sad_hwy.h (9566B)
1 /* 2 * Copyright (c) 2025, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 #ifndef AOM_AOM_DSP_SAD_HWY_H_ 12 #define AOM_AOM_DSP_SAD_HWY_H_ 13 14 #include "aom_dsp/reduce_sum_hwy.h" 15 #include "third_party/highway/hwy/highway.h" 16 17 HWY_BEFORE_NAMESPACE(); 18 19 namespace { 20 namespace HWY_NAMESPACE { 21 22 namespace hn = hwy::HWY_NAMESPACE; 23 24 template <int BlockWidth> 25 HWY_MAYBE_UNUSED unsigned int SumOfAbsoluteDiff( 26 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, 27 int ref_stride, int h, const uint8_t *second_pred = nullptr) { 28 constexpr hn::CappedTag<uint8_t, BlockWidth> pixel_tag; 29 constexpr hn::Repartition<uint64_t, decltype(pixel_tag)> intermediate_sum_tag; 30 const int vw = hn::Lanes(pixel_tag); 31 auto sum_sad = hn::Zero(intermediate_sum_tag); 32 const bool is_sad_avg = second_pred != nullptr; 33 for (int i = 0; i < h; ++i) { 34 for (int j = 0; j < BlockWidth; j += vw) { 35 auto src_vec = hn::LoadU(pixel_tag, &src_ptr[j]); 36 auto ref_vec = hn::LoadU(pixel_tag, &ref_ptr[j]); 37 if (is_sad_avg) { 38 auto sec_pred_vec = hn::LoadU(pixel_tag, &second_pred[j]); 39 ref_vec = hn::AverageRound(ref_vec, sec_pred_vec); 40 } 41 auto sad = hn::SumsOf8AbsDiff(src_vec, ref_vec); 42 sum_sad = hn::Add(sum_sad, sad); 43 } 44 src_ptr += src_stride; 45 ref_ptr += ref_stride; 46 if (is_sad_avg) { 47 second_pred += BlockWidth; 48 } 49 } 50 return static_cast<unsigned int>( 51 hn::ReduceSum(intermediate_sum_tag, sum_sad)); 52 } 53 54 template <int BlockWidth, int NumRef> 55 HWY_MAYBE_UNUSED void SumOfAbsoluteDiffND(const uint8_t *src_ptr, 56 int src_stride, 57 const uint8_t *const ref_ptr[4], 58 int ref_stride, int h, 59 uint32_t res[4]) { 60 static_assert(NumRef == 3 || NumRef == 4, "NumRef must be 3 or 4."); 61 constexpr hn::CappedTag<uint8_t, BlockWidth> pixel_tag; 62 constexpr hn::Repartition<uint64_t, decltype(pixel_tag)> intermediate_sum_tag; 63 const int vw = hn::Lanes(pixel_tag); 64 auto sum_sad_0 = hn::Zero(intermediate_sum_tag); 65 auto sum_sad_1 = hn::Zero(intermediate_sum_tag); 66 auto sum_sad_2 = hn::Zero(intermediate_sum_tag); 67 auto sum_sad_3 = hn::Zero(intermediate_sum_tag); 68 const uint8_t *ref_0, *ref_1, *ref_2, *ref_3; 69 ref_0 = ref_ptr[0]; 70 ref_1 = ref_ptr[1]; 71 ref_2 = ref_ptr[2]; 72 if (NumRef == 4) { 73 ref_3 = ref_ptr[3]; 74 } 75 for (int i = 0; i < h; ++i) { 76 for (int j = 0; j < BlockWidth; j += vw) { 77 auto src_vec = hn::LoadU(pixel_tag, &src_ptr[j]); 78 auto ref_vec_0 = hn::LoadU(pixel_tag, &ref_0[j]); 79 auto ref_vec_1 = hn::LoadU(pixel_tag, &ref_1[j]); 80 auto ref_vec_2 = hn::LoadU(pixel_tag, &ref_2[j]); 81 auto sad_0 = hn::SumsOf8AbsDiff(src_vec, ref_vec_0); 82 auto sad_1 = hn::SumsOf8AbsDiff(src_vec, ref_vec_1); 83 auto sad_2 = hn::SumsOf8AbsDiff(src_vec, ref_vec_2); 84 sum_sad_0 = hn::Add(sum_sad_0, sad_0); 85 sum_sad_1 = hn::Add(sum_sad_1, sad_1); 86 sum_sad_2 = hn::Add(sum_sad_2, sad_2); 87 if (NumRef == 4) { 88 auto ref_vec_3 = hn::LoadU(pixel_tag, &ref_3[j]); 89 auto sad_3 = hn::SumsOf8AbsDiff(src_vec, ref_vec_3); 90 sum_sad_3 = hn::Add(sum_sad_3, sad_3); 91 } 92 } 93 src_ptr += src_stride; 94 ref_0 += ref_stride; 95 ref_1 += ref_stride; 96 ref_2 += ref_stride; 97 if (NumRef == 4) { 98 ref_3 += ref_stride; 99 } 100 } 101 constexpr hn::Repartition<uint32_t, decltype(pixel_tag)> uint32_tag; 102 auto r02 = hn::InterleaveEven(uint32_tag, hn::BitCast(uint32_tag, sum_sad_0), 103 hn::BitCast(uint32_tag, sum_sad_2)); 104 auto r13 = hn::InterleaveEven(uint32_tag, hn::BitCast(uint32_tag, sum_sad_1), 105 hn::BitCast(uint32_tag, sum_sad_3)); 106 auto r0123 = hn::Add(hn::InterleaveLower(uint32_tag, r02, r13), 107 hn::InterleaveUpper(uint32_tag, r02, r13)); 108 109 auto block_sum = BlockReduceSum(uint32_tag, r0123); 110 constexpr hn::FixedTag<uint32_t, 4> block_sum_tag; 111 hn::StoreU(block_sum, block_sum_tag, res); 112 } 113 114 } // namespace HWY_NAMESPACE 115 } // namespace 116 117 #define FSAD(w, h, suffix) \ 118 extern "C" unsigned int aom_sad##w##x##h##_##suffix( \ 119 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 120 int ref_stride); \ 121 HWY_ATTR unsigned int aom_sad##w##x##h##_##suffix( \ 122 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 123 int ref_stride) { \ 124 return HWY_NAMESPACE::SumOfAbsoluteDiff<w>(src_ptr, src_stride, ref_ptr, \ 125 ref_stride, h); \ 126 } 127 128 #define FSAD_4D(w, h, suffix) \ 129 extern "C" void aom_sad##w##x##h##x4d_##suffix( \ 130 const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \ 131 int ref_stride, uint32_t res[4]); \ 132 HWY_ATTR void aom_sad##w##x##h##x4d_##suffix( \ 133 const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \ 134 int ref_stride, uint32_t res[4]) { \ 135 HWY_NAMESPACE::SumOfAbsoluteDiffND<w, 4>(src_ptr, src_stride, ref_ptr, \ 136 ref_stride, h, res); \ 137 } 138 139 #define FSAD_3D(w, h, suffix) \ 140 extern "C" void aom_sad##w##x##h##x3d_##suffix( \ 141 const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \ 142 int ref_stride, uint32_t res[4]); \ 143 HWY_ATTR void aom_sad##w##x##h##x3d_##suffix( \ 144 const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \ 145 int ref_stride, uint32_t res[4]) { \ 146 HWY_NAMESPACE::SumOfAbsoluteDiffND<w, 3>(src_ptr, src_stride, ref_ptr, \ 147 ref_stride, h, res); \ 148 } 149 150 #define FSAD_SKIP(w, h, suffix) \ 151 extern "C" unsigned int aom_sad_skip_##w##x##h##_##suffix( \ 152 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 153 int ref_stride); \ 154 HWY_ATTR unsigned int aom_sad_skip_##w##x##h##_##suffix( \ 155 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 156 int ref_stride) { \ 157 return 2 * HWY_NAMESPACE::SumOfAbsoluteDiff<w>( \ 158 src_ptr, src_stride * 2, ref_ptr, ref_stride * 2, h / 2); \ 159 } 160 161 #define FSAD_4D_SKIP(w, h, suffix) \ 162 extern "C" void aom_sad_skip_##w##x##h##x4d_##suffix( \ 163 const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \ 164 int ref_stride, uint32_t res[4]); \ 165 HWY_ATTR void aom_sad_skip_##w##x##h##x4d_##suffix( \ 166 const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4], \ 167 int ref_stride, uint32_t res[4]) { \ 168 HWY_NAMESPACE::SumOfAbsoluteDiffND<w, 4>(src_ptr, 2 * src_stride, ref_ptr, \ 169 2 * ref_stride, ((h) >> 1), res); \ 170 res[0] <<= 1; \ 171 res[1] <<= 1; \ 172 res[2] <<= 1; \ 173 res[3] <<= 1; \ 174 } 175 176 #define FSAD_AVG(w, h, suffix) \ 177 extern "C" unsigned int aom_sad##w##x##h##_avg_##suffix( \ 178 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 179 int ref_stride, const uint8_t *second_pred); \ 180 HWY_ATTR unsigned int aom_sad##w##x##h##_avg_##suffix( \ 181 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ 182 int ref_stride, const uint8_t *second_pred) { \ 183 return HWY_NAMESPACE::SumOfAbsoluteDiff<w>(src_ptr, src_stride, ref_ptr, \ 184 ref_stride, h, second_pred); \ 185 } 186 187 #define FOR_EACH_SAD_BLOCK_SIZE(X, suffix) \ 188 X(128, 128, suffix) \ 189 X(128, 64, suffix) \ 190 X(64, 128, suffix) \ 191 X(64, 64, suffix) \ 192 X(64, 32, suffix) 193 194 HWY_AFTER_NAMESPACE(); 195 196 #endif // AOM_AOM_DSP_SAD_HWY_H_