enc_optimize.h (4216B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 // Utility functions for optimizing multi-dimensional nonlinear functions. 7 8 #ifndef LIB_JXL_OPTIMIZE_H_ 9 #define LIB_JXL_OPTIMIZE_H_ 10 11 #include <cmath> 12 #include <cstdio> 13 14 #include "lib/jxl/base/status.h" 15 16 namespace jxl { 17 namespace optimize { 18 19 // An array type of numeric values that supports math operations with operator-, 20 // operator+, etc. 21 template <typename T, size_t N> 22 class Array { 23 public: 24 Array() = default; 25 explicit Array(T v) { 26 for (size_t i = 0; i < N; i++) v_[i] = v; 27 } 28 29 size_t size() const { return N; } 30 31 T& operator[](size_t index) { 32 JXL_DASSERT(index < N); 33 return v_[index]; 34 } 35 T operator[](size_t index) const { 36 JXL_DASSERT(index < N); 37 return v_[index]; 38 } 39 40 private: 41 // The values used by this Array. 42 T v_[N]; 43 }; 44 45 template <typename T, size_t N> 46 Array<T, N> operator+(const Array<T, N>& x, const Array<T, N>& y) { 47 Array<T, N> z; 48 for (size_t i = 0; i < N; ++i) { 49 z[i] = x[i] + y[i]; 50 } 51 return z; 52 } 53 54 template <typename T, size_t N> 55 Array<T, N> operator-(const Array<T, N>& x, const Array<T, N>& y) { 56 Array<T, N> z; 57 for (size_t i = 0; i < N; ++i) { 58 z[i] = x[i] - y[i]; 59 } 60 return z; 61 } 62 63 template <typename T, size_t N> 64 Array<T, N> operator*(T v, const Array<T, N>& x) { 65 Array<T, N> y; 66 for (size_t i = 0; i < N; ++i) { 67 y[i] = v * x[i]; 68 } 69 return y; 70 } 71 72 template <typename T, size_t N> 73 T operator*(const Array<T, N>& x, const Array<T, N>& y) { 74 T r = 0.0; 75 for (size_t i = 0; i < N; ++i) { 76 r += x[i] * y[i]; 77 } 78 return r; 79 } 80 81 // Implementation of the Scaled Conjugate Gradient method described in the 82 // following paper: 83 // Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised 84 // Learning", Neural Networks, Vol. 6. pp. 525-533, 1993 85 // http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf 86 // 87 // The Function template parameter is a class that has the following method: 88 // 89 // // Returns the value of the function at point w and sets *df to be the 90 // // negative gradient vector of the function at point w. 91 // double Compute(const optimize::Array<T, N>& w, 92 // optimize::Array<T, N>* df) const; 93 // 94 // Returns a vector w, such that |df(w)| < grad_norm_threshold. 95 template <typename T, size_t N, typename Function> 96 Array<T, N> OptimizeWithScaledConjugateGradientMethod( 97 const Function& f, const Array<T, N>& w0, const T grad_norm_threshold, 98 size_t max_iters) { 99 const size_t n = w0.size(); 100 const T rsq_threshold = grad_norm_threshold * grad_norm_threshold; 101 const T sigma0 = static_cast<T>(0.0001); 102 const T l_min = static_cast<T>(1.0e-15); 103 const T l_max = static_cast<T>(1.0e15); 104 105 Array<T, N> w = w0; 106 Array<T, N> wp; 107 Array<T, N> r; 108 Array<T, N> rt; 109 Array<T, N> e; 110 Array<T, N> p; 111 T psq; 112 T fp; 113 T D; 114 T d; 115 T m; 116 T a; 117 T b; 118 T s; 119 T t; 120 121 T fw = f.Compute(w, &r); 122 T rsq = r * r; 123 e = r; 124 p = r; 125 T l = static_cast<T>(1.0); 126 bool success = true; 127 size_t n_success = 0; 128 size_t k = 0; 129 130 while (k++ < max_iters) { 131 if (success) { 132 m = -(p * r); 133 if (m >= 0) { 134 p = r; 135 m = -(p * r); 136 } 137 psq = p * p; 138 s = sigma0 / std::sqrt(psq); 139 f.Compute(w + (s * p), &rt); 140 t = (p * (r - rt)) / s; 141 } 142 143 d = t + l * psq; 144 if (d <= 0) { 145 d = l * psq; 146 l = l - t / psq; 147 } 148 149 a = -m / d; 150 wp = w + a * p; 151 fp = f.Compute(wp, &rt); 152 153 D = 2.0 * (fp - fw) / (a * m); 154 if (D >= 0.0) { 155 success = true; 156 n_success++; 157 w = wp; 158 } else { 159 success = false; 160 } 161 162 if (success) { 163 e = r; 164 r = rt; 165 rsq = r * r; 166 fw = fp; 167 if (rsq <= rsq_threshold) { 168 break; 169 } 170 } 171 172 if (D < 0.25) { 173 l = std::min(4.0 * l, l_max); 174 } else if (D > 0.75) { 175 l = std::max(0.25 * l, l_min); 176 } 177 178 if ((n_success % n) == 0) { 179 p = r; 180 l = 1.0; 181 } else if (success) { 182 b = ((e - r) * r) / m; 183 p = b * p + r; 184 } 185 } 186 187 return w; 188 } 189 190 } // namespace optimize 191 } // namespace jxl 192 193 #endif // LIB_JXL_OPTIMIZE_H_