tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_palette.cc (24805B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/modular/transform/enc_palette.h"
      7 
      8 #include <jxl/memory_manager.h>
      9 
     10 #include <array>
     11 #include <map>
     12 #include <set>
     13 
     14 #include "lib/jxl/base/common.h"
     15 #include "lib/jxl/base/status.h"
     16 #include "lib/jxl/image_ops.h"
     17 #include "lib/jxl/modular/encoding/context_predict.h"
     18 #include "lib/jxl/modular/modular_image.h"
     19 #include "lib/jxl/modular/transform/enc_transform.h"
     20 #include "lib/jxl/modular/transform/palette.h"
     21 
     22 namespace jxl {
     23 
     24 namespace palette_internal {
     25 
     26 static constexpr bool kEncodeToHighQualityImplicitPalette = true;
     27 
     28 // Inclusive.
     29 static constexpr int kMinImplicitPaletteIndex = -(2 * 72 - 1);
     30 
     31 float ColorDistance(const std::vector<float> &JXL_RESTRICT a,
     32                    const std::vector<pixel_type> &JXL_RESTRICT b) {
     33  JXL_DASSERT(a.size() == b.size());
     34  float distance = 0;
     35  float ave3 = 0;
     36  if (a.size() >= 3) {
     37    ave3 = (a[0] + b[0] + a[1] + b[1] + a[2] + b[2]) * (1.21f / 3.0f);
     38  }
     39  float sum_a = 0;
     40  float sum_b = 0;
     41  for (size_t c = 0; c < a.size(); ++c) {
     42    const float difference =
     43        static_cast<float>(a[c]) - static_cast<float>(b[c]);
     44    float weight = c == 0 ? 3 : c == 1 ? 5 : 2;
     45    if (c < 3 && (a[c] + b[c] >= ave3)) {
     46      const float add_w[3] = {
     47          1.15,
     48          1.15,
     49          1.12,
     50      };
     51      weight += add_w[c];
     52      if (c == 2 && ((a[2] + b[2]) < 1.22 * ave3)) {
     53        weight -= 0.5;
     54      }
     55    }
     56    distance += difference * difference * weight * weight;
     57    const int sum_weight = c == 0 ? 3 : c == 1 ? 5 : 1;
     58    sum_a += a[c] * sum_weight;
     59    sum_b += b[c] * sum_weight;
     60  }
     61  distance *= 4;
     62  float sum_difference = sum_a - sum_b;
     63  distance += sum_difference * sum_difference;
     64  return distance;
     65 }
     66 
     67 static int QuantizeColorToImplicitPaletteIndex(
     68    const std::vector<pixel_type> &color, const int palette_size,
     69    const int bit_depth, bool high_quality) {
     70  int index = 0;
     71  if (high_quality) {
     72    int multiplier = 1;
     73    for (int value : color) {
     74      int quantized = ((kLargeCube - 1) * value + (1 << (bit_depth - 1))) /
     75                      ((1 << bit_depth) - 1);
     76      JXL_DASSERT((quantized % kLargeCube) == quantized);
     77      index += quantized * multiplier;
     78      multiplier *= kLargeCube;
     79    }
     80    return index + palette_size + kLargeCubeOffset;
     81  } else {
     82    int multiplier = 1;
     83    for (int value : color) {
     84      value -= 1 << (std::max(0, bit_depth - 3));
     85      value = std::max(0, value);
     86      int quantized = ((kLargeCube - 1) * value + (1 << (bit_depth - 1))) /
     87                      ((1 << bit_depth) - 1);
     88      JXL_DASSERT((quantized % kLargeCube) == quantized);
     89      if (quantized > kSmallCube - 1) {
     90        quantized = kSmallCube - 1;
     91      }
     92      index += quantized * multiplier;
     93      multiplier *= kSmallCube;
     94    }
     95    return index + palette_size;
     96  }
     97 }
     98 
     99 }  // namespace palette_internal
    100 
    101 int RoundInt(int value, int div) {  // symmetric rounding around 0
    102  if (value < 0) return -RoundInt(-value, div);
    103  return (value + div / 2) / div;
    104 }
    105 
    106 struct PaletteIterationData {
    107  static constexpr int kMaxDeltas = 128;
    108  bool final_run = false;
    109  std::vector<pixel_type> deltas[3];
    110  std::vector<double> delta_distances;
    111  std::vector<pixel_type> frequent_deltas[3];
    112 
    113  // Populates `frequent_deltas` with items from `deltas` based on frequencies
    114  // and color distances.
    115  void FindFrequentColorDeltas(int num_pixels, int bitdepth) {
    116    using pixel_type_3d = std::array<pixel_type, 3>;
    117    std::map<pixel_type_3d, double> delta_frequency_map;
    118    pixel_type bucket_size = 3 << std::max(0, bitdepth - 8);
    119    // Store frequency weighted by delta distance from quantized value.
    120    for (size_t i = 0; i < deltas[0].size(); ++i) {
    121      pixel_type_3d delta = {
    122          {RoundInt(deltas[0][i], bucket_size),
    123           RoundInt(deltas[1][i], bucket_size),
    124           RoundInt(deltas[2][i], bucket_size)}};  // a basic form of clustering
    125      if (delta[0] == 0 && delta[1] == 0 && delta[2] == 0) continue;
    126      delta_frequency_map[delta] += sqrt(sqrt(delta_distances[i]));
    127    }
    128 
    129    const float delta_distance_multiplier = 1.0f / num_pixels;
    130 
    131    // Weigh frequencies by magnitude and normalize.
    132    for (auto &delta_frequency : delta_frequency_map) {
    133      std::vector<pixel_type> current_delta = {delta_frequency.first[0],
    134                                               delta_frequency.first[1],
    135                                               delta_frequency.first[2]};
    136      float delta_distance =
    137          std::sqrt(palette_internal::ColorDistance({0, 0, 0}, current_delta)) +
    138          1;
    139      delta_frequency.second *= delta_distance * delta_distance_multiplier;
    140    }
    141 
    142    // Sort by weighted frequency.
    143    using pixel_type_3d_frequency = std::pair<pixel_type_3d, double>;
    144    std::vector<pixel_type_3d_frequency> sorted_delta_frequency_map(
    145        delta_frequency_map.begin(), delta_frequency_map.end());
    146    std::sort(
    147        sorted_delta_frequency_map.begin(), sorted_delta_frequency_map.end(),
    148        [](const pixel_type_3d_frequency &a, const pixel_type_3d_frequency &b) {
    149          return a.second > b.second;
    150        });
    151 
    152    // Store the top deltas.
    153    for (auto &delta_frequency : sorted_delta_frequency_map) {
    154      if (frequent_deltas[0].size() >= kMaxDeltas) break;
    155      // Number obtained by optimizing on jyrki31 corpus:
    156      if (delta_frequency.second < 17) break;
    157      for (int c = 0; c < 3; ++c) {
    158        frequent_deltas[c].push_back(delta_frequency.first[c] * bucket_size);
    159      }
    160    }
    161  }
    162 };
    163 
    164 Status FwdPaletteIteration(Image &input, uint32_t begin_c, uint32_t end_c,
    165                           uint32_t &nb_colors, uint32_t &nb_deltas,
    166                           bool ordered, bool lossy, Predictor &predictor,
    167                           const weighted::Header &wp_header,
    168                           PaletteIterationData &palette_iteration_data) {
    169  JXL_QUIET_RETURN_IF_ERROR(CheckEqualChannels(input, begin_c, end_c));
    170  JXL_ENSURE(begin_c >= input.nb_meta_channels);
    171  JxlMemoryManager *memory_manager = input.memory_manager();
    172  uint32_t nb = end_c - begin_c + 1;
    173 
    174  size_t w = input.channel[begin_c].w;
    175  size_t h = input.channel[begin_c].h;
    176  if (input.bitdepth >= 32) return false;
    177  if (!lossy && nb_colors < 2) return false;
    178 
    179  if (!lossy && nb == 1) {
    180    // Channel palette special case
    181    if (nb_colors == 0) return false;
    182    std::vector<pixel_type> lookup;
    183    pixel_type minval;
    184    pixel_type maxval;
    185    compute_minmax(input.channel[begin_c], &minval, &maxval);
    186    size_t lookup_table_size =
    187        static_cast<int64_t>(maxval) - static_cast<int64_t>(minval) + 1;
    188    if (lookup_table_size > palette_internal::kMaxPaletteLookupTableSize) {
    189      // a lookup table would use too much memory, instead use a slower approach
    190      // with std::set
    191      std::set<pixel_type> chpalette;
    192      pixel_type idx = 0;
    193      for (size_t y = 0; y < h; y++) {
    194        const pixel_type *p = input.channel[begin_c].Row(y);
    195        for (size_t x = 0; x < w; x++) {
    196          const bool new_color = chpalette.insert(p[x]).second;
    197          if (new_color) {
    198            idx++;
    199            if (idx > static_cast<int>(nb_colors)) return false;
    200          }
    201        }
    202      }
    203      JXL_DEBUG_V(6, "Channel %i uses only %i colors.", begin_c, idx);
    204      JXL_ASSIGN_OR_RETURN(Channel pch,
    205                           Channel::Create(memory_manager, idx, 1));
    206      pch.hshift = -1;
    207      pch.vshift = -1;
    208      nb_colors = idx;
    209      idx = 0;
    210      pixel_type *JXL_RESTRICT p_palette = pch.Row(0);
    211      for (pixel_type p : chpalette) {
    212        p_palette[idx++] = p;
    213      }
    214      for (size_t y = 0; y < h; y++) {
    215        pixel_type *p = input.channel[begin_c].Row(y);
    216        for (size_t x = 0; x < w; x++) {
    217          for (idx = 0;
    218               p[x] != p_palette[idx] && idx < static_cast<int>(nb_colors);
    219               idx++) {
    220            // no-op
    221          }
    222          JXL_DASSERT(idx < static_cast<int>(nb_colors));
    223          p[x] = idx;
    224        }
    225      }
    226      predictor = Predictor::Zero;
    227      input.nb_meta_channels++;
    228      input.channel.insert(input.channel.begin(), std::move(pch));
    229 
    230      return true;
    231    }
    232    lookup.resize(lookup_table_size, 0);
    233    pixel_type idx = 0;
    234    for (size_t y = 0; y < h; y++) {
    235      const pixel_type *p = input.channel[begin_c].Row(y);
    236      for (size_t x = 0; x < w; x++) {
    237        if (lookup[p[x] - minval] == 0) {
    238          lookup[p[x] - minval] = 1;
    239          idx++;
    240          if (idx > static_cast<int>(nb_colors)) return false;
    241        }
    242      }
    243    }
    244    JXL_DEBUG_V(6, "Channel %i uses only %i colors.", begin_c, idx);
    245    JXL_ASSIGN_OR_RETURN(Channel pch, Channel::Create(memory_manager, idx, 1));
    246    pch.hshift = -1;
    247    pch.vshift = -1;
    248    nb_colors = idx;
    249    idx = 0;
    250    pixel_type *JXL_RESTRICT p_palette = pch.Row(0);
    251    for (size_t i = 0; i < lookup_table_size; i++) {
    252      if (lookup[i]) {
    253        p_palette[idx] = i + minval;
    254        lookup[i] = idx;
    255        idx++;
    256      }
    257    }
    258    for (size_t y = 0; y < h; y++) {
    259      pixel_type *p = input.channel[begin_c].Row(y);
    260      for (size_t x = 0; x < w; x++) p[x] = lookup[p[x] - minval];
    261    }
    262    predictor = Predictor::Zero;
    263    input.nb_meta_channels++;
    264    input.channel.insert(input.channel.begin(), std::move(pch));
    265    return true;
    266  }
    267 
    268  Image quantized_input(memory_manager);
    269  if (lossy) {
    270    JXL_ASSIGN_OR_RETURN(quantized_input, Image::Create(memory_manager, w, h,
    271                                                        input.bitdepth, nb));
    272    for (size_t c = 0; c < nb; c++) {
    273      JXL_RETURN_IF_ERROR(CopyImageTo(input.channel[begin_c + c].plane,
    274                                      &quantized_input.channel[c].plane));
    275    }
    276  }
    277 
    278  JXL_DEBUG_V(
    279      7, "Trying to represent channels %i-%i using at most a %i-color palette.",
    280      begin_c, end_c, nb_colors);
    281  nb_deltas = 0;
    282  bool delta_used = false;
    283  std::set<std::vector<pixel_type>> candidate_palette;
    284  std::vector<std::vector<pixel_type>> candidate_palette_imageorder;
    285  std::vector<pixel_type> color(nb);
    286  std::vector<float> color_with_error(nb);
    287  std::vector<const pixel_type *> p_in(nb);
    288  std::map<std::vector<pixel_type>, size_t> inv_palette;
    289 
    290  if (lossy) {
    291    palette_iteration_data.FindFrequentColorDeltas(w * h, input.bitdepth);
    292    nb_deltas = palette_iteration_data.frequent_deltas[0].size();
    293 
    294    // Count color frequency for colors that make a cross.
    295    std::map<std::vector<pixel_type>, size_t> color_freq_map;
    296    for (size_t y = 1; y + 1 < h; y++) {
    297      for (uint32_t c = 0; c < nb; c++) {
    298        p_in[c] = input.channel[begin_c + c].Row(y);
    299      }
    300      for (size_t x = 1; x + 1 < w; x++) {
    301        for (uint32_t c = 0; c < nb; c++) {
    302          color[c] = p_in[c][x];
    303        }
    304        int offsets[4][2] = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}};
    305        bool makes_cross = true;
    306        for (int i = 0; i < 4 && makes_cross; ++i) {
    307          int dx = offsets[i][0];
    308          int dy = offsets[i][1];
    309          for (uint32_t c = 0; c < nb && makes_cross; c++) {
    310            if (input.channel[begin_c + c].Row(y + dy)[x + dx] != color[c]) {
    311              makes_cross = false;
    312            }
    313          }
    314        }
    315        if (makes_cross) color_freq_map[color] += 1;
    316      }
    317    }
    318    // Add colors satisfying frequency condition to the palette.
    319    constexpr float kImageFraction = 0.01f;
    320    size_t color_frequency_lower_bound = 5 + input.h * input.w * kImageFraction;
    321    for (const auto &color_freq : color_freq_map) {
    322      if (color_freq.second > color_frequency_lower_bound) {
    323        candidate_palette.insert(color_freq.first);
    324        candidate_palette_imageorder.push_back(color_freq.first);
    325      }
    326    }
    327  }
    328 
    329  std::map<std::vector<pixel_type>, bool> implicit_color;
    330  std::vector<std::vector<pixel_type>> implicit_colors;
    331  implicit_colors.reserve(palette_internal::kImplicitPaletteSize);
    332  for (size_t k = 0; k < palette_internal::kImplicitPaletteSize; k++) {
    333    for (size_t i = 0; i < nb; i++) {
    334      color[i] = palette_internal::GetPaletteValue(nullptr, k, i, 0, 0,
    335                                                   input.bitdepth);
    336    }
    337    implicit_color[color] = true;
    338    implicit_colors.push_back(color);
    339  }
    340 
    341  std::map<std::vector<pixel_type>, size_t> color_freq_map;
    342  uint32_t implicit_colors_used = 0;
    343  for (size_t y = 0; y < h; y++) {
    344    for (uint32_t c = 0; c < nb; c++) {
    345      p_in[c] = input.channel[begin_c + c].Row(y);
    346    }
    347    for (size_t x = 0; x < w; x++) {
    348      if (lossy && candidate_palette.size() >= nb_colors) break;
    349      for (uint32_t c = 0; c < nb; c++) {
    350        color[c] = p_in[c][x];
    351      }
    352      const bool new_color = candidate_palette.insert(color).second;
    353      if (new_color) {
    354        if (implicit_color[color]) {
    355          implicit_colors_used++;
    356        } else {
    357          candidate_palette_imageorder.push_back(color);
    358          if (candidate_palette_imageorder.size() > nb_colors) {
    359            return false;  // too many colors
    360          }
    361        }
    362      }
    363      color_freq_map[color] += 1;
    364    }
    365  }
    366 
    367  nb_colors = nb_deltas + candidate_palette_imageorder.size();
    368 
    369  // not useful to make a single-color palette
    370  if (!lossy && nb_colors + implicit_colors_used == 1) return false;
    371  // TODO(jon): if this happens (e.g. solid white group), special-case it for
    372  // faster encode
    373 
    374  for (size_t k = 0; k < palette_internal::kImplicitPaletteSize; k++) {
    375    color = implicit_colors[k];
    376    // still add the color to the explicit palette if it is frequent enough
    377    if (color_freq_map[color] > 10) {
    378      nb_colors++;
    379      candidate_palette_imageorder.push_back(color);
    380    }
    381  }
    382  for (size_t k = 0; k < palette_internal::kImplicitPaletteSize; k++) {
    383    color = implicit_colors[k];
    384    inv_palette[color] = nb_colors + k;
    385  }
    386 
    387  JXL_DEBUG_V(6, "Channels %i-%i can be represented using a %i-color palette.",
    388              begin_c, end_c, nb_colors);
    389 
    390  JXL_ASSIGN_OR_RETURN(Channel pch,
    391                       Channel::Create(memory_manager, nb_colors, nb));
    392  pch.hshift = -1;
    393  pch.vshift = -1;
    394  pixel_type *JXL_RESTRICT p_palette = pch.Row(0);
    395  intptr_t onerow = pch.plane.PixelsPerRow();
    396  intptr_t onerow_image = input.channel[begin_c].plane.PixelsPerRow();
    397  const int bit_depth = std::min(input.bitdepth, 24);
    398 
    399  if (lossy) {
    400    for (uint32_t i = 0; i < nb_deltas; i++) {
    401      for (size_t c = 0; c < 3; c++) {
    402        p_palette[c * onerow + i] =
    403            palette_iteration_data.frequent_deltas[c][i];
    404      }
    405    }
    406  }
    407  // Separate the palette in two buckets, first the common colors, then the
    408  // rare colors.
    409  // Within each bucket, the colors are sorted on luma (times alpha).
    410  float freq_threshold = 4;  // arbitrary threshold
    411  int x = 0;
    412  if (ordered && nb >= 3) {
    413    JXL_DEBUG_V(7, "Palette of %i colors, using luma order", nb_colors);
    414    // sort on luma (multiplied by alpha if available)
    415    std::sort(candidate_palette_imageorder.begin(),
    416              candidate_palette_imageorder.end(),
    417              [&](std::vector<pixel_type> ap, std::vector<pixel_type> bp) {
    418                float ay;
    419                float by;
    420                ay = (0.299f * ap[0] + 0.587f * ap[1] + 0.114f * ap[2] + 0.1f);
    421                if (ap.size() > 3) ay *= 1.f + ap[3];
    422                by = (0.299f * bp[0] + 0.587f * bp[1] + 0.114f * bp[2] + 0.1f);
    423                if (bp.size() > 3) by *= 1.f + bp[3];
    424                // put common colors first, transparent dark to opaque bright,
    425                // then rare colors, bright to dark
    426                ay = color_freq_map[ap] > freq_threshold ? -ay : ay;
    427                by = color_freq_map[bp] > freq_threshold ? -by : by;
    428                return ay < by;
    429              });
    430  } else {
    431    JXL_DEBUG_V(7, "Palette of %i colors, using image order", nb_colors);
    432  }
    433 
    434  for (auto pcol : candidate_palette_imageorder) {
    435    JXL_DEBUG_V(9, "  Color %i :  ", x);
    436    for (size_t i = 0; i < nb; i++) {
    437      p_palette[nb_deltas + i * onerow + x] = pcol[i];
    438      JXL_DEBUG_V(9, "%i ", pcol[i]);
    439    }
    440    inv_palette[pcol] = x;
    441    x++;
    442  }
    443  std::vector<weighted::State> wp_states;
    444  for (size_t c = 0; c < nb; c++) {
    445    wp_states.emplace_back(wp_header, w, h);
    446  }
    447  std::vector<pixel_type *> p_quant(nb);
    448  // Three rows of error for dithering: y to y + 2.
    449  // Each row has two pixels of padding in the ends, which is
    450  // beneficial for both precision and encoding speed.
    451  std::vector<std::vector<float>> error_row[3];
    452  if (lossy) {
    453    for (auto &row : error_row) {
    454      row.resize(nb);
    455      for (size_t c = 0; c < nb; ++c) {
    456        row[c].resize(w + 4);
    457      }
    458    }
    459  }
    460  for (size_t y = 0; y < h; y++) {
    461    for (size_t c = 0; c < nb; c++) {
    462      p_in[c] = input.channel[begin_c + c].Row(y);
    463      if (lossy) p_quant[c] = quantized_input.channel[c].Row(y);
    464    }
    465    pixel_type *JXL_RESTRICT p = input.channel[begin_c].Row(y);
    466    for (size_t x = 0; x < w; x++) {
    467      int index;
    468      if (!lossy) {
    469        for (size_t c = 0; c < nb; c++) color[c] = p_in[c][x];
    470        index = inv_palette[color];
    471      } else {
    472        int best_index = 0;
    473        bool best_is_delta = false;
    474        float best_distance = std::numeric_limits<float>::infinity();
    475        std::vector<pixel_type> best_val(nb, 0);
    476        std::vector<pixel_type> ideal_residual(nb, 0);
    477        std::vector<pixel_type> quantized_val(nb);
    478        std::vector<pixel_type> predictions(nb);
    479        for (double diffusion_multiplier : {0.55, 0.75}) {
    480          for (size_t c = 0; c < nb; c++) {
    481            color_with_error[c] =
    482                p_in[c][x] + (palette_iteration_data.final_run ? 1 : 0) *
    483                                 diffusion_multiplier * error_row[0][c][x + 2];
    484            color[c] = Clamp1(lroundf(color_with_error[c]), 0l,
    485                              (1l << input.bitdepth) - 1);
    486          }
    487 
    488          for (size_t c = 0; c < nb; ++c) {
    489            predictions[c] = PredictNoTreeWP(w, p_quant[c] + x, onerow_image, x,
    490                                             y, predictor, &wp_states[c])
    491                                 .guess;
    492          }
    493          const auto TryIndex = [&](const int index) {
    494            for (size_t c = 0; c < nb; c++) {
    495              quantized_val[c] = palette_internal::GetPaletteValue(
    496                  p_palette, index, /*c=*/c,
    497                  /*palette_size=*/nb_colors,
    498                  /*onerow=*/onerow, /*bit_depth=*/bit_depth);
    499              if (index < static_cast<int>(nb_deltas)) {
    500                quantized_val[c] += predictions[c];
    501              }
    502            }
    503            const float color_distance =
    504                32.0 / (1LL << std::max(0, 2 * (bit_depth - 8))) *
    505                palette_internal::ColorDistance(color_with_error,
    506                                                quantized_val);
    507            float index_penalty = 0;
    508            if (index == -1) {
    509              index_penalty = -124;
    510            } else if (index < 0) {
    511              index_penalty = -2 * index;
    512            } else if (index < static_cast<int>(nb_deltas)) {
    513              index_penalty = 250;
    514            } else if (index < static_cast<int>(nb_colors)) {
    515              index_penalty = 150;
    516            } else if (index < static_cast<int>(nb_colors) +
    517                                   palette_internal::kLargeCubeOffset) {
    518              index_penalty = 70;
    519            } else {
    520              index_penalty = 256;
    521            }
    522            const float distance = color_distance + index_penalty;
    523            if (distance < best_distance) {
    524              best_distance = distance;
    525              best_index = index;
    526              best_is_delta = index < static_cast<int>(nb_deltas);
    527              best_val.swap(quantized_val);
    528              for (size_t c = 0; c < nb; ++c) {
    529                ideal_residual[c] = color_with_error[c] - predictions[c];
    530              }
    531            }
    532          };
    533          for (index = palette_internal::kMinImplicitPaletteIndex;
    534               index < static_cast<int32_t>(nb_colors); index++) {
    535            TryIndex(index);
    536          }
    537          TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex(
    538              color, nb_colors, bit_depth,
    539              /*high_quality=*/false));
    540          if (palette_internal::kEncodeToHighQualityImplicitPalette) {
    541            TryIndex(palette_internal::QuantizeColorToImplicitPaletteIndex(
    542                color, nb_colors, bit_depth,
    543                /*high_quality=*/true));
    544          }
    545        }
    546        index = best_index;
    547        delta_used |= best_is_delta;
    548        if (!palette_iteration_data.final_run) {
    549          for (size_t c = 0; c < 3; ++c) {
    550            palette_iteration_data.deltas[c].push_back(ideal_residual[c]);
    551          }
    552          palette_iteration_data.delta_distances.push_back(best_distance);
    553        }
    554 
    555        for (size_t c = 0; c < nb; ++c) {
    556          wp_states[c].UpdateErrors(best_val[c], x, y, w);
    557          p_quant[c][x] = best_val[c];
    558        }
    559        float len_error = 0;
    560        for (size_t c = 0; c < nb; ++c) {
    561          float local_error = color_with_error[c] - best_val[c];
    562          len_error += local_error * local_error;
    563        }
    564        len_error = std::sqrt(len_error);
    565        float modulate = 1.0;
    566        int len_limit = 38 << std::max(0, bit_depth - 8);
    567        if (len_error > len_limit) {
    568          modulate *= len_limit / len_error;
    569        }
    570        for (size_t c = 0; c < nb; ++c) {
    571          float total_error = (color_with_error[c] - best_val[c]);
    572 
    573          // If the neighboring pixels have some error in the opposite
    574          // direction of total_error, cancel some or all of it out before
    575          // spreading among them.
    576          constexpr int offsets[12][2] = {{1, 2}, {0, 3}, {0, 4}, {1, 1},
    577                                          {1, 3}, {2, 2}, {1, 0}, {1, 4},
    578                                          {2, 1}, {2, 3}, {2, 0}, {2, 4}};
    579          float total_available = 0;
    580          for (int i = 0; i < 11; ++i) {
    581            const int row = offsets[i][0];
    582            const int col = offsets[i][1];
    583            if (std::signbit(error_row[row][c][x + col]) !=
    584                std::signbit(total_error)) {
    585              total_available += error_row[row][c][x + col];
    586            }
    587          }
    588          float weight =
    589              std::abs(total_error) / (std::abs(total_available) + 1e-3);
    590          weight = std::min(weight, 1.0f);
    591          for (int i = 0; i < 11; ++i) {
    592            const int row = offsets[i][0];
    593            const int col = offsets[i][1];
    594            if (std::signbit(error_row[row][c][x + col]) !=
    595                std::signbit(total_error)) {
    596              total_error += weight * error_row[row][c][x + col];
    597              error_row[row][c][x + col] *= (1 - weight);
    598            }
    599          }
    600          total_error *= modulate;
    601          const float remaining_error = (1.0f / 14.) * total_error;
    602          error_row[0][c][x + 3] += 2 * remaining_error;
    603          error_row[0][c][x + 4] += remaining_error;
    604          error_row[1][c][x + 0] += remaining_error;
    605          for (int i = 0; i < 5; ++i) {
    606            error_row[1][c][x + i] += remaining_error;
    607            error_row[2][c][x + i] += remaining_error;
    608          }
    609        }
    610      }
    611      if (palette_iteration_data.final_run) p[x] = index;
    612    }
    613    if (lossy) {
    614      for (size_t c = 0; c < nb; ++c) {
    615        error_row[0][c].swap(error_row[1][c]);
    616        error_row[1][c].swap(error_row[2][c]);
    617        std::fill(error_row[2][c].begin(), error_row[2][c].end(), 0.f);
    618      }
    619    }
    620  }
    621  if (!delta_used) {
    622    predictor = Predictor::Zero;
    623  }
    624  if (palette_iteration_data.final_run) {
    625    input.nb_meta_channels++;
    626    input.channel.erase(input.channel.begin() + begin_c + 1,
    627                        input.channel.begin() + end_c + 1);
    628    input.channel.insert(input.channel.begin(), std::move(pch));
    629  }
    630  nb_colors -= nb_deltas;
    631  return true;
    632 }
    633 
    634 Status FwdPalette(Image &input, uint32_t begin_c, uint32_t end_c,
    635                  uint32_t &nb_colors, uint32_t &nb_deltas, bool ordered,
    636                  bool lossy, Predictor &predictor,
    637                  const weighted::Header &wp_header) {
    638  PaletteIterationData palette_iteration_data;
    639  uint32_t nb_colors_orig = nb_colors;
    640  uint32_t nb_deltas_orig = nb_deltas;
    641  // preprocessing pass in case of lossy palette
    642  if (lossy && input.bitdepth >= 8) {
    643    JXL_RETURN_IF_ERROR(FwdPaletteIteration(
    644        input, begin_c, end_c, nb_colors_orig, nb_deltas_orig, ordered, lossy,
    645        predictor, wp_header, palette_iteration_data));
    646  }
    647  palette_iteration_data.final_run = true;
    648  return FwdPaletteIteration(input, begin_c, end_c, nb_colors, nb_deltas,
    649                             ordered, lossy, predictor, wp_header,
    650                             palette_iteration_data);
    651 }
    652 
    653 }  // namespace jxl