tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_cluster.cc (12768B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_cluster.h"
      7 
      8 #include <algorithm>
      9 #include <cmath>
     10 #include <limits>
     11 #include <map>
     12 #include <numeric>
     13 #include <queue>
     14 #include <tuple>
     15 
     16 #include "lib/jxl/base/status.h"
     17 
     18 #undef HWY_TARGET_INCLUDE
     19 #define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc"
     20 #include <hwy/foreach_target.h>
     21 #include <hwy/highway.h>
     22 
     23 #include "lib/jxl/base/fast_math-inl.h"
     24 #include "lib/jxl/enc_ans.h"
     25 HWY_BEFORE_NAMESPACE();
     26 namespace jxl {
     27 namespace HWY_NAMESPACE {
     28 
     29 // These templates are not found via ADL.
     30 using hwy::HWY_NAMESPACE::Eq;
     31 using hwy::HWY_NAMESPACE::IfThenZeroElse;
     32 
     33 template <class V>
     34 V Entropy(V count, V inv_total, V total) {
     35  const HWY_CAPPED(float, Histogram::kRounding) d;
     36  const auto zero = Set(d, 0.0f);
     37  // TODO(eustas): why (0 - x) instead of Neg(x)?
     38  return IfThenZeroElse(
     39      Eq(count, total),
     40      Sub(zero, Mul(count, FastLog2f(d, Mul(inv_total, count)))));
     41 }
     42 
     43 void HistogramEntropy(const Histogram& a) {
     44  a.entropy_ = 0.0f;
     45  if (a.total_count_ == 0) return;
     46 
     47  const HWY_CAPPED(float, Histogram::kRounding) df;
     48  const HWY_CAPPED(int32_t, Histogram::kRounding) di;
     49 
     50  const auto inv_tot = Set(df, 1.0f / a.total_count_);
     51  auto entropy_lanes = Zero(df);
     52  auto total = Set(df, a.total_count_);
     53 
     54  for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) {
     55    const auto counts = LoadU(di, &a.data_[i]);
     56    entropy_lanes =
     57        Add(entropy_lanes, Entropy(ConvertTo(df, counts), inv_tot, total));
     58  }
     59  a.entropy_ += GetLane(SumOfLanes(df, entropy_lanes));
     60 }
     61 
     62 float HistogramDistance(const Histogram& a, const Histogram& b) {
     63  if (a.total_count_ == 0 || b.total_count_ == 0) return 0;
     64 
     65  const HWY_CAPPED(float, Histogram::kRounding) df;
     66  const HWY_CAPPED(int32_t, Histogram::kRounding) di;
     67 
     68  const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_));
     69  auto distance_lanes = Zero(df);
     70  auto total = Set(df, a.total_count_ + b.total_count_);
     71 
     72  for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size());
     73       i += Lanes(di)) {
     74    const auto a_counts =
     75        a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di);
     76    const auto b_counts =
     77        b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di);
     78    const auto counts = ConvertTo(df, Add(a_counts, b_counts));
     79    distance_lanes = Add(distance_lanes, Entropy(counts, inv_tot, total));
     80  }
     81  const float total_distance = GetLane(SumOfLanes(df, distance_lanes));
     82  return total_distance - a.entropy_ - b.entropy_;
     83 }
     84 
     85 constexpr const float kInfinity = std::numeric_limits<float>::infinity();
     86 
     87 float HistogramKLDivergence(const Histogram& actual, const Histogram& coding) {
     88  if (actual.total_count_ == 0) return 0;
     89  if (coding.total_count_ == 0) return kInfinity;
     90 
     91  const HWY_CAPPED(float, Histogram::kRounding) df;
     92  const HWY_CAPPED(int32_t, Histogram::kRounding) di;
     93 
     94  const auto coding_inv = Set(df, 1.0f / coding.total_count_);
     95  auto cost_lanes = Zero(df);
     96 
     97  for (size_t i = 0; i < actual.data_.size(); i += Lanes(di)) {
     98    const auto counts = LoadU(di, &actual.data_[i]);
     99    const auto coding_counts =
    100        coding.data_.size() > i ? LoadU(di, &coding.data_[i]) : Zero(di);
    101    const auto coding_probs = Mul(ConvertTo(df, coding_counts), coding_inv);
    102    const auto neg_coding_cost = BitCast(
    103        df,
    104        IfThenZeroElse(Eq(counts, Zero(di)),
    105                       IfThenElse(Eq(coding_counts, Zero(di)),
    106                                  BitCast(di, Set(df, -kInfinity)),
    107                                  BitCast(di, FastLog2f(df, coding_probs)))));
    108    cost_lanes = NegMulAdd(ConvertTo(df, counts), neg_coding_cost, cost_lanes);
    109  }
    110  const float total_cost = GetLane(SumOfLanes(df, cost_lanes));
    111  return total_cost - actual.entropy_;
    112 }
    113 
    114 // First step of a k-means clustering with a fancy distance metric.
    115 Status FastClusterHistograms(const std::vector<Histogram>& in,
    116                             size_t max_histograms, std::vector<Histogram>* out,
    117                             std::vector<uint32_t>* histogram_symbols) {
    118  const size_t prev_histograms = out->size();
    119  out->reserve(max_histograms);
    120  histogram_symbols->clear();
    121  histogram_symbols->resize(in.size(), max_histograms);
    122 
    123  std::vector<float> dists(in.size(), std::numeric_limits<float>::max());
    124  size_t largest_idx = 0;
    125  for (size_t i = 0; i < in.size(); i++) {
    126    if (in[i].total_count_ == 0) {
    127      (*histogram_symbols)[i] = 0;
    128      dists[i] = 0.0f;
    129      continue;
    130    }
    131    HistogramEntropy(in[i]);
    132    if (in[i].total_count_ > in[largest_idx].total_count_) {
    133      largest_idx = i;
    134    }
    135  }
    136 
    137  if (prev_histograms > 0) {
    138    for (size_t j = 0; j < prev_histograms; ++j) {
    139      HistogramEntropy((*out)[j]);
    140    }
    141    for (size_t i = 0; i < in.size(); i++) {
    142      if (dists[i] == 0.0f) continue;
    143      for (size_t j = 0; j < prev_histograms; ++j) {
    144        dists[i] = std::min(HistogramKLDivergence(in[i], (*out)[j]), dists[i]);
    145      }
    146    }
    147    auto max_dist = std::max_element(dists.begin(), dists.end());
    148    if (*max_dist > 0.0f) {
    149      largest_idx = max_dist - dists.begin();
    150    }
    151  }
    152 
    153  constexpr float kMinDistanceForDistinct = 48.0f;
    154  while (out->size() < max_histograms) {
    155    (*histogram_symbols)[largest_idx] = out->size();
    156    out->push_back(in[largest_idx]);
    157    dists[largest_idx] = 0.0f;
    158    largest_idx = 0;
    159    for (size_t i = 0; i < in.size(); i++) {
    160      if (dists[i] == 0.0f) continue;
    161      dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]);
    162      if (dists[i] > dists[largest_idx]) largest_idx = i;
    163    }
    164    if (dists[largest_idx] < kMinDistanceForDistinct) break;
    165  }
    166 
    167  for (size_t i = 0; i < in.size(); i++) {
    168    if ((*histogram_symbols)[i] != max_histograms) continue;
    169    size_t best = 0;
    170    float best_dist = std::numeric_limits<float>::max();
    171    for (size_t j = 0; j < out->size(); j++) {
    172      float dist = j < prev_histograms ? HistogramKLDivergence(in[i], (*out)[j])
    173                                       : HistogramDistance(in[i], (*out)[j]);
    174      if (dist < best_dist) {
    175        best = j;
    176        best_dist = dist;
    177      }
    178    }
    179    JXL_ENSURE(best_dist < std::numeric_limits<float>::max());
    180    if (best >= prev_histograms) {
    181      (*out)[best].AddHistogram(in[i]);
    182      HistogramEntropy((*out)[best]);
    183    }
    184    (*histogram_symbols)[i] = best;
    185  }
    186  return true;
    187 }
    188 
    189 // NOLINTNEXTLINE(google-readability-namespace-comments)
    190 }  // namespace HWY_NAMESPACE
    191 }  // namespace jxl
    192 HWY_AFTER_NAMESPACE();
    193 
    194 #if HWY_ONCE
    195 namespace jxl {
    196 HWY_EXPORT(FastClusterHistograms);  // Local function
    197 HWY_EXPORT(HistogramEntropy);       // Local function
    198 
    199 StatusOr<float> Histogram::PopulationCost() const {
    200  return ANSPopulationCost(data_.data(), data_.size());
    201 }
    202 
    203 float Histogram::ShannonEntropy() const {
    204  HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this);
    205  return entropy_;
    206 }
    207 
    208 namespace {
    209 // -----------------------------------------------------------------------------
    210 // Histogram refinement
    211 
    212 // Reorder histograms in *out so that the new symbols in *symbols come in
    213 // increasing order.
    214 void HistogramReindex(std::vector<Histogram>* out, size_t prev_histograms,
    215                      std::vector<uint32_t>* symbols) {
    216  std::vector<Histogram> tmp(*out);
    217  std::map<int, int> new_index;
    218  for (size_t i = 0; i < prev_histograms; ++i) {
    219    new_index[i] = i;
    220  }
    221  int next_index = prev_histograms;
    222  for (uint32_t symbol : *symbols) {
    223    if (new_index.find(symbol) == new_index.end()) {
    224      new_index[symbol] = next_index;
    225      (*out)[next_index] = tmp[symbol];
    226      ++next_index;
    227    }
    228  }
    229  out->resize(next_index);
    230  for (uint32_t& symbol : *symbols) {
    231    symbol = new_index[symbol];
    232  }
    233 }
    234 
    235 }  // namespace
    236 
    237 // Clusters similar histograms in 'in' together, the selected histograms are
    238 // placed in 'out', and for each index in 'in', *histogram_symbols will
    239 // indicate which of the 'out' histograms is the best approximation.
    240 Status ClusterHistograms(const HistogramParams& params,
    241                         const std::vector<Histogram>& in,
    242                         size_t max_histograms, std::vector<Histogram>* out,
    243                         std::vector<uint32_t>* histogram_symbols) {
    244  size_t prev_histograms = out->size();
    245  max_histograms = std::min(max_histograms, params.max_histograms);
    246  max_histograms = std::min(max_histograms, in.size());
    247  if (params.clustering == HistogramParams::ClusteringType::kFastest) {
    248    max_histograms = std::min(max_histograms, static_cast<size_t>(4));
    249  }
    250 
    251  JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(FastClusterHistograms)(
    252      in, prev_histograms + max_histograms, out, histogram_symbols));
    253 
    254  if (prev_histograms == 0 &&
    255      params.clustering == HistogramParams::ClusteringType::kBest) {
    256    for (auto& histo : *out) {
    257      JXL_ASSIGN_OR_RETURN(
    258          histo.entropy_,
    259          ANSPopulationCost(histo.data_.data(), histo.data_.size()));
    260    }
    261    uint32_t next_version = 2;
    262    std::vector<uint32_t> version(out->size(), 1);
    263    std::vector<uint32_t> renumbering(out->size());
    264    std::iota(renumbering.begin(), renumbering.end(), 0);
    265 
    266    // Try to pair up clusters if doing so reduces the total cost.
    267 
    268    struct HistogramPair {
    269      // validity of a pair: p.version == max(version[i], version[j])
    270      float cost;
    271      uint32_t first;
    272      uint32_t second;
    273      uint32_t version;
    274      // We use > because priority queues sort in *decreasing* order, but we
    275      // want lower cost elements to appear first.
    276      bool operator<(const HistogramPair& other) const {
    277        return std::make_tuple(cost, first, second, version) >
    278               std::make_tuple(other.cost, other.first, other.second,
    279                               other.version);
    280      }
    281    };
    282 
    283    // Create list of all pairs by increasing merging cost.
    284    std::priority_queue<HistogramPair> pairs_to_merge;
    285    for (uint32_t i = 0; i < out->size(); i++) {
    286      for (uint32_t j = i + 1; j < out->size(); j++) {
    287        Histogram histo;
    288        histo.AddHistogram((*out)[i]);
    289        histo.AddHistogram((*out)[j]);
    290        JXL_ASSIGN_OR_RETURN(float cost, ANSPopulationCost(histo.data_.data(),
    291                                                           histo.data_.size()));
    292        cost -= (*out)[i].entropy_ + (*out)[j].entropy_;
    293        // Avoid enqueueing pairs that are not advantageous to merge.
    294        if (cost >= 0) continue;
    295        pairs_to_merge.push(
    296            HistogramPair{cost, i, j, std::max(version[i], version[j])});
    297      }
    298    }
    299 
    300    // Merge the best pair to merge, add new pairs that get formed as a
    301    // consequence.
    302    while (!pairs_to_merge.empty()) {
    303      uint32_t first = pairs_to_merge.top().first;
    304      uint32_t second = pairs_to_merge.top().second;
    305      uint32_t ver = pairs_to_merge.top().version;
    306      pairs_to_merge.pop();
    307      if (ver != std::max(version[first], version[second]) ||
    308          version[first] == 0 || version[second] == 0) {
    309        continue;
    310      }
    311      (*out)[first].AddHistogram((*out)[second]);
    312      JXL_ASSIGN_OR_RETURN(float cost,
    313                           ANSPopulationCost((*out)[first].data_.data(),
    314                                             (*out)[first].data_.size()));
    315      (*out)[first].entropy_ = cost;
    316      for (uint32_t& item : renumbering) {
    317        if (item == second) {
    318          item = first;
    319        }
    320      }
    321      version[second] = 0;
    322      version[first] = next_version++;
    323      for (uint32_t j = 0; j < out->size(); j++) {
    324        if (j == first) continue;
    325        if (version[j] == 0) continue;
    326        Histogram histo;
    327        histo.AddHistogram((*out)[first]);
    328        histo.AddHistogram((*out)[j]);
    329        JXL_ASSIGN_OR_RETURN(float cost, ANSPopulationCost(histo.data_.data(),
    330                                                           histo.data_.size()));
    331        cost -= (*out)[first].entropy_ + (*out)[j].entropy_;
    332        // Avoid enqueueing pairs that are not advantageous to merge.
    333        if (cost >= 0) continue;
    334        pairs_to_merge.push(
    335            HistogramPair{cost, std::min(first, j), std::max(first, j),
    336                          std::max(version[first], version[j])});
    337      }
    338    }
    339    std::vector<uint32_t> reverse_renumbering(out->size(), -1);
    340    size_t num_alive = 0;
    341    for (size_t i = 0; i < out->size(); i++) {
    342      if (version[i] == 0) continue;
    343      (*out)[num_alive++] = (*out)[i];
    344      reverse_renumbering[i] = num_alive - 1;
    345    }
    346    out->resize(num_alive);
    347    for (uint32_t& item : *histogram_symbols) {
    348      item = reverse_renumbering[renumbering[item]];
    349    }
    350  }
    351 
    352  // Convert the context map to a canonical form.
    353  HistogramReindex(out, prev_histograms, histogram_symbols);
    354  return true;
    355 }
    356 
    357 }  // namespace jxl
    358 #endif  // HWY_ONCE