tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

robust_statistics.h (4844B)


      1 // Copyright 2023 Google LLC
      2 // SPDX-License-Identifier: Apache-2.0
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 
     16 #ifndef HIGHWAY_HWY_ROBUST_STATISTICS_H_
     17 #define HIGHWAY_HWY_ROBUST_STATISTICS_H_
     18 
     19 #include <algorithm>  // std::sort, std::find_if
     20 #include <limits>
     21 #include <utility>  // std::pair
     22 #include <vector>
     23 
     24 #include "hwy/base.h"
     25 
     26 namespace hwy {
     27 namespace robust_statistics {
     28 
     29 // Sorts integral values in ascending order (e.g. for Mode). About 3x faster
     30 // than std::sort for input distributions with very few unique values.
     31 template <class T>
     32 void CountingSort(T* values, size_t num_values) {
     33  // Unique values and their frequency (similar to flat_map).
     34  using Unique = std::pair<T, int>;
     35  std::vector<Unique> unique;
     36  for (size_t i = 0; i < num_values; ++i) {
     37    const T value = values[i];
     38    const auto pos =
     39        std::find_if(unique.begin(), unique.end(),
     40                     [value](const Unique u) { return u.first == value; });
     41    if (pos == unique.end()) {
     42      unique.push_back(std::make_pair(value, 1));
     43    } else {
     44      ++pos->second;
     45    }
     46  }
     47 
     48  // Sort in ascending order of value (pair.first).
     49  std::sort(unique.begin(), unique.end());
     50 
     51  // Write that many copies of each unique value to the array.
     52  T* HWY_RESTRICT p = values;
     53  for (const auto& value_count : unique) {
     54    std::fill(p, p + value_count.second, value_count.first);
     55    p += value_count.second;
     56  }
     57  HWY_ASSERT(p == values + num_values);
     58 }
     59 
     60 // @return i in [idx_begin, idx_begin + half_count) that minimizes
     61 // sorted[i + half_count] - sorted[i].
     62 template <typename T>
     63 size_t MinRange(const T* const HWY_RESTRICT sorted, const size_t idx_begin,
     64                const size_t half_count) {
     65  T min_range = std::numeric_limits<T>::max();
     66  size_t min_idx = 0;
     67 
     68  for (size_t idx = idx_begin; idx < idx_begin + half_count; ++idx) {
     69    HWY_ASSERT(sorted[idx] <= sorted[idx + half_count]);
     70    const T range = sorted[idx + half_count] - sorted[idx];
     71    if (range < min_range) {
     72      min_range = range;
     73      min_idx = idx;
     74    }
     75  }
     76 
     77  return min_idx;
     78 }
     79 
     80 // Returns an estimate of the mode by calling MinRange on successively
     81 // halved intervals. "sorted" must be in ascending order. This is the
     82 // Half Sample Mode estimator proposed by Bickel in "On a fast, robust
     83 // estimator of the mode", with complexity O(N log N). The mode is less
     84 // affected by outliers in highly-skewed distributions than the median.
     85 // The averaging operation below assumes "T" is an unsigned integer type.
     86 template <typename T>
     87 T ModeOfSorted(const T* const HWY_RESTRICT sorted, const size_t num_values) {
     88  size_t idx_begin = 0;
     89  size_t half_count = num_values / 2;
     90  while (half_count > 1) {
     91    idx_begin = MinRange(sorted, idx_begin, half_count);
     92    half_count >>= 1;
     93  }
     94 
     95  const T x = sorted[idx_begin + 0];
     96  if (half_count == 0) {
     97    return x;
     98  }
     99  HWY_ASSERT(half_count == 1);
    100  const T average = (x + sorted[idx_begin + 1] + 1) / 2;
    101  return average;
    102 }
    103 
    104 // Returns the mode. Side effect: sorts "values".
    105 template <typename T>
    106 T Mode(T* values, const size_t num_values) {
    107  CountingSort(values, num_values);
    108  return ModeOfSorted(values, num_values);
    109 }
    110 
    111 template <typename T, size_t N>
    112 T Mode(T (&values)[N]) {
    113  return Mode(&values[0], N);
    114 }
    115 
    116 // Returns the median value. Side effect: sorts "values".
    117 template <typename T>
    118 T Median(T* values, const size_t num_values) {
    119  HWY_ASSERT(num_values != 0);
    120  std::sort(values, values + num_values);
    121  const size_t half = num_values / 2;
    122  // Odd count: return middle
    123  if (num_values % 2) {
    124    return values[half];
    125  }
    126  // Even count: return average of middle two.
    127  return (values[half] + values[half - 1] + 1) / 2;
    128 }
    129 
    130 // Returns a robust measure of variability.
    131 template <typename T>
    132 T MedianAbsoluteDeviation(const T* values, const size_t num_values,
    133                          const T median) {
    134  HWY_ASSERT(num_values != 0);
    135  std::vector<T> abs_deviations;
    136  abs_deviations.reserve(num_values);
    137  for (size_t i = 0; i < num_values; ++i) {
    138    const int64_t abs = ScalarAbs(static_cast<int64_t>(values[i]) -
    139                                  static_cast<int64_t>(median));
    140    abs_deviations.push_back(static_cast<T>(abs));
    141  }
    142  return Median(abs_deviations.data(), num_values);
    143 }
    144 
    145 }  // namespace robust_statistics
    146 }  // namespace hwy
    147 
    148 #endif  // HIGHWAY_HWY_ROBUST_STATISTICS_H_