tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

metrics.cc (7159B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/extras/metrics.h"
      7 
      8 #include <math.h>
      9 #include <stdlib.h>
     10 
     11 #include <atomic>
     12 
     13 #undef HWY_TARGET_INCLUDE
     14 #define HWY_TARGET_INCLUDE "lib/extras/metrics.cc"
     15 #include <hwy/foreach_target.h>
     16 #include <hwy/highway.h>
     17 
     18 #include "lib/jxl/base/compiler_specific.h"
     19 #include "lib/jxl/base/rect.h"
     20 #include "lib/jxl/base/status.h"
     21 #include "lib/jxl/color_encoding_internal.h"
     22 HWY_BEFORE_NAMESPACE();
     23 namespace jxl {
     24 namespace HWY_NAMESPACE {
     25 
     26 // These templates are not found via ADL.
     27 using hwy::HWY_NAMESPACE::Add;
     28 using hwy::HWY_NAMESPACE::GetLane;
     29 using hwy::HWY_NAMESPACE::Mul;
     30 using hwy::HWY_NAMESPACE::Rebind;
     31 
     32 double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params,
     33                        double p) {
     34  const double onePerPixels = 1.0 / (distmap.ysize() * distmap.xsize());
     35  if (std::abs(p - 3.0) < 1E-6) {
     36    double sum1[3] = {0.0};
     37 
     38 // Prefer double if possible, but otherwise use float rather than scalar.
     39 #if HWY_CAP_FLOAT64
     40    using T = double;
     41    const Rebind<float, HWY_FULL(double)> df;
     42 #else
     43    using T = float;
     44 #endif
     45    const HWY_FULL(T) d;
     46    constexpr size_t N = MaxLanes(d);
     47    // Manually aligned storage to avoid asan crash on clang-7 due to
     48    // unaligned spill.
     49    HWY_ALIGN T sum_totals0[N] = {0};
     50    HWY_ALIGN T sum_totals1[N] = {0};
     51    HWY_ALIGN T sum_totals2[N] = {0};
     52 
     53    for (size_t y = 0; y < distmap.ysize(); ++y) {
     54      const float* JXL_RESTRICT row = distmap.ConstRow(y);
     55 
     56      auto sums0 = Zero(d);
     57      auto sums1 = Zero(d);
     58      auto sums2 = Zero(d);
     59 
     60      size_t x = 0;
     61      for (; x + Lanes(d) <= distmap.xsize(); x += Lanes(d)) {
     62 #if HWY_CAP_FLOAT64
     63        const auto d1 = PromoteTo(d, Load(df, row + x));
     64 #else
     65        const auto d1 = Load(d, row + x);
     66 #endif
     67        const auto d2 = Mul(d1, Mul(d1, d1));
     68        sums0 = Add(sums0, d2);
     69        const auto d3 = Mul(d2, d2);
     70        sums1 = Add(sums1, d3);
     71        const auto d4 = Mul(d3, d3);
     72        sums2 = Add(sums2, d4);
     73      }
     74 
     75      Store(Add(sums0, Load(d, sum_totals0)), d, sum_totals0);
     76      Store(Add(sums1, Load(d, sum_totals1)), d, sum_totals1);
     77      Store(Add(sums2, Load(d, sum_totals2)), d, sum_totals2);
     78 
     79      for (; x < distmap.xsize(); ++x) {
     80        const double d1 = row[x];
     81        double d2 = d1 * d1 * d1;
     82        sum1[0] += d2;
     83        d2 *= d2;
     84        sum1[1] += d2;
     85        d2 *= d2;
     86        sum1[2] += d2;
     87      }
     88    }
     89    double v = 0;
     90    v += pow(
     91        onePerPixels * (sum1[0] + GetLane(SumOfLanes(d, Load(d, sum_totals0)))),
     92        1.0 / (p * 1.0));
     93    v += pow(
     94        onePerPixels * (sum1[1] + GetLane(SumOfLanes(d, Load(d, sum_totals1)))),
     95        1.0 / (p * 2.0));
     96    v += pow(
     97        onePerPixels * (sum1[2] + GetLane(SumOfLanes(d, Load(d, sum_totals2)))),
     98        1.0 / (p * 4.0));
     99    v /= 3.0;
    100    return v;
    101  } else {
    102    static std::atomic<int> once{0};
    103    if (once.fetch_add(1, std::memory_order_relaxed) == 0) {
    104      JXL_WARNING("WARNING: using slow ComputeDistanceP");
    105    }
    106    double sum1[3] = {0.0};
    107    for (size_t y = 0; y < distmap.ysize(); ++y) {
    108      const float* JXL_RESTRICT row = distmap.ConstRow(y);
    109      for (size_t x = 0; x < distmap.xsize(); ++x) {
    110        double d2 = std::pow(row[x], p);
    111        sum1[0] += d2;
    112        d2 *= d2;
    113        sum1[1] += d2;
    114        d2 *= d2;
    115        sum1[2] += d2;
    116      }
    117    }
    118    double v = 0;
    119    for (int i = 0; i < 3; ++i) {
    120      v += pow(onePerPixels * (sum1[i]), 1.0 / (p * (1 << i)));
    121    }
    122    v /= 3.0;
    123    return v;
    124  }
    125 }
    126 
    127 void ComputeSumOfSquares(const ImageBundle& ib1, const ImageBundle& ib2,
    128                         const JxlCmsInterface& cms, double sum_of_squares[3]) {
    129  sum_of_squares[0] = sum_of_squares[1] = sum_of_squares[2] =
    130      std::numeric_limits<double>::max();
    131  // Convert to sRGB - closer to perception than linear.
    132  const Image3F* srgb1 = &ib1.color();
    133  Image3F copy1;
    134  if (!ib1.IsSRGB()) {
    135    if (!ib1.CopyTo(Rect(ib1), ColorEncoding::SRGB(ib1.IsGray()), cms, &copy1))
    136      return;
    137    srgb1 = &copy1;
    138  }
    139  const Image3F* srgb2 = &ib2.color();
    140  Image3F copy2;
    141  if (!ib2.IsSRGB()) {
    142    if (!ib2.CopyTo(Rect(ib2), ColorEncoding::SRGB(ib2.IsGray()), cms, &copy2))
    143      return;
    144    srgb2 = &copy2;
    145  }
    146 
    147  if (!SameSize(*srgb1, *srgb2)) return;
    148 
    149  sum_of_squares[0] = sum_of_squares[1] = sum_of_squares[2] = 0.0;
    150 
    151  // TODO(veluca): SIMD.
    152  float yuvmatrix[3][3] = {{0.299, 0.587, 0.114},
    153                           {-0.14713, -0.28886, 0.436},
    154                           {0.615, -0.51499, -0.10001}};
    155  for (size_t y = 0; y < srgb1->ysize(); ++y) {
    156    const float* JXL_RESTRICT row1[3];
    157    const float* JXL_RESTRICT row2[3];
    158    for (size_t j = 0; j < 3; j++) {
    159      row1[j] = srgb1->ConstPlaneRow(j, y);
    160      row2[j] = srgb2->ConstPlaneRow(j, y);
    161    }
    162    for (size_t x = 0; x < srgb1->xsize(); ++x) {
    163      float cdiff[3] = {};
    164      // YUV conversion is linear, so we can run it on the difference.
    165      for (size_t j = 0; j < 3; j++) {
    166        cdiff[j] = row1[j][x] - row2[j][x];
    167      }
    168      float yuvdiff[3] = {};
    169      for (size_t j = 0; j < 3; j++) {
    170        for (size_t k = 0; k < 3; k++) {
    171          yuvdiff[j] += yuvmatrix[j][k] * cdiff[k];
    172        }
    173      }
    174      for (size_t j = 0; j < 3; j++) {
    175        sum_of_squares[j] += yuvdiff[j] * yuvdiff[j];
    176      }
    177    }
    178  }
    179 }
    180 
    181 // NOLINTNEXTLINE(google-readability-namespace-comments)
    182 }  // namespace HWY_NAMESPACE
    183 }  // namespace jxl
    184 HWY_AFTER_NAMESPACE();
    185 
    186 #if HWY_ONCE
    187 namespace jxl {
    188 HWY_EXPORT(ComputeDistanceP);
    189 double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params,
    190                        double p) {
    191  return HWY_DYNAMIC_DISPATCH(ComputeDistanceP)(distmap, params, p);
    192 }
    193 
    194 HWY_EXPORT(ComputeSumOfSquares);
    195 
    196 double ComputeDistance2(const ImageBundle& ib1, const ImageBundle& ib2,
    197                        const JxlCmsInterface& cms) {
    198  double sum_of_squares[3] = {};
    199  HWY_DYNAMIC_DISPATCH(ComputeSumOfSquares)(ib1, ib2, cms, sum_of_squares);
    200  // Weighted PSNR as in JPEG-XL: chroma counts 1/8.
    201  const float weights[3] = {6.0f / 8, 1.0f / 8, 1.0f / 8};
    202  // Avoid squaring the weight - 1/64 is too extreme.
    203  double norm = 0;
    204  for (size_t i = 0; i < 3; i++) {
    205    norm += std::sqrt(sum_of_squares[i]) * weights[i];
    206  }
    207  // This function returns distance *squared*.
    208  return norm * norm;
    209 }
    210 
    211 double ComputePSNR(const ImageBundle& ib1, const ImageBundle& ib2,
    212                   const JxlCmsInterface& cms) {
    213  if (!SameSize(ib1, ib2)) return 0.0;
    214  double sum_of_squares[3] = {};
    215  HWY_DYNAMIC_DISPATCH(ComputeSumOfSquares)(ib1, ib2, cms, sum_of_squares);
    216  constexpr double kChannelWeights[3] = {6.0 / 8, 1.0 / 8, 1.0 / 8};
    217  double avg_psnr = 0;
    218  const size_t input_pixels = ib1.xsize() * ib1.ysize();
    219  for (int i = 0; i < 3; ++i) {
    220    const double rmse = std::sqrt(sum_of_squares[i] / input_pixels);
    221    const double psnr =
    222        sum_of_squares[i] == 0 ? 99.99 : (20 * std::log10(1 / rmse));
    223    avg_psnr += kChannelWeights[i] * psnr;
    224  }
    225  return avg_psnr;
    226 }
    227 
    228 }  // namespace jxl
    229 #endif