tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_optimize.h (4216B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 // Utility functions for optimizing multi-dimensional nonlinear functions.
      7 
      8 #ifndef LIB_JXL_OPTIMIZE_H_
      9 #define LIB_JXL_OPTIMIZE_H_
     10 
     11 #include <cmath>
     12 #include <cstdio>
     13 
     14 #include "lib/jxl/base/status.h"
     15 
     16 namespace jxl {
     17 namespace optimize {
     18 
     19 // An array type of numeric values that supports math operations with operator-,
     20 // operator+, etc.
     21 template <typename T, size_t N>
     22 class Array {
     23 public:
     24  Array() = default;
     25  explicit Array(T v) {
     26    for (size_t i = 0; i < N; i++) v_[i] = v;
     27  }
     28 
     29  size_t size() const { return N; }
     30 
     31  T& operator[](size_t index) {
     32    JXL_DASSERT(index < N);
     33    return v_[index];
     34  }
     35  T operator[](size_t index) const {
     36    JXL_DASSERT(index < N);
     37    return v_[index];
     38  }
     39 
     40 private:
     41  // The values used by this Array.
     42  T v_[N];
     43 };
     44 
     45 template <typename T, size_t N>
     46 Array<T, N> operator+(const Array<T, N>& x, const Array<T, N>& y) {
     47  Array<T, N> z;
     48  for (size_t i = 0; i < N; ++i) {
     49    z[i] = x[i] + y[i];
     50  }
     51  return z;
     52 }
     53 
     54 template <typename T, size_t N>
     55 Array<T, N> operator-(const Array<T, N>& x, const Array<T, N>& y) {
     56  Array<T, N> z;
     57  for (size_t i = 0; i < N; ++i) {
     58    z[i] = x[i] - y[i];
     59  }
     60  return z;
     61 }
     62 
     63 template <typename T, size_t N>
     64 Array<T, N> operator*(T v, const Array<T, N>& x) {
     65  Array<T, N> y;
     66  for (size_t i = 0; i < N; ++i) {
     67    y[i] = v * x[i];
     68  }
     69  return y;
     70 }
     71 
     72 template <typename T, size_t N>
     73 T operator*(const Array<T, N>& x, const Array<T, N>& y) {
     74  T r = 0.0;
     75  for (size_t i = 0; i < N; ++i) {
     76    r += x[i] * y[i];
     77  }
     78  return r;
     79 }
     80 
     81 // Implementation of the Scaled Conjugate Gradient method described in the
     82 // following paper:
     83 //   Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised
     84 //   Learning", Neural Networks, Vol. 6. pp. 525-533, 1993
     85 //   http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf
     86 //
     87 // The Function template parameter is a class that has the following method:
     88 //
     89 //   // Returns the value of the function at point w and sets *df to be the
     90 //   // negative gradient vector of the function at point w.
     91 //   double Compute(const optimize::Array<T, N>& w,
     92 //                  optimize::Array<T, N>* df) const;
     93 //
     94 // Returns a vector w, such that |df(w)| < grad_norm_threshold.
     95 template <typename T, size_t N, typename Function>
     96 Array<T, N> OptimizeWithScaledConjugateGradientMethod(
     97    const Function& f, const Array<T, N>& w0, const T grad_norm_threshold,
     98    size_t max_iters) {
     99  const size_t n = w0.size();
    100  const T rsq_threshold = grad_norm_threshold * grad_norm_threshold;
    101  const T sigma0 = static_cast<T>(0.0001);
    102  const T l_min = static_cast<T>(1.0e-15);
    103  const T l_max = static_cast<T>(1.0e15);
    104 
    105  Array<T, N> w = w0;
    106  Array<T, N> wp;
    107  Array<T, N> r;
    108  Array<T, N> rt;
    109  Array<T, N> e;
    110  Array<T, N> p;
    111  T psq;
    112  T fp;
    113  T D;
    114  T d;
    115  T m;
    116  T a;
    117  T b;
    118  T s;
    119  T t;
    120 
    121  T fw = f.Compute(w, &r);
    122  T rsq = r * r;
    123  e = r;
    124  p = r;
    125  T l = static_cast<T>(1.0);
    126  bool success = true;
    127  size_t n_success = 0;
    128  size_t k = 0;
    129 
    130  while (k++ < max_iters) {
    131    if (success) {
    132      m = -(p * r);
    133      if (m >= 0) {
    134        p = r;
    135        m = -(p * r);
    136      }
    137      psq = p * p;
    138      s = sigma0 / std::sqrt(psq);
    139      f.Compute(w + (s * p), &rt);
    140      t = (p * (r - rt)) / s;
    141    }
    142 
    143    d = t + l * psq;
    144    if (d <= 0) {
    145      d = l * psq;
    146      l = l - t / psq;
    147    }
    148 
    149    a = -m / d;
    150    wp = w + a * p;
    151    fp = f.Compute(wp, &rt);
    152 
    153    D = 2.0 * (fp - fw) / (a * m);
    154    if (D >= 0.0) {
    155      success = true;
    156      n_success++;
    157      w = wp;
    158    } else {
    159      success = false;
    160    }
    161 
    162    if (success) {
    163      e = r;
    164      r = rt;
    165      rsq = r * r;
    166      fw = fp;
    167      if (rsq <= rsq_threshold) {
    168        break;
    169      }
    170    }
    171 
    172    if (D < 0.25) {
    173      l = std::min(4.0 * l, l_max);
    174    } else if (D > 0.75) {
    175      l = std::max(0.25 * l, l_min);
    176    }
    177 
    178    if ((n_success % n) == 0) {
    179      p = r;
    180      l = 1.0;
    181    } else if (success) {
    182      b = ((e - r) * r) / m;
    183      p = b * p + r;
    184    }
    185  }
    186 
    187  return w;
    188 }
    189 
    190 }  // namespace optimize
    191 }  // namespace jxl
    192 
    193 #endif  // LIB_JXL_OPTIMIZE_H_