tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_optimize_test.cc (2712B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_optimize.h"
      7 
      8 #include <cmath>
      9 #include <cstddef>
     10 #include <vector>
     11 
     12 #include "lib/jxl/testing.h"
     13 
     14 namespace jxl {
     15 namespace optimize {
     16 namespace {
     17 
     18 // The maximum number of iterations for the test.
     19 const size_t kMaxTestIter = 100000;
     20 
     21 // F(w) = (w - w_min)^2.
     22 struct SimpleQuadraticFunction {
     23  typedef Array<double, 2> ArrayType;
     24  explicit SimpleQuadraticFunction(const ArrayType& w0) : w_min(w0) {}
     25 
     26  double Compute(const ArrayType& w, ArrayType* df) const {
     27    ArrayType dw = w - w_min;
     28    *df = -2.0 * dw;
     29    return dw * dw;
     30  }
     31 
     32  ArrayType w_min;
     33 };
     34 
     35 // F(alpha, beta, gamma| x,y) = \sum_i(y_i - (alpha x_i ^ gamma + beta))^2.
     36 struct PowerFunction {
     37  explicit PowerFunction(const std::vector<double>& x0,
     38                         const std::vector<double>& y0)
     39      : x(x0), y(y0) {}
     40 
     41  typedef Array<double, 3> ArrayType;
     42  double Compute(const ArrayType& w, ArrayType* df) const {
     43    double loss_function = 0;
     44    (*df)[0] = 0;
     45    (*df)[1] = 0;
     46    (*df)[2] = 0;
     47    for (size_t ind = 0; ind < y.size(); ++ind) {
     48      if (x[ind] != 0) {
     49        double l_f = y[ind] - (w[0] * pow(x[ind], w[1]) + w[2]);
     50        (*df)[0] += 2.0 * l_f * pow(x[ind], w[1]);
     51        (*df)[1] += 2.0 * l_f * w[0] * pow(x[ind], w[1]) * log(x[ind]);
     52        (*df)[2] += 2.0 * l_f * 1;
     53        loss_function += l_f * l_f;
     54      }
     55    }
     56    return loss_function;
     57  }
     58 
     59  std::vector<double> x;
     60  std::vector<double> y;
     61 };
     62 
     63 TEST(OptimizeTest, SimpleQuadraticFunction) {
     64  SimpleQuadraticFunction::ArrayType w_min;
     65  w_min[0] = 1.0;
     66  w_min[1] = 2.0;
     67  SimpleQuadraticFunction f(w_min);
     68  SimpleQuadraticFunction::ArrayType w(0.);
     69  static const double kPrecision = 1e-8;
     70  w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision,
     71                                                          kMaxTestIter);
     72  EXPECT_NEAR(w[0], 1.0, kPrecision);
     73  EXPECT_NEAR(w[1], 2.0, kPrecision);
     74 }
     75 
     76 TEST(OptimizeTest, PowerFunction) {
     77  std::vector<double> x(10);
     78  std::vector<double> y(10);
     79  for (int ind = 0; ind < 10; ++ind) {
     80    x[ind] = 1. * ind;
     81    y[ind] = 2. * pow(x[ind], 3) + 5.;
     82  }
     83  PowerFunction f(x, y);
     84  PowerFunction::ArrayType w(0.);
     85 
     86  static const double kPrecision = 0.01;
     87  w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision,
     88                                                          kMaxTestIter);
     89  EXPECT_NEAR(w[0], 2.0, kPrecision);
     90  EXPECT_NEAR(w[1], 3.0, kPrecision);
     91  EXPECT_NEAR(w[2], 5.0, kPrecision);
     92 }
     93 
     94 }  // namespace
     95 }  // namespace optimize
     96 }  // namespace jxl