tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

palette.c (43347B)


      1 /*
      2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <math.h>
     13 #include <stdlib.h>
     14 
     15 #include "av1/common/pred_common.h"
     16 
     17 #include "aom_ports/bitops.h"
     18 #include "av1/encoder/block.h"
     19 #include "av1/encoder/cost.h"
     20 #include "av1/encoder/encoder.h"
     21 #include "av1/encoder/intra_mode_search.h"
     22 #include "av1/encoder/intra_mode_search_utils.h"
     23 #include "av1/encoder/palette.h"
     24 #include "av1/encoder/random.h"
     25 #include "av1/encoder/rdopt_utils.h"
     26 #include "av1/encoder/tx_search.h"
     27 
     28 #define AV1_K_MEANS_DIM 1
     29 #include "av1/encoder/k_means_template.h"
     30 #undef AV1_K_MEANS_DIM
     31 #define AV1_K_MEANS_DIM 2
     32 #include "av1/encoder/k_means_template.h"
     33 #undef AV1_K_MEANS_DIM
     34 
     35 static int int16_comparer(const void *a, const void *b) {
     36  return (*(int16_t *)a - *(int16_t *)b);
     37 }
     38 
     39 /*!\brief Removes duplicated centroid indices.
     40 *
     41 * \ingroup palette_mode_search
     42 * \param[in]    centroids          A list of centroids index.
     43 * \param[in]    num_centroids      Number of centroids.
     44 *
     45 * \return Returns the number of unique centroids and saves the unique centroids
     46 * in beginning of the centroids array.
     47 *
     48 * \attention The centroids should be rounded to integers before calling this
     49 * method.
     50 */
     51 static int remove_duplicates(int16_t *centroids, int num_centroids) {
     52  int num_unique;  // number of unique centroids
     53  int i;
     54  qsort(centroids, num_centroids, sizeof(*centroids), int16_comparer);
     55  // Remove duplicates.
     56  num_unique = 1;
     57  for (i = 1; i < num_centroids; ++i) {
     58    if (centroids[i] != centroids[i - 1]) {  // found a new unique centroid
     59      centroids[num_unique++] = centroids[i];
     60    }
     61  }
     62  return num_unique;
     63 }
     64 
     65 static int delta_encode_cost(const int *colors, int num, int bit_depth,
     66                             int min_val) {
     67  if (num <= 0) return 0;
     68  int bits_cost = bit_depth;
     69  if (num == 1) return bits_cost;
     70  bits_cost += 2;
     71  int max_delta = 0;
     72  int deltas[PALETTE_MAX_SIZE];
     73  const int min_bits = bit_depth - 3;
     74  for (int i = 1; i < num; ++i) {
     75    const int delta = colors[i] - colors[i - 1];
     76    deltas[i - 1] = delta;
     77    assert(delta >= min_val);
     78    if (delta > max_delta) max_delta = delta;
     79  }
     80  int bits_per_delta = AOMMAX(aom_ceil_log2(max_delta + 1 - min_val), min_bits);
     81  assert(bits_per_delta <= bit_depth);
     82  int range = (1 << bit_depth) - colors[0] - min_val;
     83  for (int i = 0; i < num - 1; ++i) {
     84    bits_cost += bits_per_delta;
     85    range -= deltas[i];
     86    bits_per_delta = AOMMIN(bits_per_delta, aom_ceil_log2(range));
     87  }
     88  return bits_cost;
     89 }
     90 
     91 int av1_index_color_cache(const uint16_t *color_cache, int n_cache,
     92                          const uint16_t *colors, int n_colors,
     93                          uint8_t *cache_color_found, int *out_cache_colors) {
     94  if (n_cache <= 0) {
     95    for (int i = 0; i < n_colors; ++i) out_cache_colors[i] = colors[i];
     96    return n_colors;
     97  }
     98  memset(cache_color_found, 0, n_cache * sizeof(*cache_color_found));
     99  int n_in_cache = 0;
    100  int in_cache_flags[PALETTE_MAX_SIZE];
    101  memset(in_cache_flags, 0, sizeof(in_cache_flags));
    102  for (int i = 0; i < n_cache && n_in_cache < n_colors; ++i) {
    103    for (int j = 0; j < n_colors; ++j) {
    104      if (colors[j] == color_cache[i]) {
    105        in_cache_flags[j] = 1;
    106        cache_color_found[i] = 1;
    107        ++n_in_cache;
    108        break;
    109      }
    110    }
    111  }
    112  int j = 0;
    113  for (int i = 0; i < n_colors; ++i)
    114    if (!in_cache_flags[i]) out_cache_colors[j++] = colors[i];
    115  assert(j == n_colors - n_in_cache);
    116  return j;
    117 }
    118 
    119 int av1_get_palette_delta_bits_v(const PALETTE_MODE_INFO *const pmi,
    120                                 int bit_depth, int *zero_count,
    121                                 int *min_bits) {
    122  const int n = pmi->palette_size[1];
    123  const int max_val = 1 << bit_depth;
    124  int max_d = 0;
    125  *min_bits = bit_depth - 4;
    126  *zero_count = 0;
    127  for (int i = 1; i < n; ++i) {
    128    const int delta = pmi->palette_colors[2 * PALETTE_MAX_SIZE + i] -
    129                      pmi->palette_colors[2 * PALETTE_MAX_SIZE + i - 1];
    130    const int v = abs(delta);
    131    const int d = AOMMIN(v, max_val - v);
    132    if (d > max_d) max_d = d;
    133    if (d == 0) ++(*zero_count);
    134  }
    135  return AOMMAX(aom_ceil_log2(max_d + 1), *min_bits);
    136 }
    137 
    138 int av1_palette_color_cost_y(const PALETTE_MODE_INFO *const pmi,
    139                             const uint16_t *color_cache, int n_cache,
    140                             int bit_depth) {
    141  const int n = pmi->palette_size[0];
    142  int out_cache_colors[PALETTE_MAX_SIZE];
    143  uint8_t cache_color_found[2 * PALETTE_MAX_SIZE];
    144  const int n_out_cache =
    145      av1_index_color_cache(color_cache, n_cache, pmi->palette_colors, n,
    146                            cache_color_found, out_cache_colors);
    147  const int total_bits =
    148      n_cache + delta_encode_cost(out_cache_colors, n_out_cache, bit_depth, 1);
    149  return av1_cost_literal(total_bits);
    150 }
    151 
    152 int av1_palette_color_cost_uv(const PALETTE_MODE_INFO *const pmi,
    153                              const uint16_t *color_cache, int n_cache,
    154                              int bit_depth) {
    155  const int n = pmi->palette_size[1];
    156  int total_bits = 0;
    157  // U channel palette color cost.
    158  int out_cache_colors[PALETTE_MAX_SIZE];
    159  uint8_t cache_color_found[2 * PALETTE_MAX_SIZE];
    160  const int n_out_cache = av1_index_color_cache(
    161      color_cache, n_cache, pmi->palette_colors + PALETTE_MAX_SIZE, n,
    162      cache_color_found, out_cache_colors);
    163  total_bits +=
    164      n_cache + delta_encode_cost(out_cache_colors, n_out_cache, bit_depth, 0);
    165 
    166  // V channel palette color cost.
    167  int zero_count = 0, min_bits_v = 0;
    168  const int bits_v =
    169      av1_get_palette_delta_bits_v(pmi, bit_depth, &zero_count, &min_bits_v);
    170  const int bits_using_delta =
    171      2 + bit_depth + (bits_v + 1) * (n - 1) - zero_count;
    172  const int bits_using_raw = bit_depth * n;
    173  total_bits += 1 + AOMMIN(bits_using_delta, bits_using_raw);
    174  return av1_cost_literal(total_bits);
    175 }
    176 
    177 // Extends 'color_map' array from 'orig_width x orig_height' to 'new_width x
    178 // new_height'. Extra rows and columns are filled in by copying last valid
    179 // row/column.
    180 static inline void extend_palette_color_map(uint8_t *const color_map,
    181                                            int orig_width, int orig_height,
    182                                            int new_width, int new_height) {
    183  int j;
    184  assert(new_width >= orig_width);
    185  assert(new_height >= orig_height);
    186  if (new_width == orig_width && new_height == orig_height) return;
    187 
    188  for (j = orig_height - 1; j >= 0; --j) {
    189    memmove(color_map + j * new_width, color_map + j * orig_width, orig_width);
    190    // Copy last column to extra columns.
    191    memset(color_map + j * new_width + orig_width,
    192           color_map[j * new_width + orig_width - 1], new_width - orig_width);
    193  }
    194  // Copy last row to extra rows.
    195  for (j = orig_height; j < new_height; ++j) {
    196    memcpy(color_map + j * new_width, color_map + (orig_height - 1) * new_width,
    197           new_width);
    198  }
    199 }
    200 
    201 // Bias toward using colors in the cache.
    202 // TODO(huisu): Try other schemes to improve compression.
    203 static inline void optimize_palette_colors(uint16_t *color_cache, int n_cache,
    204                                           int n_colors, int stride,
    205                                           int16_t *centroids, int bit_depth) {
    206  if (n_cache <= 0) return;
    207  for (int i = 0; i < n_colors * stride; i += stride) {
    208    int min_diff = abs((int)centroids[i] - (int)color_cache[0]);
    209    int idx = 0;
    210    for (int j = 1; j < n_cache; ++j) {
    211      const int this_diff = abs((int)centroids[i] - (int)color_cache[j]);
    212      if (this_diff < min_diff) {
    213        min_diff = this_diff;
    214        idx = j;
    215      }
    216    }
    217    const int min_threshold = 4 << (bit_depth - 8);
    218    if (min_diff <= min_threshold) centroids[i] = color_cache[idx];
    219  }
    220 }
    221 
    222 /*!\brief Calculate the luma palette cost from a given color palette
    223 *
    224 * \ingroup palette_mode_search
    225 * \callergraph
    226 * Given the base colors as specified in centroids[], calculate the RD cost
    227 * of palette mode.
    228 */
    229 static inline void palette_rd_y(
    230    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
    231    BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data, int16_t *centroids,
    232    int n, uint16_t *color_cache, int n_cache, bool do_header_rd_based_gating,
    233    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
    234    int *rate, int *rate_tokenonly, int64_t *distortion, uint8_t *skippable,
    235    int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *blk_skip,
    236    uint8_t *tx_type_map, int *beat_best_palette_rd,
    237    bool *do_header_rd_based_breakout, int discount_color_cost) {
    238  if (do_header_rd_based_breakout != NULL) *do_header_rd_based_breakout = false;
    239  optimize_palette_colors(color_cache, n_cache, n, 1, centroids,
    240                          cpi->common.seq_params->bit_depth);
    241  const int num_unique_colors = remove_duplicates(centroids, n);
    242  if (num_unique_colors < PALETTE_MIN_SIZE) {
    243    // Too few unique colors to create a palette. And DC_PRED will work
    244    // well for that case anyway. So skip.
    245    return;
    246  }
    247  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
    248  if (cpi->common.seq_params->use_highbitdepth) {
    249    for (int i = 0; i < num_unique_colors; ++i) {
    250      pmi->palette_colors[i] = clip_pixel_highbd(
    251          (int)centroids[i], cpi->common.seq_params->bit_depth);
    252    }
    253  } else {
    254    for (int i = 0; i < num_unique_colors; ++i) {
    255      pmi->palette_colors[i] = clip_pixel(centroids[i]);
    256    }
    257  }
    258  pmi->palette_size[0] = num_unique_colors;
    259  MACROBLOCKD *const xd = &x->e_mbd;
    260  uint8_t *const color_map = xd->plane[0].color_index_map;
    261  int block_width, block_height, rows, cols;
    262  av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
    263                           &cols);
    264  av1_calc_indices(data, centroids, color_map, rows * cols, num_unique_colors,
    265                   1);
    266  extend_palette_color_map(color_map, cols, rows, block_width, block_height);
    267 
    268  RD_STATS tokenonly_rd_stats;
    269  int this_rate;
    270 
    271  if (do_header_rd_based_gating) {
    272    assert(do_header_rd_based_breakout != NULL);
    273    const int palette_mode_rate = intra_mode_info_cost_y(
    274        cpi, x, mbmi, bsize, dc_mode_cost, discount_color_cost);
    275    const int64_t header_rd = RDCOST(x->rdmult, palette_mode_rate, 0);
    276    // Less aggressive pruning when prune_luma_palette_size_search_level == 1.
    277    const int header_rd_shift =
    278        (cpi->sf.intra_sf.prune_luma_palette_size_search_level == 1) ? 1 : 0;
    279    // Terminate further palette_size search, if the header cost corresponding
    280    // to lower palette_size is more than *best_rd << header_rd_shift. This
    281    // logic is implemented with a right shift in the LHS to prevent a possible
    282    // overflow with the left shift in RHS.
    283    if ((header_rd >> header_rd_shift) > *best_rd) {
    284      *do_header_rd_based_breakout = true;
    285      return;
    286    }
    287    av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
    288                                      *best_rd);
    289    if (tokenonly_rd_stats.rate == INT_MAX) return;
    290    this_rate = tokenonly_rd_stats.rate + palette_mode_rate;
    291  } else {
    292    av1_pick_uniform_tx_size_type_yrd(cpi, x, &tokenonly_rd_stats, bsize,
    293                                      *best_rd);
    294    if (tokenonly_rd_stats.rate == INT_MAX) return;
    295    this_rate = tokenonly_rd_stats.rate +
    296                intra_mode_info_cost_y(cpi, x, mbmi, bsize, dc_mode_cost,
    297                                       discount_color_cost);
    298  }
    299 
    300  int64_t this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
    301  if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->bsize)) {
    302    tokenonly_rd_stats.rate -= tx_size_cost(x, bsize, mbmi->tx_size);
    303  }
    304  // Collect mode stats for multiwinner mode processing
    305  const int txfm_search_done = 1;
    306  store_winner_mode_stats(
    307      &cpi->common, x, mbmi, NULL, NULL, NULL, THR_DC, color_map, bsize,
    308      this_rd, cpi->sf.winner_mode_sf.multi_winner_mode_type, txfm_search_done);
    309  if (this_rd < *best_rd) {
    310    *best_rd = this_rd;
    311    // Setting beat_best_rd flag because current mode rd is better than best_rd.
    312    // This flag need to be updated only for palette evaluation in key frames
    313    if (beat_best_rd) *beat_best_rd = 1;
    314    memcpy(best_palette_color_map, color_map,
    315           block_width * block_height * sizeof(color_map[0]));
    316    *best_mbmi = *mbmi;
    317    memcpy(blk_skip, x->txfm_search_info.blk_skip,
    318           sizeof(x->txfm_search_info.blk_skip[0]) * ctx->num_4x4_blk);
    319    av1_copy_array(tx_type_map, xd->tx_type_map, ctx->num_4x4_blk);
    320    if (rate) *rate = this_rate;
    321    if (rate_tokenonly) *rate_tokenonly = tokenonly_rd_stats.rate;
    322    if (distortion) *distortion = tokenonly_rd_stats.dist;
    323    if (skippable) *skippable = tokenonly_rd_stats.skip_txfm;
    324    if (beat_best_palette_rd) *beat_best_palette_rd = 1;
    325  }
    326 }
    327 
    328 static inline int is_iter_over(int curr_idx, int end_idx, int step_size) {
    329  assert(step_size != 0);
    330  return (step_size > 0) ? curr_idx >= end_idx : curr_idx <= end_idx;
    331 }
    332 
    333 // Performs count-based palette search with number of colors in interval
    334 // [start_n, end_n) with step size step_size. If step_size < 0, then end_n can
    335 // be less than start_n. Saves the last numbers searched in last_n_searched and
    336 // returns the best number of colors found.
    337 static inline int perform_top_color_palette_search(
    338    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
    339    BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data,
    340    int16_t *top_colors, int start_n, int end_n, int step_size,
    341    bool do_header_rd_based_gating, int *last_n_searched, uint16_t *color_cache,
    342    int n_cache, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
    343    int64_t *best_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
    344    uint8_t *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
    345    uint8_t *best_blk_skip, uint8_t *tx_type_map, int discount_color_cost) {
    346  int16_t centroids[PALETTE_MAX_SIZE];
    347  int n = start_n;
    348  int top_color_winner = end_n;
    349  /* clang-format off */
    350  assert(IMPLIES(step_size < 0, start_n > end_n));
    351  /* clang-format on */
    352  assert(IMPLIES(step_size > 0, start_n < end_n));
    353  while (!is_iter_over(n, end_n, step_size)) {
    354    int beat_best_palette_rd = 0;
    355    bool do_header_rd_based_breakout = false;
    356    memcpy(centroids, top_colors, n * sizeof(top_colors[0]));
    357    palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
    358                 color_cache, n_cache, do_header_rd_based_gating, best_mbmi,
    359                 best_palette_color_map, best_rd, rate, rate_tokenonly,
    360                 distortion, skippable, beat_best_rd, ctx, best_blk_skip,
    361                 tx_type_map, &beat_best_palette_rd,
    362                 &do_header_rd_based_breakout, discount_color_cost);
    363    *last_n_searched = n;
    364    if (do_header_rd_based_breakout) {
    365      // Terminate palette_size search by setting last_n_searched to end_n.
    366      *last_n_searched = end_n;
    367      break;
    368    }
    369    if (beat_best_palette_rd) {
    370      top_color_winner = n;
    371    } else if (cpi->sf.intra_sf.prune_palette_search_level == 2) {
    372      // At search level 2, we return immediately if we don't see an improvement
    373      return top_color_winner;
    374    }
    375    n += step_size;
    376  }
    377  return top_color_winner;
    378 }
    379 
    380 // Performs k-means based palette search with number of colors in interval
    381 // [start_n, end_n) with step size step_size. If step_size < 0, then end_n can
    382 // be less than start_n. Saves the last numbers searched in last_n_searched and
    383 // returns the best number of colors found.
    384 static inline int perform_k_means_palette_search(
    385    const AV1_COMP *const cpi, MACROBLOCK *x, MB_MODE_INFO *mbmi,
    386    BLOCK_SIZE bsize, int dc_mode_cost, const int16_t *data, int lower_bound,
    387    int upper_bound, int start_n, int end_n, int step_size,
    388    bool do_header_rd_based_gating, int *last_n_searched, uint16_t *color_cache,
    389    int n_cache, MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map,
    390    int64_t *best_rd, int *rate, int *rate_tokenonly, int64_t *distortion,
    391    uint8_t *skippable, int *beat_best_rd, PICK_MODE_CONTEXT *ctx,
    392    uint8_t *best_blk_skip, uint8_t *tx_type_map, uint8_t *color_map,
    393    int data_points, int discount_color_cost) {
    394  int16_t centroids[PALETTE_MAX_SIZE];
    395  const int max_itr = 50;
    396  int n = start_n;
    397  int top_color_winner = end_n;
    398  /* clang-format off */
    399  assert(IMPLIES(step_size < 0, start_n > end_n));
    400  /* clang-format on */
    401  assert(IMPLIES(step_size > 0, start_n < end_n));
    402  while (!is_iter_over(n, end_n, step_size)) {
    403    int beat_best_palette_rd = 0;
    404    bool do_header_rd_based_breakout = false;
    405    for (int i = 0; i < n; ++i) {
    406      centroids[i] =
    407          lower_bound + (2 * i + 1) * (upper_bound - lower_bound) / n / 2;
    408    }
    409    av1_k_means(data, centroids, color_map, data_points, n, 1, max_itr);
    410    palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, n,
    411                 color_cache, n_cache, do_header_rd_based_gating, best_mbmi,
    412                 best_palette_color_map, best_rd, rate, rate_tokenonly,
    413                 distortion, skippable, beat_best_rd, ctx, best_blk_skip,
    414                 tx_type_map, &beat_best_palette_rd,
    415                 &do_header_rd_based_breakout, discount_color_cost);
    416    *last_n_searched = n;
    417    if (do_header_rd_based_breakout) {
    418      // Terminate palette_size search by setting last_n_searched to end_n.
    419      *last_n_searched = end_n;
    420      break;
    421    }
    422    if (beat_best_palette_rd) {
    423      top_color_winner = n;
    424    } else if (cpi->sf.intra_sf.prune_palette_search_level == 2) {
    425      // At search level 2, we return immediately if we don't see an improvement
    426      return top_color_winner;
    427    }
    428    n += step_size;
    429  }
    430  return top_color_winner;
    431 }
    432 
    433 // Sets the parameters to search the current number of colors +- 1
    434 static inline void set_stage2_params(int *min_n, int *max_n, int *step_size,
    435                                     int winner, int end_n) {
    436  // Set min to winner - 1 unless we are already at the border, then we set it
    437  // to winner + 1
    438  *min_n = (winner == PALETTE_MIN_SIZE) ? (PALETTE_MIN_SIZE + 1)
    439                                        : AOMMAX(winner - 1, PALETTE_MIN_SIZE);
    440  // Set max to winner + 1 unless we are already at the border, then we set it
    441  // to winner - 1
    442  *max_n =
    443      (winner == end_n) ? (winner - 1) : AOMMIN(winner + 1, PALETTE_MAX_SIZE);
    444 
    445  // Set the step size to max_n - min_n so we only search those two values.
    446  // If max_n == min_n, then set step_size to 1 to avoid infinite loop later.
    447  *step_size = AOMMAX(1, *max_n - *min_n);
    448 }
    449 
    450 static inline void fill_data_and_get_bounds(const uint8_t *src,
    451                                            const int src_stride,
    452                                            const int rows, const int cols,
    453                                            const int is_high_bitdepth,
    454                                            int16_t *data, int *lower_bound,
    455                                            int *upper_bound) {
    456  if (is_high_bitdepth) {
    457    const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
    458    *lower_bound = *upper_bound = src_ptr[0];
    459    for (int r = 0; r < rows; ++r) {
    460      for (int c = 0; c < cols; ++c) {
    461        const int val = src_ptr[c];
    462        data[c] = (int16_t)val;
    463        *lower_bound = AOMMIN(*lower_bound, val);
    464        *upper_bound = AOMMAX(*upper_bound, val);
    465      }
    466      src_ptr += src_stride;
    467      data += cols;
    468    }
    469    return;
    470  }
    471 
    472  // low bit depth
    473  *lower_bound = *upper_bound = src[0];
    474  for (int r = 0; r < rows; ++r) {
    475    for (int c = 0; c < cols; ++c) {
    476      const int val = src[c];
    477      data[c] = (int16_t)val;
    478      *lower_bound = AOMMIN(*lower_bound, val);
    479      *upper_bound = AOMMAX(*upper_bound, val);
    480    }
    481    src += src_stride;
    482    data += cols;
    483  }
    484 }
    485 
    486 /*! \brief Colors are sorted by their count: the higher the better.
    487 */
    488 struct ColorCount {
    489  //! Color index in the histogram.
    490  int index;
    491  //! Histogram count.
    492  int count;
    493 };
    494 
    495 static int color_count_comp(const void *c1, const void *c2) {
    496  const struct ColorCount *color_count1 = (const struct ColorCount *)c1;
    497  const struct ColorCount *color_count2 = (const struct ColorCount *)c2;
    498  if (color_count1->count > color_count2->count) return -1;
    499  if (color_count1->count < color_count2->count) return 1;
    500  if (color_count1->index < color_count2->index) return -1;
    501  return 1;
    502 }
    503 
    504 static void find_top_colors(const int *const count_buf, int bit_depth,
    505                            int n_colors, int16_t *top_colors) {
    506  // Top color array, serving as a priority queue if more than n_colors are
    507  // found.
    508  struct ColorCount top_color_counts[PALETTE_MAX_SIZE] = { { 0 } };
    509  int n_color_count = 0;
    510  for (int i = 0; i < (1 << bit_depth); ++i) {
    511    if (count_buf[i] > 0) {
    512      if (n_color_count < n_colors) {
    513        // Keep adding to the top colors.
    514        top_color_counts[n_color_count].index = i;
    515        top_color_counts[n_color_count].count = count_buf[i];
    516        ++n_color_count;
    517        if (n_color_count == n_colors) {
    518          qsort(top_color_counts, n_colors, sizeof(top_color_counts[0]),
    519                color_count_comp);
    520        }
    521      } else {
    522        // Check the worst in the sorted top.
    523        if (count_buf[i] > top_color_counts[n_colors - 1].count) {
    524          int j = n_colors - 1;
    525          // Move up to the best one.
    526          while (j >= 1 && count_buf[i] > top_color_counts[j - 1].count) --j;
    527          memmove(top_color_counts + j + 1, top_color_counts + j,
    528                  (n_colors - j - 1) * sizeof(top_color_counts[0]));
    529          top_color_counts[j].index = i;
    530          top_color_counts[j].count = count_buf[i];
    531        }
    532      }
    533    }
    534  }
    535  assert(n_color_count == n_colors);
    536 
    537  for (int i = 0; i < n_colors; ++i) {
    538    top_colors[i] = top_color_counts[i].index;
    539  }
    540 }
    541 
    542 void av1_rd_pick_palette_intra_sby(
    543    const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int dc_mode_cost,
    544    MB_MODE_INFO *best_mbmi, uint8_t *best_palette_color_map, int64_t *best_rd,
    545    int *rate, int *rate_tokenonly, int64_t *distortion, uint8_t *skippable,
    546    int *beat_best_rd, PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip,
    547    uint8_t *tx_type_map) {
    548  MACROBLOCKD *const xd = &x->e_mbd;
    549  MB_MODE_INFO *const mbmi = xd->mi[0];
    550  assert(!is_inter_block(mbmi));
    551  assert(av1_allow_palette(cpi->common.features.allow_screen_content_tools,
    552                           bsize));
    553  assert(PALETTE_MAX_SIZE == 8);
    554  assert(PALETTE_MIN_SIZE == 2);
    555 
    556  const int src_stride = x->plane[0].src.stride;
    557  const uint8_t *const src = x->plane[0].src.buf;
    558  int block_width, block_height, rows, cols;
    559  av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
    560                           &cols);
    561  const SequenceHeader *const seq_params = cpi->common.seq_params;
    562  const int is_hbd = seq_params->use_highbitdepth;
    563  const int bit_depth = seq_params->bit_depth;
    564  const int discount_color_cost = cpi->sf.rt_sf.discount_color_cost;
    565  int unused;
    566 
    567  int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
    568  int colors, colors_threshold = 0;
    569  if (is_hbd) {
    570    int count_buf_8bit[1 << 8];  // Maximum (1 << 8) bins for hbd path.
    571    av1_count_colors_highbd(src, src_stride, rows, cols, bit_depth, count_buf,
    572                            count_buf_8bit, &colors_threshold, &colors);
    573  } else {
    574    av1_count_colors(src, src_stride, rows, cols, count_buf, &colors);
    575    colors_threshold = colors;
    576  }
    577 
    578  uint8_t *const color_map = xd->plane[0].color_index_map;
    579  int color_thresh_palette = x->color_palette_thresh;
    580  // Allow for larger color_threshold for palette search, based on color,
    581  // scene_change, and block source variance.
    582  // Since palette is Y based, only allow larger threshold if block
    583  // color_dist is below threshold.
    584  if (cpi->sf.rt_sf.use_nonrd_pick_mode &&
    585      cpi->sf.rt_sf.increase_color_thresh_palette && cpi->rc.high_source_sad &&
    586      x->source_variance > 50) {
    587    int64_t norm_color_dist = 0;
    588    if (x->color_sensitivity[0] || x->color_sensitivity[1]) {
    589      norm_color_dist = x->min_dist_inter_uv >>
    590                        (mi_size_wide_log2[bsize] + mi_size_high_log2[bsize]);
    591      if (x->color_sensitivity[0] && x->color_sensitivity[1])
    592        norm_color_dist = norm_color_dist >> 1;
    593    }
    594    if (norm_color_dist < 8000) color_thresh_palette += 20;
    595  }
    596  if (colors_threshold > 1 && colors_threshold <= color_thresh_palette) {
    597    int16_t *const data = x->palette_buffer->kmeans_data_buf;
    598    int16_t centroids[PALETTE_MAX_SIZE];
    599    int lower_bound, upper_bound;
    600    fill_data_and_get_bounds(src, src_stride, rows, cols, is_hbd, data,
    601                             &lower_bound, &upper_bound);
    602 
    603    mbmi->mode = DC_PRED;
    604    mbmi->filter_intra_mode_info.use_filter_intra = 0;
    605 
    606    uint16_t color_cache[2 * PALETTE_MAX_SIZE];
    607    const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
    608 
    609    // Find the dominant colors, stored in top_colors[].
    610    int16_t top_colors[PALETTE_MAX_SIZE] = { 0 };
    611    find_top_colors(count_buf, bit_depth, AOMMIN(colors, PALETTE_MAX_SIZE),
    612                    top_colors);
    613 
    614    // The following are the approaches used for header rdcost based gating
    615    // for early termination for different values of prune_palette_search_level.
    616    // 0: Pruning based on header rdcost for ascending order palette_size
    617    // search.
    618    // 1: When colors > PALETTE_MIN_SIZE, enabled only for coarse palette_size
    619    // search and for finer search do_header_rd_based_gating parameter is
    620    // explicitly passed as 'false'.
    621    // 2: Enabled only for ascending order palette_size search and for
    622    // descending order search do_header_rd_based_gating parameter is explicitly
    623    // passed as 'false'.
    624    const bool do_header_rd_based_gating =
    625        cpi->sf.intra_sf.prune_luma_palette_size_search_level != 0;
    626 
    627    // TODO(huisu@google.com): Try to avoid duplicate computation in cases
    628    // where the dominant colors and the k-means results are similar.
    629    if ((cpi->sf.intra_sf.prune_palette_search_level == 1) &&
    630        (colors > PALETTE_MIN_SIZE)) {
    631      // Start index and step size below are chosen to evaluate unique
    632      // candidates in neighbor search, in case a winner candidate is found in
    633      // coarse search. Example,
    634      // 1) 8 colors (end_n = 8): 2,3,4,5,6,7,8. start_n is chosen as 2 and step
    635      // size is chosen as 3. Therefore, coarse search will evaluate 2, 5 and 8.
    636      // If winner is found at 5, then 4 and 6 are evaluated. Similarly, for 2
    637      // (3) and 8 (7).
    638      // 2) 7 colors (end_n = 7): 2,3,4,5,6,7. If start_n is chosen as 2 (same
    639      // as for 8 colors) then step size should also be 2, to cover all
    640      // candidates. Coarse search will evaluate 2, 4 and 6. If winner is either
    641      // 2 or 4, 3 will be evaluated. Instead, if start_n=3 and step_size=3,
    642      // coarse search will evaluate 3 and 6. For the winner, unique neighbors
    643      // (3: 2,4 or 6: 5,7) would be evaluated.
    644 
    645      // Start index for coarse palette search for dominant colors and k-means
    646      const uint8_t start_n_lookup_table[PALETTE_MAX_SIZE + 1] = { 0, 0, 0,
    647                                                                   3, 3, 2,
    648                                                                   3, 3, 2 };
    649      // Step size for coarse palette search for dominant colors and k-means
    650      const uint8_t step_size_lookup_table[PALETTE_MAX_SIZE + 1] = { 0, 0, 0,
    651                                                                     3, 3, 3,
    652                                                                     3, 3, 3 };
    653 
    654      // Choose the start index and step size for coarse search based on number
    655      // of colors
    656      const int max_n = AOMMIN(colors, PALETTE_MAX_SIZE);
    657      const int min_n = start_n_lookup_table[max_n];
    658      const int step_size = step_size_lookup_table[max_n];
    659      assert(min_n >= PALETTE_MIN_SIZE);
    660      // Perform top color coarse palette search to find the winner candidate
    661      const int top_color_winner = perform_top_color_palette_search(
    662          cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, min_n, max_n + 1,
    663          step_size, do_header_rd_based_gating, &unused, color_cache, n_cache,
    664          best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
    665          distortion, skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
    666          discount_color_cost);
    667      // Evaluate neighbors for the winner color (if winner is found) in the
    668      // above coarse search for dominant colors
    669      if (top_color_winner <= max_n) {
    670        int stage2_min_n, stage2_max_n, stage2_step_size;
    671        set_stage2_params(&stage2_min_n, &stage2_max_n, &stage2_step_size,
    672                          top_color_winner, max_n);
    673        // perform finer search for the winner candidate
    674        perform_top_color_palette_search(
    675            cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, stage2_min_n,
    676            stage2_max_n + 1, stage2_step_size,
    677            /*do_header_rd_based_gating=*/false, &unused, color_cache, n_cache,
    678            best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
    679            distortion, skippable, beat_best_rd, ctx, best_blk_skip,
    680            tx_type_map, discount_color_cost);
    681      }
    682      // K-means clustering.
    683      // Perform k-means coarse palette search to find the winner candidate
    684      const int k_means_winner = perform_k_means_palette_search(
    685          cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
    686          min_n, max_n + 1, step_size, do_header_rd_based_gating, &unused,
    687          color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
    688          rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
    689          best_blk_skip, tx_type_map, color_map, rows * cols,
    690          discount_color_cost);
    691      // Evaluate neighbors for the winner color (if winner is found) in the
    692      // above coarse search for k-means
    693      if (k_means_winner <= max_n) {
    694        int start_n_stage2, end_n_stage2, step_size_stage2;
    695        set_stage2_params(&start_n_stage2, &end_n_stage2, &step_size_stage2,
    696                          k_means_winner, max_n);
    697        // perform finer search for the winner candidate
    698        perform_k_means_palette_search(
    699            cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
    700            start_n_stage2, end_n_stage2 + 1, step_size_stage2,
    701            /*do_header_rd_based_gating=*/false, &unused, color_cache, n_cache,
    702            best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
    703            distortion, skippable, beat_best_rd, ctx, best_blk_skip,
    704            tx_type_map, color_map, rows * cols, discount_color_cost);
    705      }
    706    } else {
    707      const int max_n = AOMMIN(colors, PALETTE_MAX_SIZE),
    708                min_n = PALETTE_MIN_SIZE;
    709      // Perform top color palette search in ascending order
    710      int last_n_searched = min_n;
    711      perform_top_color_palette_search(
    712          cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, min_n, max_n + 1,
    713          1, do_header_rd_based_gating, &last_n_searched, color_cache, n_cache,
    714          best_mbmi, best_palette_color_map, best_rd, rate, rate_tokenonly,
    715          distortion, skippable, beat_best_rd, ctx, best_blk_skip, tx_type_map,
    716          discount_color_cost);
    717      if (last_n_searched < max_n) {
    718        // Search in descending order until we get to the previous best
    719        perform_top_color_palette_search(
    720            cpi, x, mbmi, bsize, dc_mode_cost, data, top_colors, max_n,
    721            last_n_searched, -1, /*do_header_rd_based_gating=*/false, &unused,
    722            color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
    723            rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
    724            best_blk_skip, tx_type_map, discount_color_cost);
    725      }
    726      // K-means clustering.
    727      if (colors == PALETTE_MIN_SIZE) {
    728        // Special case: These colors automatically become the centroids.
    729        assert(colors == 2);
    730        centroids[0] = lower_bound;
    731        centroids[1] = upper_bound;
    732        palette_rd_y(cpi, x, mbmi, bsize, dc_mode_cost, data, centroids, colors,
    733                     color_cache, n_cache, /*do_header_rd_based_gating=*/false,
    734                     best_mbmi, best_palette_color_map, best_rd, rate,
    735                     rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
    736                     best_blk_skip, tx_type_map, NULL, NULL,
    737                     discount_color_cost);
    738      } else {
    739        // Perform k-means palette search in ascending order
    740        last_n_searched = min_n;
    741        perform_k_means_palette_search(
    742            cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
    743            min_n, max_n + 1, 1, do_header_rd_based_gating, &last_n_searched,
    744            color_cache, n_cache, best_mbmi, best_palette_color_map, best_rd,
    745            rate, rate_tokenonly, distortion, skippable, beat_best_rd, ctx,
    746            best_blk_skip, tx_type_map, color_map, rows * cols,
    747            discount_color_cost);
    748        if (last_n_searched < max_n) {
    749          // Search in descending order until we get to the previous best
    750          perform_k_means_palette_search(
    751              cpi, x, mbmi, bsize, dc_mode_cost, data, lower_bound, upper_bound,
    752              max_n, last_n_searched, -1, /*do_header_rd_based_gating=*/false,
    753              &unused, color_cache, n_cache, best_mbmi, best_palette_color_map,
    754              best_rd, rate, rate_tokenonly, distortion, skippable,
    755              beat_best_rd, ctx, best_blk_skip, tx_type_map, color_map,
    756              rows * cols, discount_color_cost);
    757        }
    758      }
    759    }
    760  }
    761 
    762  if (best_mbmi->palette_mode_info.palette_size[0] > 0) {
    763    memcpy(color_map, best_palette_color_map,
    764           block_width * block_height * sizeof(best_palette_color_map[0]));
    765    // Gather the stats to determine whether to use screen content tools in
    766    // function av1_determine_sc_tools_with_encoding().
    767    x->palette_pixels += (block_width * block_height);
    768  }
    769  *mbmi = *best_mbmi;
    770 }
    771 
    772 void av1_rd_pick_palette_intra_sbuv(const AV1_COMP *cpi, MACROBLOCK *x,
    773                                    int dc_mode_cost,
    774                                    uint8_t *best_palette_color_map,
    775                                    MB_MODE_INFO *const best_mbmi,
    776                                    int64_t *best_rd, int *rate,
    777                                    int *rate_tokenonly, int64_t *distortion,
    778                                    uint8_t *skippable) {
    779  MACROBLOCKD *const xd = &x->e_mbd;
    780  MB_MODE_INFO *const mbmi = xd->mi[0];
    781  assert(!is_inter_block(mbmi));
    782  assert(av1_allow_palette(cpi->common.features.allow_screen_content_tools,
    783                           mbmi->bsize));
    784  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
    785  const BLOCK_SIZE bsize = mbmi->bsize;
    786  const SequenceHeader *const seq_params = cpi->common.seq_params;
    787  int this_rate;
    788  int64_t this_rd;
    789  int colors_u, colors_v;
    790  int colors_threshold_u = 0, colors_threshold_v = 0, colors_threshold = 0;
    791  const int src_stride = x->plane[1].src.stride;
    792  const uint8_t *const src_u = x->plane[1].src.buf;
    793  const uint8_t *const src_v = x->plane[2].src.buf;
    794  uint8_t *const color_map = xd->plane[1].color_index_map;
    795  RD_STATS tokenonly_rd_stats;
    796  int plane_block_width, plane_block_height, rows, cols;
    797  av1_get_block_dimensions(bsize, 1, xd, &plane_block_width,
    798                           &plane_block_height, &rows, &cols);
    799 
    800  mbmi->uv_mode = UV_DC_PRED;
    801  if (seq_params->use_highbitdepth) {
    802    int count_buf[1 << 12];      // Maximum (1 << 12) color levels.
    803    int count_buf_8bit[1 << 8];  // Maximum (1 << 8) bins for hbd path.
    804    av1_count_colors_highbd(src_u, src_stride, rows, cols,
    805                            seq_params->bit_depth, count_buf, count_buf_8bit,
    806                            &colors_threshold_u, &colors_u);
    807    av1_count_colors_highbd(src_v, src_stride, rows, cols,
    808                            seq_params->bit_depth, count_buf, count_buf_8bit,
    809                            &colors_threshold_v, &colors_v);
    810  } else {
    811    int count_buf[1 << 8];
    812    av1_count_colors(src_u, src_stride, rows, cols, count_buf, &colors_u);
    813    av1_count_colors(src_v, src_stride, rows, cols, count_buf, &colors_v);
    814    colors_threshold_u = colors_u;
    815    colors_threshold_v = colors_v;
    816  }
    817 
    818  uint16_t color_cache[2 * PALETTE_MAX_SIZE];
    819  const int n_cache = av1_get_palette_cache(xd, 1, color_cache);
    820 
    821  colors_threshold = colors_threshold_u > colors_threshold_v
    822                         ? colors_threshold_u
    823                         : colors_threshold_v;
    824  if (colors_threshold > 1 && colors_threshold <= 64) {
    825    int r, c, n, i, j;
    826    const int max_itr = 50;
    827    int lb_u, ub_u, val_u;
    828    int lb_v, ub_v, val_v;
    829    int16_t *const data = x->palette_buffer->kmeans_data_buf;
    830    int16_t centroids[2 * PALETTE_MAX_SIZE];
    831 
    832    uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src_u);
    833    uint16_t *src_v16 = CONVERT_TO_SHORTPTR(src_v);
    834    if (seq_params->use_highbitdepth) {
    835      lb_u = src_u16[0];
    836      ub_u = src_u16[0];
    837      lb_v = src_v16[0];
    838      ub_v = src_v16[0];
    839    } else {
    840      lb_u = src_u[0];
    841      ub_u = src_u[0];
    842      lb_v = src_v[0];
    843      ub_v = src_v[0];
    844    }
    845 
    846    for (r = 0; r < rows; ++r) {
    847      for (c = 0; c < cols; ++c) {
    848        if (seq_params->use_highbitdepth) {
    849          val_u = src_u16[r * src_stride + c];
    850          val_v = src_v16[r * src_stride + c];
    851          data[(r * cols + c) * 2] = val_u;
    852          data[(r * cols + c) * 2 + 1] = val_v;
    853        } else {
    854          val_u = src_u[r * src_stride + c];
    855          val_v = src_v[r * src_stride + c];
    856          data[(r * cols + c) * 2] = val_u;
    857          data[(r * cols + c) * 2 + 1] = val_v;
    858        }
    859        if (val_u < lb_u)
    860          lb_u = val_u;
    861        else if (val_u > ub_u)
    862          ub_u = val_u;
    863        if (val_v < lb_v)
    864          lb_v = val_v;
    865        else if (val_v > ub_v)
    866          ub_v = val_v;
    867      }
    868    }
    869 
    870    const int colors = colors_u > colors_v ? colors_u : colors_v;
    871    const int max_colors =
    872        colors > PALETTE_MAX_SIZE ? PALETTE_MAX_SIZE : colors;
    873    for (n = PALETTE_MIN_SIZE; n <= max_colors; ++n) {
    874      for (i = 0; i < n; ++i) {
    875        centroids[i * 2] = lb_u + (2 * i + 1) * (ub_u - lb_u) / n / 2;
    876        centroids[i * 2 + 1] = lb_v + (2 * i + 1) * (ub_v - lb_v) / n / 2;
    877      }
    878      av1_k_means(data, centroids, color_map, rows * cols, n, 2, max_itr);
    879      optimize_palette_colors(color_cache, n_cache, n, 2, centroids,
    880                              cpi->common.seq_params->bit_depth);
    881      // Sort the U channel colors in ascending order.
    882      for (i = 0; i < 2 * (n - 1); i += 2) {
    883        int min_idx = i;
    884        int min_val = centroids[i];
    885        for (j = i + 2; j < 2 * n; j += 2)
    886          if (centroids[j] < min_val) min_val = centroids[j], min_idx = j;
    887        if (min_idx != i) {
    888          int temp_u = centroids[i], temp_v = centroids[i + 1];
    889          centroids[i] = centroids[min_idx];
    890          centroids[i + 1] = centroids[min_idx + 1];
    891          centroids[min_idx] = temp_u, centroids[min_idx + 1] = temp_v;
    892        }
    893      }
    894      av1_calc_indices(data, centroids, color_map, rows * cols, n, 2);
    895      extend_palette_color_map(color_map, cols, rows, plane_block_width,
    896                               plane_block_height);
    897      pmi->palette_size[1] = n;
    898      for (i = 1; i < 3; ++i) {
    899        for (j = 0; j < n; ++j) {
    900          if (seq_params->use_highbitdepth)
    901            pmi->palette_colors[i * PALETTE_MAX_SIZE + j] = clip_pixel_highbd(
    902                (int)centroids[j * 2 + i - 1], seq_params->bit_depth);
    903          else
    904            pmi->palette_colors[i * PALETTE_MAX_SIZE + j] =
    905                clip_pixel((int)centroids[j * 2 + i - 1]);
    906        }
    907      }
    908 
    909      if (cpi->sf.intra_sf.early_term_chroma_palette_size_search) {
    910        const int palette_mode_rate =
    911            intra_mode_info_cost_uv(cpi, x, mbmi, bsize, dc_mode_cost);
    912        const int64_t header_rd = RDCOST(x->rdmult, palette_mode_rate, 0);
    913        // Terminate further palette_size search, if header cost corresponding
    914        // to lower palette_size is more than the best_rd.
    915        if (header_rd >= *best_rd) break;
    916        av1_txfm_uvrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
    917        if (tokenonly_rd_stats.rate == INT_MAX) continue;
    918        this_rate = tokenonly_rd_stats.rate + palette_mode_rate;
    919      } else {
    920        av1_txfm_uvrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
    921        if (tokenonly_rd_stats.rate == INT_MAX) continue;
    922        this_rate = tokenonly_rd_stats.rate +
    923                    intra_mode_info_cost_uv(cpi, x, mbmi, bsize, dc_mode_cost);
    924      }
    925 
    926      this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
    927      if (this_rd < *best_rd) {
    928        *best_rd = this_rd;
    929        *best_mbmi = *mbmi;
    930        memcpy(best_palette_color_map, color_map,
    931               plane_block_width * plane_block_height *
    932                   sizeof(best_palette_color_map[0]));
    933        *rate = this_rate;
    934        *distortion = tokenonly_rd_stats.dist;
    935        *rate_tokenonly = tokenonly_rd_stats.rate;
    936        *skippable = tokenonly_rd_stats.skip_txfm;
    937      }
    938    }
    939  }
    940  if (best_mbmi->palette_mode_info.palette_size[1] > 0) {
    941    memcpy(color_map, best_palette_color_map,
    942           plane_block_width * plane_block_height *
    943               sizeof(best_palette_color_map[0]));
    944  }
    945 }
    946 
    947 void av1_restore_uv_color_map(const AV1_COMP *cpi, MACROBLOCK *x) {
    948  MACROBLOCKD *const xd = &x->e_mbd;
    949  MB_MODE_INFO *const mbmi = xd->mi[0];
    950  PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
    951  const BLOCK_SIZE bsize = mbmi->bsize;
    952  int src_stride = x->plane[1].src.stride;
    953  const uint8_t *const src_u = x->plane[1].src.buf;
    954  const uint8_t *const src_v = x->plane[2].src.buf;
    955  int16_t *const data = x->palette_buffer->kmeans_data_buf;
    956  int16_t centroids[2 * PALETTE_MAX_SIZE];
    957  uint8_t *const color_map = xd->plane[1].color_index_map;
    958  int r, c;
    959  const uint16_t *const src_u16 = CONVERT_TO_SHORTPTR(src_u);
    960  const uint16_t *const src_v16 = CONVERT_TO_SHORTPTR(src_v);
    961  int plane_block_width, plane_block_height, rows, cols;
    962  av1_get_block_dimensions(bsize, 1, xd, &plane_block_width,
    963                           &plane_block_height, &rows, &cols);
    964 
    965  for (r = 0; r < rows; ++r) {
    966    for (c = 0; c < cols; ++c) {
    967      if (cpi->common.seq_params->use_highbitdepth) {
    968        data[(r * cols + c) * 2] = src_u16[r * src_stride + c];
    969        data[(r * cols + c) * 2 + 1] = src_v16[r * src_stride + c];
    970      } else {
    971        data[(r * cols + c) * 2] = src_u[r * src_stride + c];
    972        data[(r * cols + c) * 2 + 1] = src_v[r * src_stride + c];
    973      }
    974    }
    975  }
    976 
    977  for (r = 1; r < 3; ++r) {
    978    for (c = 0; c < pmi->palette_size[1]; ++c) {
    979      centroids[c * 2 + r - 1] = pmi->palette_colors[r * PALETTE_MAX_SIZE + c];
    980    }
    981  }
    982 
    983  av1_calc_indices(data, centroids, color_map, rows * cols,
    984                   pmi->palette_size[1], 2);
    985  extend_palette_color_map(color_map, cols, rows, plane_block_width,
    986                           plane_block_height);
    987 }