reduce_sum_hwy.h (2173B)
1 /* 2 * Copyright (c) 2025, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 #ifndef AOM_AOM_DSP_REDUCE_SUM_HWY_H_ 12 #define AOM_AOM_DSP_REDUCE_SUM_HWY_H_ 13 14 #include <type_traits> 15 #include "third_party/highway/hwy/highway.h" 16 17 HWY_BEFORE_NAMESPACE(); 18 19 namespace { 20 namespace HWY_NAMESPACE { 21 22 namespace hn = hwy::HWY_NAMESPACE; 23 24 template <size_t NumBlocks> 25 struct BlockReduceTraits; 26 27 template <> 28 struct BlockReduceTraits<1> { 29 template <typename D> 30 HWY_ATTR HWY_INLINE static hn::VFromD<D> ReduceSum(D d, hn::VFromD<D> v) { 31 (void)d; 32 return v; 33 } 34 }; 35 36 template <size_t NumBlocks> 37 struct BlockReduceTraits { 38 static_assert(NumBlocks > 1, 39 "Primary template BlockReduceTraits assumes NumBlocks > 1"); 40 static_assert((NumBlocks & (NumBlocks - 1)) == 0, 41 "BlockReduceTraits requires NumBlocks to be a power of 2."); 42 43 template <typename D> 44 HWY_ATTR HWY_INLINE static hn::VFromD<hn::BlockDFromD<D>> ReduceSum( 45 D d, hn::VFromD<D> v) { 46 (void)d; 47 constexpr hn::Half<D> half_d; 48 auto v_half = hn::Add(hn::LowerHalf(half_d, v), hn::UpperHalf(half_d, v)); 49 return BlockReduceTraits<NumBlocks / 2>::ReduceSum(half_d, v_half); 50 } 51 }; 52 53 // ReduceSum across blocks. 54 // For example, with a 4-block vector with 16 lanes of uint32_t: 55 // [a3 b3 c3 d3 a2 b2 c2 d2 a1 b1 c1 d1 a0 b0 c0 d0] 56 // returns a vector with 4 lanes: 57 // [a3+a2+a1+a0 b3+b2+b1+b0 c3+c2+c1+c0 d3+d2+d1+d0] 58 template <typename D> 59 HWY_ATTR HWY_INLINE hn::Vec<hn::BlockDFromD<D>> BlockReduceSum( 60 D int_tag, hn::VFromD<D> v) { 61 return BlockReduceTraits<int_tag.MaxBlocks()>::ReduceSum(int_tag, v); 62 } 63 64 } // namespace HWY_NAMESPACE 65 } // namespace 66 67 HWY_AFTER_NAMESPACE(); 68 69 #endif // AOM_AOM_DSP_REDUCE_SUM_HWY_H_