tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_noise.cc (13342B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_noise.h"
      7 
      8 #include <algorithm>
      9 #include <cstdint>
     10 #include <cstdlib>
     11 #include <numeric>
     12 #include <utility>
     13 
     14 #include "lib/jxl/base/status.h"
     15 #include "lib/jxl/enc_aux_out.h"
     16 #include "lib/jxl/enc_optimize.h"
     17 
     18 namespace jxl {
     19 namespace {
     20 
     21 using OptimizeArray = optimize::Array<double, NoiseParams::kNumNoisePoints>;
     22 
     23 float GetScoreSumsOfAbsoluteDifferences(const Image3F& opsin, const int x,
     24                                        const int y, const int block_size) {
     25  const int small_bl_size_x = 3;
     26  const int small_bl_size_y = 4;
     27  const int kNumSAD =
     28      (block_size - small_bl_size_x) * (block_size - small_bl_size_y);
     29  // block_size x block_size reference pixels
     30  int counter = 0;
     31  const int offset = 2;
     32 
     33  std::vector<float> sad(kNumSAD, 0);
     34  for (int y_bl = 0; y_bl + small_bl_size_y < block_size; ++y_bl) {
     35    for (int x_bl = 0; x_bl + small_bl_size_x < block_size; ++x_bl) {
     36      float sad_sum = 0;
     37      // size of the center patch, we compare all the patches inside window with
     38      // the center one
     39      for (int cy = 0; cy < small_bl_size_y; ++cy) {
     40        for (int cx = 0; cx < small_bl_size_x; ++cx) {
     41          float wnd = 0.5f * (opsin.PlaneRow(1, y + y_bl + cy)[x + x_bl + cx] +
     42                              opsin.PlaneRow(0, y + y_bl + cy)[x + x_bl + cx]);
     43          float center =
     44              0.5f * (opsin.PlaneRow(1, y + offset + cy)[x + offset + cx] +
     45                      opsin.PlaneRow(0, y + offset + cy)[x + offset + cx]);
     46          sad_sum += std::abs(center - wnd);
     47        }
     48      }
     49      sad[counter++] = sad_sum;
     50    }
     51  }
     52  const int kSamples = (kNumSAD) / 2;
     53  // As with ROAD (rank order absolute distance), we keep the smallest half of
     54  // the values in SAD (we use here the more robust patch SAD instead of
     55  // absolute single-pixel differences).
     56  std::sort(sad.begin(), sad.end());
     57  const float total_sad_sum =
     58      std::accumulate(sad.begin(), sad.begin() + kSamples, 0.0f);
     59  return total_sad_sum / kSamples;
     60 }
     61 
     62 class NoiseHistogram {
     63 public:
     64  static constexpr int kBins = 256;
     65 
     66  NoiseHistogram() { std::fill(bins, bins + kBins, 0); }
     67 
     68  void Increment(const float x) { bins[Index(x)] += 1; }
     69  int Get(const float x) const { return bins[Index(x)]; }
     70  int Bin(const size_t bin) const { return bins[bin]; }
     71 
     72  int Mode() const {
     73    size_t max_idx = 0;
     74    for (size_t i = 0; i < kBins; i++) {
     75      if (bins[i] > bins[max_idx]) max_idx = i;
     76    }
     77    return max_idx;
     78  }
     79 
     80  double Quantile(double q01) const {
     81    const int64_t total = std::accumulate(bins, bins + kBins, int64_t{1});
     82    const int64_t target = static_cast<int64_t>(q01 * total);
     83    // Until sum >= target:
     84    int64_t sum = 0;
     85    size_t i = 0;
     86    for (; i < kBins; ++i) {
     87      sum += bins[i];
     88      // Exact match: assume middle of bin i
     89      if (sum == target) {
     90        return i + 0.5;
     91      }
     92      if (sum > target) break;
     93    }
     94 
     95    // Next non-empty bin (in case histogram is sparsely filled)
     96    size_t next = i + 1;
     97    while (next < kBins && bins[next] == 0) {
     98      ++next;
     99    }
    100 
    101    // Linear interpolation according to how far into next we went
    102    const double excess = target - sum;
    103    const double weight_next = bins[Index(next)] / excess;
    104    return ClampX(next * weight_next + i * (1.0 - weight_next));
    105  }
    106 
    107  // Inter-quartile range
    108  double IQR() const { return Quantile(0.75) - Quantile(0.25); }
    109 
    110 private:
    111  template <typename T>
    112  T ClampX(const T x) const {
    113    return std::min(std::max(static_cast<T>(0), x), static_cast<T>(kBins - 1));
    114  }
    115  size_t Index(const float x) const { return ClampX(static_cast<int>(x)); }
    116 
    117  uint32_t bins[kBins];
    118 };
    119 
    120 std::vector<float> GetSADScoresForPatches(const Image3F& opsin,
    121                                          const size_t block_s,
    122                                          const size_t num_bin,
    123                                          NoiseHistogram* sad_histogram) {
    124  std::vector<float> sad_scores(
    125      (opsin.ysize() / block_s) * (opsin.xsize() / block_s), 0.0f);
    126 
    127  int block_index = 0;
    128 
    129  for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) {
    130    for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) {
    131      float sad_sc = GetScoreSumsOfAbsoluteDifferences(opsin, x, y, block_s);
    132      sad_scores[block_index++] = sad_sc;
    133      sad_histogram->Increment(sad_sc * num_bin);
    134    }
    135  }
    136  return sad_scores;
    137 }
    138 
    139 float GetSADThreshold(const NoiseHistogram& histogram, const int num_bin) {
    140  // Here we assume that the most patches with similar SAD value is a "flat"
    141  // patches. However, some images might contain regular texture part and
    142  // generate second strong peak at the histogram
    143  // TODO(user) handle bimodal and heavy-tailed case
    144  const int mode = histogram.Mode();
    145  return static_cast<float>(mode) / NoiseHistogram::kBins;
    146 }
    147 
    148 // loss = sum asym * (F(x) - nl)^2 + kReg * num_points * sum (w[i] - w[i+1])^2
    149 // where asym = 1 if F(x) < nl, kAsym if F(x) > nl.
    150 struct LossFunction {
    151  explicit LossFunction(std::vector<NoiseLevel> nl0) : nl(std::move(nl0)) {}
    152 
    153  double Compute(const OptimizeArray& w, OptimizeArray* df,
    154                 bool skip_regularization = false) const {
    155    constexpr double kReg = 0.005;
    156    constexpr double kAsym = 1.1;
    157    double loss_function = 0;
    158    for (size_t i = 0; i < w.size(); i++) {
    159      (*df)[i] = 0;
    160    }
    161    for (auto ind : nl) {
    162      std::pair<int, float> pos = IndexAndFrac(ind.intensity);
    163      JXL_DASSERT(pos.first >= 0 && static_cast<size_t>(pos.first) <
    164                                        NoiseParams::kNumNoisePoints - 1);
    165      double low = w[pos.first];
    166      double hi = w[pos.first + 1];
    167      double val = low * (1.0f - pos.second) + hi * pos.second;
    168      double dist = val - ind.noise_level;
    169      if (dist > 0) {
    170        loss_function += kAsym * dist * dist;
    171        (*df)[pos.first] -= kAsym * (1.0f - pos.second) * dist;
    172        (*df)[pos.first + 1] -= kAsym * pos.second * dist;
    173      } else {
    174        loss_function += dist * dist;
    175        (*df)[pos.first] -= (1.0f - pos.second) * dist;
    176        (*df)[pos.first + 1] -= pos.second * dist;
    177      }
    178    }
    179    if (skip_regularization) return loss_function;
    180    for (size_t i = 0; i + 1 < w.size(); i++) {
    181      double diff = w[i] - w[i + 1];
    182      loss_function += kReg * nl.size() * diff * diff;
    183      (*df)[i] -= kReg * diff * nl.size();
    184      (*df)[i + 1] += kReg * diff * nl.size();
    185    }
    186    return loss_function;
    187  }
    188 
    189  std::vector<NoiseLevel> nl;
    190 };
    191 
    192 void OptimizeNoiseParameters(const std::vector<NoiseLevel>& noise_level,
    193                             NoiseParams* noise_params) {
    194  constexpr double kMaxError = 1e-3;
    195  static const double kPrecision = 1e-8;
    196  static const int kMaxIter = 40;
    197 
    198  float avg = 0;
    199  for (const NoiseLevel& nl : noise_level) {
    200    avg += nl.noise_level;
    201  }
    202  avg /= noise_level.size();
    203 
    204  LossFunction loss_function(noise_level);
    205  OptimizeArray parameter_vector;
    206  for (size_t i = 0; i < parameter_vector.size(); i++) {
    207    parameter_vector[i] = avg;
    208  }
    209 
    210  parameter_vector = optimize::OptimizeWithScaledConjugateGradientMethod(
    211      loss_function, parameter_vector, kPrecision, kMaxIter);
    212 
    213  OptimizeArray df = parameter_vector;
    214  float loss = loss_function.Compute(parameter_vector, &df,
    215                                     /*skip_regularization=*/true) /
    216               noise_level.size();
    217 
    218  // Approximation went too badly: escape with no noise at all.
    219  if (loss > kMaxError) {
    220    noise_params->Clear();
    221    return;
    222  }
    223 
    224  for (size_t i = 0; i < parameter_vector.size(); i++) {
    225    noise_params->lut[i] = std::max(parameter_vector[i], 0.0);
    226  }
    227 }
    228 
    229 std::vector<NoiseLevel> GetNoiseLevel(
    230    const Image3F& opsin, const std::vector<float>& texture_strength,
    231    const float threshold, const size_t block_s) {
    232  std::vector<NoiseLevel> noise_level_per_intensity;
    233 
    234  const int filt_size = 1;
    235  static const float kLaplFilter[filt_size * 2 + 1][filt_size * 2 + 1] = {
    236      {-0.25f, -1.0f, -0.25f},
    237      {-1.0f, 5.0f, -1.0f},
    238      {-0.25f, -1.0f, -0.25f},
    239  };
    240 
    241  // The noise model is built based on channel 0.5 * (X+Y) as we notice that it
    242  // is similar to the model 0.5 * (Y-X)
    243  size_t patch_index = 0;
    244 
    245  for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) {
    246    for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) {
    247      if (texture_strength[patch_index] <= threshold) {
    248        // Calculate mean value
    249        float mean_int = 0;
    250        for (size_t y_bl = 0; y_bl < block_s; ++y_bl) {
    251          for (size_t x_bl = 0; x_bl < block_s; ++x_bl) {
    252            mean_int += 0.5f * (opsin.PlaneRow(1, y + y_bl)[x + x_bl] +
    253                                opsin.PlaneRow(0, y + y_bl)[x + x_bl]);
    254          }
    255        }
    256        mean_int /= block_s * block_s;
    257 
    258        // Calculate Noise level
    259        float noise_level = 0;
    260        size_t count = 0;
    261        for (size_t y_bl = 0; y_bl < block_s; ++y_bl) {
    262          for (size_t x_bl = 0; x_bl < block_s; ++x_bl) {
    263            float filtered_value = 0;
    264            for (int y_f = -1 * filt_size; y_f <= filt_size; ++y_f) {
    265              if ((static_cast<ssize_t>(y_bl) + y_f) >= 0 &&
    266                  (y_bl + y_f) < block_s) {
    267                for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) {
    268                  if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 &&
    269                      (x_bl + x_f) < block_s) {
    270                    filtered_value +=
    271                        0.5f *
    272                        (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl + x_f] +
    273                         opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl + x_f]) *
    274                        kLaplFilter[y_f + filt_size][x_f + filt_size];
    275                  } else {
    276                    filtered_value +=
    277                        0.5f *
    278                        (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl - x_f] +
    279                         opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl - x_f]) *
    280                        kLaplFilter[y_f + filt_size][x_f + filt_size];
    281                  }
    282                }
    283              } else {
    284                for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) {
    285                  if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 &&
    286                      (x_bl + x_f) < block_s) {
    287                    filtered_value +=
    288                        0.5f *
    289                        (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl + x_f] +
    290                         opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl + x_f]) *
    291                        kLaplFilter[y_f + filt_size][x_f + filt_size];
    292                  } else {
    293                    filtered_value +=
    294                        0.5f *
    295                        (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl - x_f] +
    296                         opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl - x_f]) *
    297                        kLaplFilter[y_f + filt_size][x_f + filt_size];
    298                  }
    299                }
    300              }
    301            }
    302            noise_level += std::abs(filtered_value);
    303            ++count;
    304          }
    305        }
    306        noise_level /= count;
    307        NoiseLevel nl;
    308        nl.intensity = mean_int;
    309        nl.noise_level = noise_level;
    310        noise_level_per_intensity.push_back(nl);
    311      }
    312      ++patch_index;
    313    }
    314  }
    315  return noise_level_per_intensity;
    316 }
    317 
    318 Status EncodeFloatParam(float val, float precision, BitWriter* writer) {
    319  JXL_ENSURE(val >= 0);
    320  const int absval_quant = static_cast<int>(std::lround(val * precision));
    321  JXL_ENSURE(absval_quant < (1 << 10));
    322  writer->Write(10, absval_quant);
    323  return true;
    324 }
    325 
    326 }  // namespace
    327 
    328 Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params,
    329                         float quality_coef) {
    330  // The size of a patch in decoder might be different from encoder's patch
    331  // size.
    332  // For encoder: the patch size should be big enough to estimate
    333  //              noise level, but, at the same time, it should be not too big
    334  //              to be able to estimate intensity value of the patch
    335  const size_t block_s = 8;
    336  const size_t kNumBin = 256;
    337  NoiseHistogram sad_histogram;
    338  std::vector<float> sad_scores =
    339      GetSADScoresForPatches(opsin, block_s, kNumBin, &sad_histogram);
    340  float sad_threshold = GetSADThreshold(sad_histogram, kNumBin);
    341  // If threshold is too large, the image has a strong pattern. This pattern
    342  // fools our model and it will add too much noise. Therefore, we do not add
    343  // noise for such images
    344  if (sad_threshold > 0.15f || sad_threshold <= 0.0f) {
    345    noise_params->Clear();
    346    return false;
    347  }
    348  std::vector<NoiseLevel> nl =
    349      GetNoiseLevel(opsin, sad_scores, sad_threshold, block_s);
    350 
    351  OptimizeNoiseParameters(nl, noise_params);
    352  for (float& i : noise_params->lut) {
    353    i *= quality_coef * 1.4;
    354  }
    355  return noise_params->HasAny();
    356 }
    357 
    358 Status EncodeNoise(const NoiseParams& noise_params, BitWriter* writer,
    359                   LayerType layer, AuxOut* aux_out) {
    360  JXL_ENSURE(noise_params.HasAny());
    361 
    362  return writer->WithMaxBits(
    363      NoiseParams::kNumNoisePoints * 16, layer, aux_out, [&]() -> Status {
    364        for (float i : noise_params.lut) {
    365          JXL_RETURN_IF_ERROR(EncodeFloatParam(i, kNoisePrecision, writer));
    366        }
    367        return true;
    368      });
    369 }
    370 
    371 }  // namespace jxl