robust_statistics.h (4844B)
1 // Copyright 2023 Google LLC 2 // SPDX-License-Identifier: Apache-2.0 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #ifndef HIGHWAY_HWY_ROBUST_STATISTICS_H_ 17 #define HIGHWAY_HWY_ROBUST_STATISTICS_H_ 18 19 #include <algorithm> // std::sort, std::find_if 20 #include <limits> 21 #include <utility> // std::pair 22 #include <vector> 23 24 #include "hwy/base.h" 25 26 namespace hwy { 27 namespace robust_statistics { 28 29 // Sorts integral values in ascending order (e.g. for Mode). About 3x faster 30 // than std::sort for input distributions with very few unique values. 31 template <class T> 32 void CountingSort(T* values, size_t num_values) { 33 // Unique values and their frequency (similar to flat_map). 34 using Unique = std::pair<T, int>; 35 std::vector<Unique> unique; 36 for (size_t i = 0; i < num_values; ++i) { 37 const T value = values[i]; 38 const auto pos = 39 std::find_if(unique.begin(), unique.end(), 40 [value](const Unique u) { return u.first == value; }); 41 if (pos == unique.end()) { 42 unique.push_back(std::make_pair(value, 1)); 43 } else { 44 ++pos->second; 45 } 46 } 47 48 // Sort in ascending order of value (pair.first). 49 std::sort(unique.begin(), unique.end()); 50 51 // Write that many copies of each unique value to the array. 52 T* HWY_RESTRICT p = values; 53 for (const auto& value_count : unique) { 54 std::fill(p, p + value_count.second, value_count.first); 55 p += value_count.second; 56 } 57 HWY_ASSERT(p == values + num_values); 58 } 59 60 // @return i in [idx_begin, idx_begin + half_count) that minimizes 61 // sorted[i + half_count] - sorted[i]. 62 template <typename T> 63 size_t MinRange(const T* const HWY_RESTRICT sorted, const size_t idx_begin, 64 const size_t half_count) { 65 T min_range = std::numeric_limits<T>::max(); 66 size_t min_idx = 0; 67 68 for (size_t idx = idx_begin; idx < idx_begin + half_count; ++idx) { 69 HWY_ASSERT(sorted[idx] <= sorted[idx + half_count]); 70 const T range = sorted[idx + half_count] - sorted[idx]; 71 if (range < min_range) { 72 min_range = range; 73 min_idx = idx; 74 } 75 } 76 77 return min_idx; 78 } 79 80 // Returns an estimate of the mode by calling MinRange on successively 81 // halved intervals. "sorted" must be in ascending order. This is the 82 // Half Sample Mode estimator proposed by Bickel in "On a fast, robust 83 // estimator of the mode", with complexity O(N log N). The mode is less 84 // affected by outliers in highly-skewed distributions than the median. 85 // The averaging operation below assumes "T" is an unsigned integer type. 86 template <typename T> 87 T ModeOfSorted(const T* const HWY_RESTRICT sorted, const size_t num_values) { 88 size_t idx_begin = 0; 89 size_t half_count = num_values / 2; 90 while (half_count > 1) { 91 idx_begin = MinRange(sorted, idx_begin, half_count); 92 half_count >>= 1; 93 } 94 95 const T x = sorted[idx_begin + 0]; 96 if (half_count == 0) { 97 return x; 98 } 99 HWY_ASSERT(half_count == 1); 100 const T average = (x + sorted[idx_begin + 1] + 1) / 2; 101 return average; 102 } 103 104 // Returns the mode. Side effect: sorts "values". 105 template <typename T> 106 T Mode(T* values, const size_t num_values) { 107 CountingSort(values, num_values); 108 return ModeOfSorted(values, num_values); 109 } 110 111 template <typename T, size_t N> 112 T Mode(T (&values)[N]) { 113 return Mode(&values[0], N); 114 } 115 116 // Returns the median value. Side effect: sorts "values". 117 template <typename T> 118 T Median(T* values, const size_t num_values) { 119 HWY_ASSERT(num_values != 0); 120 std::sort(values, values + num_values); 121 const size_t half = num_values / 2; 122 // Odd count: return middle 123 if (num_values % 2) { 124 return values[half]; 125 } 126 // Even count: return average of middle two. 127 return (values[half] + values[half - 1] + 1) / 2; 128 } 129 130 // Returns a robust measure of variability. 131 template <typename T> 132 T MedianAbsoluteDeviation(const T* values, const size_t num_values, 133 const T median) { 134 HWY_ASSERT(num_values != 0); 135 std::vector<T> abs_deviations; 136 abs_deviations.reserve(num_values); 137 for (size_t i = 0; i < num_values; ++i) { 138 const int64_t abs = ScalarAbs(static_cast<int64_t>(values[i]) - 139 static_cast<int64_t>(median)); 140 abs_deviations.push_back(static_cast<T>(abs)); 141 } 142 return Median(abs_deviations.data(), num_values); 143 } 144 145 } // namespace robust_statistics 146 } // namespace hwy 147 148 #endif // HIGHWAY_HWY_ROBUST_STATISTICS_H_