warp_plane_hwy.h (66918B)
1 /* 2 * Copyright (c) 2025, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #ifndef AV1_COMMON_WARP_PLANE_HWY_H_ 13 #define AV1_COMMON_WARP_PLANE_HWY_H_ 14 15 #include "av1/common/warped_motion.h" 16 #include "config/av1_rtcd.h" 17 #include "third_party/highway/hwy/highway.h" 18 19 HWY_BEFORE_NAMESPACE(); 20 21 namespace { 22 namespace HWY_NAMESPACE { 23 24 namespace hn = hwy::HWY_NAMESPACE; 25 26 constexpr hn::ScalableTag<uint8_t> uint8_tag; 27 constexpr hn::ScalableTag<uint16_t> uint16_tag; 28 29 constexpr hn::ScalableTag<int8_t> int8_tag; 30 constexpr hn::ScalableTag<int16_t> int16_tag; 31 constexpr hn::ScalableTag<int32_t> int32_tag; 32 constexpr hn::ScalableTag<int64_t> int64_tag; 33 34 constexpr hn::CappedTag<uint8_t, 32> uint8x32_tag; 35 constexpr hn::CappedTag<int16_t, 16> int16x16_tag; 36 37 constexpr hn::FixedTag<uint8_t, 4> uint8x4_tag; 38 constexpr hn::FixedTag<uint8_t, 8> uint8x8_tag; 39 constexpr hn::FixedTag<uint8_t, 16> uint8x16_tag; 40 constexpr hn::FixedTag<uint16_t, 4> uint16x4_tag; 41 constexpr hn::FixedTag<uint16_t, 8> uint16x8_tag; 42 43 constexpr hn::FixedTag<int8_t, 8> int8x8_tag; 44 constexpr hn::FixedTag<int8_t, 16> int8x16_tag; 45 constexpr hn::FixedTag<int16_t, 8> int16x8_tag; 46 constexpr hn::FixedTag<int32_t, 4> int32x4_tag; 47 constexpr hn::FixedTag<int64_t, 2> int64x2_tag; 48 49 using IVec8 = hn::Vec<decltype(int8_tag)>; 50 using IVec16 = hn::Vec<decltype(int16_tag)>; 51 using IVec32 = hn::Vec<decltype(int32_tag)>; 52 using IVec8x16 = hn::Vec<decltype(int8x16_tag)>; 53 54 template <typename D> 55 HWY_ATTR inline void FilterPixelsHorizontal(D tag, const hn::VFromD<D> src, 56 int16_t *HWY_RESTRICT horz_out, 57 int8_t *HWY_RESTRICT coeff, 58 const IVec16 round_const, 59 const int shift, int row) { 60 constexpr hn::Repartition<int8_t, D> coeff_tag; 61 constexpr hn::Repartition<int16_t, D> result_tag; 62 constexpr hn::Repartition<uint16_t, D> unsigned_result_tag; 63 // N.B. coeffs are stored to support the maximum vector width, which may not 64 // be the vector width being filtered on now. 65 const auto coeff0 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 0); 66 const auto coeff1 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 1); 67 const auto coeff2 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 2); 68 const auto coeff3 = hn::Load(coeff_tag, coeff + hn::MaxLanes(int8_tag) * 3); 69 70 const auto shuffle0 = hn::Dup128VecFromValues( 71 uint8_tag, 0, 2, 2, 4, 4, 6, 6, 8, 1, 3, 3, 5, 5, 7, 7, 9 // 72 ); 73 const auto shuffle1 = hn::Dup128VecFromValues( 74 uint8_tag, 4, 6, 6, 8, 8, 10, 10, 12, 5, 7, 7, 9, 9, 11, 11, 13 // 75 ); 76 const auto shuffle2 = hn::Dup128VecFromValues( 77 uint8_tag, 1, 3, 3, 5, 5, 7, 7, 9, 2, 4, 4, 6, 6, 8, 8, 10 // 78 ); 79 const auto shuffle3 = hn::Dup128VecFromValues( 80 uint8_tag, 5, 7, 7, 9, 9, 11, 11, 13, 6, 8, 8, 10, 10, 12, 12, 14 // 81 ); 82 83 const auto src_0 = 84 hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle0)); 85 const auto src_1 = 86 hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle1)); 87 const auto src_2 = 88 hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle2)); 89 const auto src_3 = 90 hn::TableLookupBytes(src, hn::ResizeBitCast(tag, shuffle3)); 91 92 const auto res_02 = hn::SatWidenMulPairwiseAdd(result_tag, src_0, coeff0); 93 const auto res_46 = hn::SatWidenMulPairwiseAdd(result_tag, src_1, coeff1); 94 const auto res_13 = hn::SatWidenMulPairwiseAdd(result_tag, src_2, coeff2); 95 const auto res_57 = hn::SatWidenMulPairwiseAdd(result_tag, src_3, coeff3); 96 97 const auto res_even = hn::Add(res_02, res_46); 98 const auto res_odd = hn::Add(res_13, res_57); 99 100 const auto res = hn::Add(hn::Add(res_even, res_odd), 101 hn::ResizeBitCast(result_tag, round_const)); 102 103 hn::Store(hn::BitCast(result_tag, 104 hn::ShiftRightSame( 105 hn::BitCast(unsigned_result_tag, res), shift)), 106 result_tag, horz_out + row * hn::MaxLanes(int16x8_tag)); 107 } 108 109 HWY_ATTR HWY_INLINE IVec8x16 LoadAV1Filter8Bit(unsigned int offset) { 110 return hn::LoadN(int8x16_tag, av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS], 111 8); 112 } 113 114 HWY_ATTR HWY_INLINE IVec8 LoadAV1Filter8BitLower(unsigned int offset) { 115 return hn::LoadN(int8_tag, av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS], 116 8); 117 } 118 119 template <int Block> 120 HWY_ATTR HWY_INLINE IVec8 LoadAV1Filter8BitUpper(unsigned int offset, 121 IVec8 src) { 122 return hn::InsertBlock<Block>( 123 src, hn::LoadN(int8x16_tag, 124 av1_filter_8bit[offset >> WARPEDDIFF_PREC_BITS], 8)); 125 } 126 127 HWY_ATTR inline void PrepareHorizontalFilterCoefficients( 128 int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) { 129 auto tmp_0 = LoadAV1Filter8BitLower(sx + 0 * alpha); 130 auto tmp_1 = LoadAV1Filter8BitLower(sx + 1 * alpha); 131 auto tmp_2 = LoadAV1Filter8BitLower(sx + 2 * alpha); 132 auto tmp_3 = LoadAV1Filter8BitLower(sx + 3 * alpha); 133 auto tmp_4 = LoadAV1Filter8BitLower(sx + 4 * alpha); 134 auto tmp_5 = LoadAV1Filter8BitLower(sx + 5 * alpha); 135 auto tmp_6 = LoadAV1Filter8BitLower(sx + 6 * alpha); 136 auto tmp_7 = LoadAV1Filter8BitLower(sx + 7 * alpha); 137 138 if constexpr (int16_tag.MaxBlocks() >= 2) { 139 tmp_0 = LoadAV1Filter8BitUpper<1>(sx + beta + 0 * alpha, tmp_0); 140 tmp_1 = LoadAV1Filter8BitUpper<1>(sx + beta + 1 * alpha, tmp_1); 141 tmp_2 = LoadAV1Filter8BitUpper<1>(sx + beta + 2 * alpha, tmp_2); 142 tmp_3 = LoadAV1Filter8BitUpper<1>(sx + beta + 3 * alpha, tmp_3); 143 tmp_4 = LoadAV1Filter8BitUpper<1>(sx + beta + 4 * alpha, tmp_4); 144 tmp_5 = LoadAV1Filter8BitUpper<1>(sx + beta + 5 * alpha, tmp_5); 145 tmp_6 = LoadAV1Filter8BitUpper<1>(sx + beta + 6 * alpha, tmp_6); 146 tmp_7 = LoadAV1Filter8BitUpper<1>(sx + beta + 7 * alpha, tmp_7); 147 } 148 149 if constexpr (int16_tag.MaxBlocks() >= 3) { 150 tmp_0 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 0 * alpha, tmp_0); 151 tmp_1 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 1 * alpha, tmp_1); 152 tmp_2 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 2 * alpha, tmp_2); 153 tmp_3 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 3 * alpha, tmp_3); 154 tmp_4 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 4 * alpha, tmp_4); 155 tmp_5 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 5 * alpha, tmp_5); 156 tmp_6 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 6 * alpha, tmp_6); 157 tmp_7 = LoadAV1Filter8BitUpper<2>(sx + beta * 2 + 7 * alpha, tmp_7); 158 159 tmp_0 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 0 * alpha, tmp_0); 160 tmp_1 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 1 * alpha, tmp_1); 161 tmp_2 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 2 * alpha, tmp_2); 162 tmp_3 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 3 * alpha, tmp_3); 163 tmp_4 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 4 * alpha, tmp_4); 164 tmp_5 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 5 * alpha, tmp_5); 165 tmp_6 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 6 * alpha, tmp_6); 166 tmp_7 = LoadAV1Filter8BitUpper<3>(sx + beta * 3 + 7 * alpha, tmp_7); 167 } 168 169 const auto tmp_0_16 = hn::BitCast(int16_tag, tmp_0); 170 const auto tmp_1_16 = hn::BitCast(int16_tag, tmp_1); 171 const auto tmp_2_16 = hn::BitCast(int16_tag, tmp_2); 172 const auto tmp_3_16 = hn::BitCast(int16_tag, tmp_3); 173 const auto tmp_4_16 = hn::BitCast(int16_tag, tmp_4); 174 const auto tmp_5_16 = hn::BitCast(int16_tag, tmp_5); 175 const auto tmp_6_16 = hn::BitCast(int16_tag, tmp_6); 176 const auto tmp_7_16 = hn::BitCast(int16_tag, tmp_7); 177 178 const auto tmp_12 = hn::ZipLower(int32_tag, tmp_0_16, tmp_2_16); 179 const auto tmp_13 = hn::ZipLower(int32_tag, tmp_1_16, tmp_3_16); 180 const auto tmp_14 = hn::ZipLower(int32_tag, tmp_4_16, tmp_6_16); 181 const auto tmp_15 = hn::ZipLower(int32_tag, tmp_5_16, tmp_7_16); 182 183 const auto res_0 = hn::ZipLower(int64_tag, tmp_12, tmp_14); 184 const auto res_1 = hn::ZipUpper(int64_tag, tmp_12, tmp_14); 185 const auto res_2 = hn::ZipLower(int64_tag, tmp_13, tmp_15); 186 const auto res_3 = hn::ZipUpper(int64_tag, tmp_13, tmp_15); 187 188 hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_0, res_2)), 189 int8_tag, coeff + hn::MaxLanes(int8_tag) * 0); 190 hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_0, res_2)), 191 int8_tag, coeff + hn::MaxLanes(int8_tag) * 1); 192 hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_1, res_3)), 193 int8_tag, coeff + hn::MaxLanes(int8_tag) * 2); 194 hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_1, res_3)), 195 int8_tag, coeff + hn::MaxLanes(int8_tag) * 3); 196 } 197 198 HWY_ATTR inline void PrepareHorizontalFilterCoefficientsBeta0( 199 int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) { 200 (void)beta; 201 const auto tmp_0 = 202 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 0 * alpha)); 203 const auto tmp_1 = 204 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 1 * alpha)); 205 const auto tmp_2 = 206 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 2 * alpha)); 207 const auto tmp_3 = 208 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 3 * alpha)); 209 const auto tmp_4 = 210 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 4 * alpha)); 211 const auto tmp_5 = 212 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 5 * alpha)); 213 const auto tmp_6 = 214 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 6 * alpha)); 215 const auto tmp_7 = 216 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 7 * alpha)); 217 218 const auto tmp_02 = hn::ZipLower(int32x4_tag, tmp_0, tmp_2); 219 const auto tmp_13 = hn::ZipLower(int32x4_tag, tmp_1, tmp_3); 220 const auto tmp_46 = hn::ZipLower(int32x4_tag, tmp_4, tmp_6); 221 const auto tmp_57 = hn::ZipLower(int32x4_tag, tmp_5, tmp_7); 222 223 const auto broadcast_12 = 224 hn::BroadcastBlock<0>(hn::ResizeBitCast(int32_tag, tmp_02)); 225 const auto broadcast_13 = 226 hn::BroadcastBlock<0>(hn::ResizeBitCast(int32_tag, tmp_13)); 227 const auto broadcast_14 = 228 hn::BroadcastBlock<0>(hn::ResizeBitCast(int32_tag, tmp_46)); 229 const auto broadcast_15 = 230 hn::BroadcastBlock<0>(hn::ResizeBitCast(int32_tag, tmp_57)); 231 232 const auto res_0 = hn::ZipLower(int64_tag, broadcast_12, broadcast_14); 233 const auto res_1 = hn::ZipUpper(int64_tag, broadcast_12, broadcast_14); 234 const auto res_2 = hn::ZipLower(int64_tag, broadcast_13, broadcast_15); 235 const auto res_3 = hn::ZipUpper(int64_tag, broadcast_13, broadcast_15); 236 237 hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_0, res_2)), 238 int8_tag, coeff + hn::MaxLanes(int8_tag) * 0); 239 hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_0, res_2)), 240 int8_tag, coeff + hn::MaxLanes(int8_tag) * 1); 241 hn::Store(hn::BitCast(int8_tag, hn::InterleaveLower(int64_tag, res_1, res_3)), 242 int8_tag, coeff + hn::MaxLanes(int8_tag) * 2); 243 hn::Store(hn::BitCast(int8_tag, hn::InterleaveUpper(int64_tag, res_1, res_3)), 244 int8_tag, coeff + hn::MaxLanes(int8_tag) * 3); 245 } 246 247 HWY_ATTR inline void PrepareHorizontalFilterCoefficientsAlpha0( 248 int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) { 249 (void)alpha; 250 auto tmp_0 = LoadAV1Filter8BitLower(sx); 251 if constexpr (int16_tag.MaxBlocks() >= 2) { 252 tmp_0 = LoadAV1Filter8BitUpper<1>(sx + beta, tmp_0); 253 } 254 if constexpr (int16_tag.MaxBlocks() >= 3) { 255 tmp_0 = LoadAV1Filter8BitUpper<2>(sx + beta * 2, tmp_0); 256 tmp_0 = LoadAV1Filter8BitUpper<3>(sx + beta * 3, tmp_0); 257 } 258 const auto res_0 = hn::BitCast(int16_tag, tmp_0); 259 260 hn::Store(hn::BitCast(int8_tag, hn::Broadcast<0>(res_0)), int8_tag, 261 coeff + hn::MaxLanes(int8_tag) * 0); 262 hn::Store(hn::BitCast(int8_tag, hn::Broadcast<1>(res_0)), int8_tag, 263 coeff + hn::MaxLanes(int8_tag) * 1); 264 hn::Store(hn::BitCast(int8_tag, hn::Broadcast<2>(res_0)), int8_tag, 265 coeff + hn::MaxLanes(int8_tag) * 2); 266 hn::Store(hn::BitCast(int8_tag, hn::Broadcast<3>(res_0)), int8_tag, 267 coeff + hn::MaxLanes(int8_tag) * 3); 268 } 269 270 template <typename D> 271 HWY_ATTR inline void HorizontalFilter(D tag, const hn::VFromD<D> src, 272 int16_t *HWY_RESTRICT horz_out, int sx, 273 int alpha, int beta, int row, 274 const IVec16 round_const, 275 const int reduce_bits_horiz) { 276 HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(int8_tag)]; 277 PrepareHorizontalFilterCoefficients(alpha, beta, sx, coeff); 278 FilterPixelsHorizontal(tag, src, horz_out, coeff, round_const, 279 reduce_bits_horiz, row); 280 } 281 282 HWY_ATTR inline void PrepareLastHorizontalFilterCoefficients( 283 int alpha, int beta, int sx, int8_t *HWY_RESTRICT coeff) { 284 (void)beta; 285 const auto tmp_0 = 286 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 0 * alpha)); 287 const auto tmp_1 = 288 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 1 * alpha)); 289 const auto tmp_2 = 290 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 2 * alpha)); 291 const auto tmp_3 = 292 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 3 * alpha)); 293 const auto tmp_4 = 294 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 4 * alpha)); 295 const auto tmp_5 = 296 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 5 * alpha)); 297 const auto tmp_6 = 298 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 6 * alpha)); 299 const auto tmp_7 = 300 hn::BitCast(int16x8_tag, LoadAV1Filter8Bit(sx + 7 * alpha)); 301 302 const auto tmp_8 = hn::ZipLower(int32x4_tag, tmp_0, tmp_2); 303 const auto tmp_9 = hn::ZipLower(int32x4_tag, tmp_1, tmp_3); 304 const auto tmp_10 = hn::ZipLower(int32x4_tag, tmp_4, tmp_6); 305 const auto tmp_11 = hn::ZipLower(int32x4_tag, tmp_5, tmp_7); 306 307 const auto tmp_12 = hn::ZipLower(int64x2_tag, tmp_8, tmp_10); 308 const auto tmp_13 = hn::ZipUpper(int64x2_tag, tmp_8, tmp_10); 309 const auto tmp_14 = hn::ZipLower(int64x2_tag, tmp_9, tmp_11); 310 const auto tmp_15 = hn::ZipUpper(int64x2_tag, tmp_9, tmp_11); 311 312 const auto tmp_16 = hn::InterleaveLower(int64x2_tag, tmp_12, tmp_14); 313 const auto tmp_17 = hn::InterleaveUpper(int64x2_tag, tmp_12, tmp_14); 314 const auto tmp_18 = hn::InterleaveLower(int64x2_tag, tmp_13, tmp_15); 315 const auto tmp_19 = hn::InterleaveUpper(int64x2_tag, tmp_13, tmp_15); 316 317 const auto tmp_20 = hn::ResizeBitCast(int8_tag, tmp_16); 318 const auto tmp_21 = hn::ResizeBitCast(int8_tag, tmp_17); 319 const auto tmp_22 = hn::ResizeBitCast(int8_tag, tmp_18); 320 const auto tmp_23 = hn::ResizeBitCast(int8_tag, tmp_19); 321 322 hn::Store(hn::BroadcastBlock<0>(tmp_20), int8_tag, 323 coeff + hn::MaxLanes(int8_tag) * 0); 324 hn::Store(hn::BroadcastBlock<0>(tmp_21), int8_tag, 325 coeff + hn::MaxLanes(int8_tag) * 1); 326 hn::Store(hn::BroadcastBlock<0>(tmp_22), int8_tag, 327 coeff + hn::MaxLanes(int8_tag) * 2); 328 hn::Store(hn::BroadcastBlock<0>(tmp_23), int8_tag, 329 coeff + hn::MaxLanes(int8_tag) * 3); 330 } 331 332 template <typename D> 333 HWY_ATTR HWY_INLINE hn::VFromD<D> LoadRowsClamped( 334 D tag, const uint8_t *HWY_RESTRICT ref, const int stride, const int iy, 335 const int height) { 336 constexpr hn::BlockDFromD<D> block_tag; 337 const int iy0 = clamp(iy + 0, 0, height - 1); 338 auto src = hn::ResizeBitCast(tag, hn::LoadU(block_tag, ref + iy0 * stride)); 339 if constexpr (tag.MaxBlocks() >= 2) { 340 const int iy1 = clamp(iy + 1, 0, height - 1); 341 const auto src_1 = hn::LoadU(block_tag, ref + iy1 * stride); 342 src = hn::InsertBlock<1>(src, src_1); 343 } 344 if constexpr (tag.MaxBlocks() >= 3) { 345 const int iy2 = clamp(iy + 2, 0, height - 1); 346 const auto src_2 = hn::LoadU(block_tag, ref + iy2 * stride); 347 const int iy3 = clamp(iy + 3, 0, height - 1); 348 const auto src_3 = hn::LoadU(block_tag, ref + iy3 * stride); 349 src = hn::InsertBlock<2>(src, src_2); 350 src = hn::InsertBlock<3>(src, src_3); 351 } 352 return src; 353 } 354 355 template <void (*PrepareCoeffs)(int alpha, int beta, int sx, 356 int8_t *HWY_RESTRICT coeffs), 357 typename D> 358 HWY_ATTR int WarpHorizontalFilterLoop( 359 D tag, const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out, 360 int stride, int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta, 361 int p_height, int height, int i, const IVec16 round_const, 362 const int reduce_bits_horiz, int k, int8_t *HWY_RESTRICT coeff) { 363 constexpr int kNumRows = tag.MaxBlocks(); 364 for (; k < AOMMIN(8, p_height - i) - kNumRows; k += kNumRows) { 365 const auto src = 366 LoadRowsClamped(tag, ref + ix4 - 7, stride, iy4 + k, height); 367 if constexpr (PrepareCoeffs != nullptr) { 368 int sx = sx4 + beta * (k + 4); 369 PrepareCoeffs(alpha, beta, sx, coeff); 370 } 371 FilterPixelsHorizontal(tag, src, horz_out, coeff, round_const, 372 reduce_bits_horiz, k + 7); 373 } 374 return k; 375 } 376 377 template < 378 bool InnerCoeffUpdate, 379 void (*PrepareCoeffs)(int alpha, int beta, int sx, 380 int8_t *HWY_RESTRICT coeffs), 381 void (*LastPrepareCoeffs)(int alpha, int beta, int sx, 382 int8_t *HWY_RESTRICT coeffs) = PrepareCoeffs> 383 HWY_ATTR inline void WarpHorizontalFilterTemplate( 384 const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out, int stride, 385 int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta, int p_height, 386 int height, int i, const IVec16 round_const, const int reduce_bits_horiz) { 387 int k = -7, iy; 388 HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(int8_tag)]; 389 if constexpr (!InnerCoeffUpdate) { 390 PrepareCoeffs(alpha, beta, sx4, coeff); 391 } 392 if constexpr (uint8_tag.MaxBlocks() >= 3) { 393 k = WarpHorizontalFilterLoop<(InnerCoeffUpdate ? PrepareCoeffs : nullptr)>( 394 uint8_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, 395 height, i, round_const, reduce_bits_horiz, k, coeff); 396 } 397 if constexpr (uint8_tag.MaxBlocks() >= 2) { 398 k = WarpHorizontalFilterLoop<(InnerCoeffUpdate ? PrepareCoeffs : nullptr)>( 399 uint8x32_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, 400 p_height, height, i, round_const, reduce_bits_horiz, k, coeff); 401 } 402 if constexpr (uint8_tag.MaxBlocks() == 1) { 403 k = WarpHorizontalFilterLoop<(InnerCoeffUpdate ? LastPrepareCoeffs 404 : nullptr)>( 405 uint8x16_tag, ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, 406 p_height, height, i, round_const, reduce_bits_horiz, k, coeff); 407 } 408 iy = iy4 + k; 409 iy = clamp(iy, 0, height - 1); 410 const auto src = hn::LoadU(uint8x16_tag, ref + iy * stride + ix4 - 7); 411 if constexpr (InnerCoeffUpdate) { 412 int sx = sx4 + beta * (k + 4); 413 LastPrepareCoeffs(alpha, beta, sx, coeff); 414 } 415 FilterPixelsHorizontal(uint8x16_tag, src, horz_out, coeff, round_const, 416 reduce_bits_horiz, k + 7); 417 } 418 419 HWY_ATTR inline void UnpackWeightsAndSetRoundConst( 420 ConvolveParams *HWY_RESTRICT conv_params, const int round_bits, 421 const int offset_bits, IVec16 &HWY_RESTRICT res_sub_const, 422 IVec16 &HWY_RESTRICT round_bits_const, IVec16 &HWY_RESTRICT wt) { 423 res_sub_const = 424 hn::Set(int16_tag, -(1 << (offset_bits - conv_params->round_1)) - 425 (1 << (offset_bits - conv_params->round_1 - 1))); 426 round_bits_const = hn::Set(int16_tag, ((1 << round_bits) >> 1)); 427 428 const auto w0 = static_cast<int16_t>(conv_params->fwd_offset); 429 const auto w1 = static_cast<int16_t>(conv_params->bck_offset); 430 const auto wt0 = hn::Set(int16_tag, w0); 431 const auto wt1 = hn::Set(int16_tag, w1); 432 wt = hn::InterleaveLower(wt0, wt1); 433 } 434 435 HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilter(size_t offset) { 436 return hn::LoadN(int16_tag, av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS], 437 8); 438 } 439 440 HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilterLower(size_t offset) { 441 return hn::ResizeBitCast( 442 int16_tag, 443 hn::Load(int16x8_tag, av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS])); 444 } 445 446 template <int Block> 447 HWY_ATTR HWY_INLINE IVec16 LoadAV1WarpedFilterUpper(size_t offset, IVec16 src) { 448 return hn::InsertBlock<Block>( 449 src, 450 hn::Load(int16x8_tag, av1_warped_filter[offset >> WARPEDDIFF_PREC_BITS])); 451 } 452 453 HWY_ATTR inline void PrepareVerticalFilterCoeffs(int gamma, int delta, int sy, 454 int16_t *HWY_RESTRICT coeffs) { 455 auto filt_00 = LoadAV1WarpedFilterLower(sy + 0 * gamma); 456 auto filt_01 = LoadAV1WarpedFilterLower(sy + 2 * gamma); 457 auto filt_02 = LoadAV1WarpedFilterLower(sy + 4 * gamma); 458 auto filt_03 = LoadAV1WarpedFilterLower(sy + 6 * gamma); 459 460 if constexpr (int16_tag.MaxBlocks() >= 2) { 461 filt_00 = LoadAV1WarpedFilterUpper<1>(sy + delta + 0 * gamma, filt_00); 462 filt_01 = LoadAV1WarpedFilterUpper<1>(sy + delta + 2 * gamma, filt_01); 463 filt_02 = LoadAV1WarpedFilterUpper<1>(sy + delta + 4 * gamma, filt_02); 464 filt_03 = LoadAV1WarpedFilterUpper<1>(sy + delta + 6 * gamma, filt_03); 465 } 466 467 if constexpr (int16_tag.MaxBlocks() >= 3) { 468 filt_00 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 0 * gamma, filt_00); 469 filt_01 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 2 * gamma, filt_01); 470 filt_02 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 4 * gamma, filt_02); 471 filt_03 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 6 * gamma, filt_03); 472 473 filt_00 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 0 * gamma, filt_00); 474 filt_01 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 2 * gamma, filt_01); 475 filt_02 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 4 * gamma, filt_02); 476 filt_03 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 6 * gamma, filt_03); 477 } 478 479 auto filt_0 = hn::BitCast(int32_tag, filt_00); 480 auto filt_1 = hn::BitCast(int32_tag, filt_01); 481 auto filt_2 = hn::BitCast(int32_tag, filt_02); 482 auto filt_3 = hn::BitCast(int32_tag, filt_03); 483 484 auto res_0 = hn::ZipLower(int64_tag, filt_0, filt_1); 485 auto res_1 = hn::ZipLower(int64_tag, filt_2, filt_3); 486 auto res_2 = hn::ZipUpper(int64_tag, filt_0, filt_1); 487 auto res_3 = hn::ZipUpper(int64_tag, filt_2, filt_3); 488 489 hn::Store( 490 hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)), 491 int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag)); 492 hn::Store( 493 hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)), 494 int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag)); 495 hn::Store( 496 hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)), 497 int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag)); 498 hn::Store( 499 hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)), 500 int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag)); 501 502 filt_00 = LoadAV1WarpedFilterLower(sy + 1 * gamma); 503 filt_01 = LoadAV1WarpedFilterLower(sy + 3 * gamma); 504 filt_02 = LoadAV1WarpedFilterLower(sy + 5 * gamma); 505 filt_03 = LoadAV1WarpedFilterLower(sy + 7 * gamma); 506 507 if constexpr (int16_tag.MaxBlocks() >= 2) { 508 filt_00 = LoadAV1WarpedFilterUpper<1>(sy + delta + 1 * gamma, filt_00); 509 filt_01 = LoadAV1WarpedFilterUpper<1>(sy + delta + 3 * gamma, filt_01); 510 filt_02 = LoadAV1WarpedFilterUpper<1>(sy + delta + 5 * gamma, filt_02); 511 filt_03 = LoadAV1WarpedFilterUpper<1>(sy + delta + 7 * gamma, filt_03); 512 } 513 514 if constexpr (int16_tag.MaxBlocks() >= 3) { 515 filt_00 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 1 * gamma, filt_00); 516 filt_01 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 3 * gamma, filt_01); 517 filt_02 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 5 * gamma, filt_02); 518 filt_03 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta + 7 * gamma, filt_03); 519 520 filt_00 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 1 * gamma, filt_00); 521 filt_01 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 3 * gamma, filt_01); 522 filt_02 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 5 * gamma, filt_02); 523 filt_03 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta + 7 * gamma, filt_03); 524 } 525 526 filt_0 = hn::BitCast(int32_tag, filt_00); 527 filt_1 = hn::BitCast(int32_tag, filt_01); 528 filt_2 = hn::BitCast(int32_tag, filt_02); 529 filt_3 = hn::BitCast(int32_tag, filt_03); 530 531 res_0 = hn::ZipLower(int64_tag, filt_0, filt_1); 532 res_1 = hn::ZipLower(int64_tag, filt_2, filt_3); 533 res_2 = hn::ZipUpper(int64_tag, filt_0, filt_1); 534 res_3 = hn::ZipUpper(int64_tag, filt_2, filt_3); 535 536 hn::Store( 537 hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)), 538 int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag)); 539 hn::Store( 540 hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)), 541 int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag)); 542 hn::Store( 543 hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)), 544 int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag)); 545 hn::Store( 546 hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)), 547 int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag)); 548 } 549 550 HWY_ATTR inline void PrepareVerticalFilterCoeffsDelta0( 551 int gamma, int delta, int sy, int16_t *HWY_RESTRICT coeffs) { 552 (void)delta; 553 auto filt_00 = LoadAV1WarpedFilter(sy + 0 * gamma); 554 auto filt_01 = LoadAV1WarpedFilter(sy + 2 * gamma); 555 auto filt_02 = LoadAV1WarpedFilter(sy + 4 * gamma); 556 auto filt_03 = LoadAV1WarpedFilter(sy + 6 * gamma); 557 558 auto filt_10 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_00)); 559 auto filt_11 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_01)); 560 auto filt_12 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_02)); 561 auto filt_13 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_03)); 562 563 auto res_0 = hn::ZipLower(int64_tag, filt_10, filt_11); 564 auto res_1 = hn::ZipLower(int64_tag, filt_12, filt_13); 565 auto res_2 = hn::ZipUpper(int64_tag, filt_10, filt_11); 566 auto res_3 = hn::ZipUpper(int64_tag, filt_12, filt_13); 567 568 hn::Store( 569 hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)), 570 int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag)); 571 hn::Store( 572 hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)), 573 int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag)); 574 hn::Store( 575 hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)), 576 int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag)); 577 hn::Store( 578 hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)), 579 int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag)); 580 581 filt_00 = LoadAV1WarpedFilter(sy + 1 * gamma); 582 filt_01 = LoadAV1WarpedFilter(sy + 3 * gamma); 583 filt_02 = LoadAV1WarpedFilter(sy + 5 * gamma); 584 filt_03 = LoadAV1WarpedFilter(sy + 7 * gamma); 585 586 filt_10 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_00)); 587 filt_11 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_01)); 588 filt_12 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_02)); 589 filt_13 = hn::BitCast(int32_tag, hn::BroadcastBlock<0>(filt_03)); 590 591 res_0 = hn::ZipLower(int64_tag, filt_10, filt_11); 592 res_1 = hn::ZipLower(int64_tag, filt_12, filt_13); 593 res_2 = hn::ZipUpper(int64_tag, filt_10, filt_11); 594 res_3 = hn::ZipUpper(int64_tag, filt_12, filt_13); 595 596 hn::Store( 597 hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_0, res_1)), 598 int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag)); 599 hn::Store( 600 hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_0, res_1)), 601 int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag)); 602 hn::Store( 603 hn::BitCast(int16_tag, hn::InterleaveLower(int64_tag, res_2, res_3)), 604 int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag)); 605 hn::Store( 606 hn::BitCast(int16_tag, hn::InterleaveUpper(int64_tag, res_2, res_3)), 607 int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag)); 608 } 609 610 HWY_ATTR inline void PrepareVerticalFilterCoeffsGamma0( 611 int gamma, int delta, int sy, int16_t *HWY_RESTRICT coeffs) { 612 (void)gamma; 613 auto filt_0 = LoadAV1WarpedFilterLower(sy); 614 if constexpr (int16_tag.MaxBlocks() >= 2) { 615 filt_0 = LoadAV1WarpedFilterUpper<1>(sy + delta, filt_0); 616 } 617 if constexpr (int16_tag.MaxBlocks() >= 3) { 618 filt_0 = LoadAV1WarpedFilterUpper<2>(sy + 2 * delta, filt_0); 619 filt_0 = LoadAV1WarpedFilterUpper<3>(sy + 3 * delta, filt_0); 620 } 621 auto res_0 = hn::BitCast(int32_tag, filt_0); 622 623 auto broadcast_0 = hn::BitCast(int16_tag, hn::Broadcast<0>(res_0)); 624 auto broadcast_1 = hn::BitCast(int16_tag, hn::Broadcast<1>(res_0)); 625 auto broadcast_2 = hn::BitCast(int16_tag, hn::Broadcast<2>(res_0)); 626 auto broadcast_3 = hn::BitCast(int16_tag, hn::Broadcast<3>(res_0)); 627 628 hn::Store(broadcast_0, int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag)); 629 hn::Store(broadcast_1, int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag)); 630 hn::Store(broadcast_2, int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag)); 631 hn::Store(broadcast_3, int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag)); 632 hn::Store(broadcast_0, int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag)); 633 hn::Store(broadcast_1, int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag)); 634 hn::Store(broadcast_2, int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag)); 635 hn::Store(broadcast_3, int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag)); 636 } 637 638 HWY_ATTR inline void FilterPixelsVertical( 639 int16_t *HWY_RESTRICT horz_out, int16_t *HWY_RESTRICT src_lo, 640 int16_t *HWY_RESTRICT src_hi, int16_t *HWY_RESTRICT coeffs, 641 IVec32 &HWY_RESTRICT res_lo, IVec32 &HWY_RESTRICT res_hi, int row) { 642 if constexpr (int16_tag.MaxBlocks() >= 3) { 643 const auto horz_out_4 = 644 hn::Load(int16_tag, horz_out + (row + 4) * hn::MaxLanes(int16x8_tag)); 645 const auto horz_out_5 = 646 hn::LoadU(int16_tag, horz_out + (row + 5) * hn::MaxLanes(int16x8_tag)); 647 const auto horz_out_6 = 648 hn::LoadU(int16_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag)); 649 const auto horz_out_7 = 650 hn::LoadU(int16_tag, horz_out + (row + 7) * hn::MaxLanes(int16x8_tag)); 651 const auto src_lo_2 = 652 hn::InterleaveLower(int16_tag, horz_out_4, horz_out_5); 653 const auto src_hi_2 = 654 hn::InterleaveUpper(int16_tag, horz_out_4, horz_out_5); 655 const auto src_lo_3 = 656 hn::InterleaveLower(int16_tag, horz_out_6, horz_out_7); 657 const auto src_hi_3 = 658 hn::InterleaveUpper(int16_tag, horz_out_6, horz_out_7); 659 hn::Store(src_lo_2, int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)); 660 hn::Store(src_hi_2, int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)); 661 hn::Store(src_lo_3, int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag)); 662 hn::Store(src_hi_3, int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag)); 663 } else if constexpr (int16_tag.MaxBlocks() == 2) { 664 const auto horz_out_6 = 665 hn::Load(int16_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag)); 666 const auto horz_out_8 = 667 hn::Load(int16_tag, horz_out + (row + 8) * hn::MaxLanes(int16x8_tag)); 668 const auto horz_out_7 = 669 hn::ConcatLowerUpper(int16_tag, horz_out_8, horz_out_6); 670 const auto src_lo_3 = 671 hn::InterleaveLower(int16_tag, horz_out_6, horz_out_7); 672 const auto src_hi_3 = 673 hn::InterleaveUpper(int16_tag, horz_out_6, horz_out_7); 674 hn::Store(src_lo_3, int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag)); 675 hn::Store(src_hi_3, int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag)); 676 } else if constexpr (int16_tag.MaxBlocks() == 1) { 677 const auto horz_out_6 = 678 hn::Load(int16x8_tag, horz_out + (row + 6) * hn::MaxLanes(int16x8_tag)); 679 const auto horz_out_7 = 680 hn::Load(int16x8_tag, horz_out + (row + 7) * hn::MaxLanes(int16x8_tag)); 681 const auto src_lo_3 = 682 hn::InterleaveLower(int16x8_tag, horz_out_6, horz_out_7); 683 const auto src_hi_3 = 684 hn::InterleaveUpper(int16x8_tag, horz_out_6, horz_out_7); 685 hn::Store(src_lo_3, int16x8_tag, src_lo + 3 * hn::MaxLanes(int16x8_tag)); 686 hn::Store(src_hi_3, int16x8_tag, src_hi + 3 * hn::MaxLanes(int16x8_tag)); 687 } 688 689 const auto coeff_0 = 690 hn::Load(int16_tag, coeffs + 0 * hn::MaxLanes(int16_tag)); 691 const auto coeff_1 = 692 hn::Load(int16_tag, coeffs + 1 * hn::MaxLanes(int16_tag)); 693 const auto coeff_2 = 694 hn::Load(int16_tag, coeffs + 2 * hn::MaxLanes(int16_tag)); 695 const auto coeff_3 = 696 hn::Load(int16_tag, coeffs + 3 * hn::MaxLanes(int16_tag)); 697 const auto coeff_4 = 698 hn::Load(int16_tag, coeffs + 4 * hn::MaxLanes(int16_tag)); 699 const auto coeff_5 = 700 hn::Load(int16_tag, coeffs + 5 * hn::MaxLanes(int16_tag)); 701 const auto coeff_6 = 702 hn::Load(int16_tag, coeffs + 6 * hn::MaxLanes(int16_tag)); 703 const auto coeff_7 = 704 hn::Load(int16_tag, coeffs + 7 * hn::MaxLanes(int16_tag)); 705 706 const auto src_lo_0 = 707 hn::Load(int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag)); 708 const auto src_lo_1 = 709 hn::Load(int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag)); 710 const auto src_lo_2 = 711 hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)); 712 const auto src_lo_3 = 713 hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag)); 714 const auto src_hi_0 = 715 hn::Load(int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag)); 716 const auto src_hi_1 = 717 hn::Load(int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag)); 718 const auto src_hi_2 = 719 hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)); 720 const auto src_hi_3 = 721 hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag)); 722 723 auto even_sum0 = hn::Zero(int32_tag); 724 auto even_sum1 = hn::Zero(int32_tag); 725 even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_0, coeff_0, 726 even_sum0, even_sum1); 727 even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_1, coeff_1, 728 even_sum0, even_sum1); 729 even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_2, coeff_2, 730 even_sum0, even_sum1); 731 even_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_lo_3, coeff_3, 732 even_sum0, even_sum1); 733 auto res_even = hn::RearrangeToOddPlusEven(even_sum0, even_sum1); 734 735 auto odd_sum0 = hn::Zero(int32_tag); 736 auto odd_sum1 = hn::Zero(int32_tag); 737 odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_0, coeff_4, 738 odd_sum0, odd_sum1); 739 odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_1, coeff_5, 740 odd_sum0, odd_sum1); 741 odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_2, coeff_6, 742 odd_sum0, odd_sum1); 743 odd_sum0 = hn::ReorderWidenMulAccumulate(int32_tag, src_hi_3, coeff_7, 744 odd_sum0, odd_sum1); 745 auto res_odd = hn::RearrangeToOddPlusEven(odd_sum0, odd_sum1); 746 747 // Rearrange pixels back into the order 0 ... 7 748 res_lo = hn::InterleaveLower(int32_tag, res_even, res_odd); 749 res_hi = hn::InterleaveUpper(int32_tag, res_even, res_odd); 750 } 751 752 template <typename DS, typename DR, typename A, typename B, typename C> 753 HWY_ATTR HWY_INLINE void StoreRows(DS store_tag, DR row_tag, hn::VFromD<DR> vec, 754 A stride, B y, C x, 755 hn::TFromD<DS> *HWY_RESTRICT out) { 756 hn::TFromD<DS> *HWY_RESTRICT pointers[row_tag.MaxBlocks()]; 757 for (int i = 0; i < static_cast<int>(row_tag.MaxBlocks()); ++i) { 758 pointers[i] = &out[(y + i) * stride + x]; 759 } 760 hn::Store(hn::ResizeBitCast(store_tag, hn::ExtractBlock<0>(vec)), store_tag, 761 pointers[0]); 762 if constexpr (row_tag.MaxBlocks() >= 2) { 763 hn::Store(hn::ResizeBitCast(store_tag, hn::ExtractBlock<1>(vec)), store_tag, 764 pointers[1]); 765 } 766 if constexpr (row_tag.MaxBlocks() >= 3) { 767 hn::Store(hn::ResizeBitCast(store_tag, hn::ExtractBlock<2>(vec)), store_tag, 768 pointers[2]); 769 hn::Store(hn::ResizeBitCast(store_tag, hn::ExtractBlock<3>(vec)), store_tag, 770 pointers[3]); 771 } 772 } 773 774 HWY_ATTR HWY_INLINE void StoreVerticalFilterOutput( 775 IVec32 res_lo, IVec32 res_hi, const IVec32 res_add_const, const IVec16 wt, 776 const IVec16 res_sub_const, const IVec16 round_bits_const, 777 uint8_t *HWY_RESTRICT pred, ConvolveParams *HWY_RESTRICT conv_params, int i, 778 int j, int k, const int reduce_bits_vert, int p_stride, int p_width, 779 const int round_bits) { 780 constexpr int kNumRows = uint16_tag.MaxBlocks(); 781 if (conv_params->is_compound) { 782 uint16_t *HWY_RESTRICT pointers[kNumRows]; 783 for (int row = 0; row < kNumRows; ++row) { 784 pointers[row] = 785 &conv_params->dst[(i + k + row) * conv_params->dst_stride + j]; 786 } 787 788 res_lo = 789 hn::ShiftRightSame(hn::Add(res_lo, res_add_const), reduce_bits_vert); 790 791 const auto temp_lo_16 = hn::ReorderDemote2To(uint16_tag, res_lo, res_lo); 792 if (conv_params->do_average) { 793 auto p_16 = 794 hn::ResizeBitCast(uint16_tag, hn::Load(uint16x4_tag, pointers[0])); 795 if constexpr (kNumRows >= 2) { 796 p_16 = hn::InsertBlock<1>( 797 p_16, hn::ResizeBitCast(uint16x8_tag, 798 hn::Load(uint16x4_tag, pointers[1]))); 799 } 800 if constexpr (kNumRows >= 3) { 801 p_16 = hn::InsertBlock<2>( 802 p_16, hn::ResizeBitCast(uint16x8_tag, 803 hn::Load(uint16x4_tag, pointers[2]))); 804 p_16 = hn::InsertBlock<3>( 805 p_16, hn::ResizeBitCast(uint16x8_tag, 806 hn::Load(uint16x4_tag, pointers[3]))); 807 } 808 auto res_lo_16 = hn::Undefined(int16_tag); 809 if (conv_params->use_dist_wtd_comp_avg) { 810 const auto p_16_lo = 811 hn::BitCast(int16_tag, hn::InterleaveLower(p_16, temp_lo_16)); 812 const auto wt_res_lo = hn::WidenMulPairwiseAdd(int32_tag, p_16_lo, wt); 813 const auto shifted_32 = hn::ShiftRight<DIST_PRECISION_BITS>(wt_res_lo); 814 res_lo_16 = hn::BitCast( 815 int16_tag, 816 hn::ReorderDemote2To(uint16_tag, shifted_32, shifted_32)); 817 } else { 818 res_lo_16 = hn::ShiftRight<1>( 819 hn::BitCast(int16_tag, hn::Add(p_16, temp_lo_16))); 820 } 821 res_lo_16 = hn::Add(res_lo_16, res_sub_const); 822 res_lo_16 = 823 hn::ShiftRightSame(hn::Add(res_lo_16, round_bits_const), round_bits); 824 const auto res_8_lo = 825 hn::ReorderDemote2To(uint8_tag, res_lo_16, res_lo_16); 826 StoreRows(uint8x4_tag, uint8_tag, res_8_lo, p_stride, i + k, j, pred); 827 } else { 828 hn::Store( 829 hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<0>(temp_lo_16)), 830 uint16x4_tag, pointers[0]); 831 if constexpr (kNumRows >= 2) { 832 hn::Store( 833 hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<1>(temp_lo_16)), 834 uint16x4_tag, pointers[1]); 835 } 836 if constexpr (kNumRows >= 3) { 837 hn::Store( 838 hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<2>(temp_lo_16)), 839 uint16x4_tag, pointers[2]); 840 hn::Store( 841 hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<3>(temp_lo_16)), 842 uint16x4_tag, pointers[3]); 843 } 844 } 845 if (p_width > 4) { 846 uint16_t *HWY_RESTRICT pointers4[kNumRows]; 847 for (int row = 0; row < kNumRows; ++row) { 848 pointers4[row] = 849 &conv_params->dst[(i + k + row) * conv_params->dst_stride + j + 4]; 850 } 851 res_hi = 852 hn::ShiftRightSame(hn::Add(res_hi, res_add_const), reduce_bits_vert); 853 const auto temp_hi_16 = hn::ReorderDemote2To(uint16_tag, res_hi, res_hi); 854 if (conv_params->do_average) { 855 auto p4_16 = 856 hn::ResizeBitCast(uint16_tag, hn::Load(uint16x4_tag, pointers4[0])); 857 if constexpr (kNumRows >= 2) { 858 p4_16 = hn::InsertBlock<1>( 859 p4_16, hn::ResizeBitCast(uint16x8_tag, 860 hn::Load(uint16x4_tag, pointers4[1]))); 861 } 862 if constexpr (kNumRows >= 3) { 863 p4_16 = hn::InsertBlock<2>( 864 p4_16, hn::ResizeBitCast(uint16x8_tag, 865 hn::Load(uint16x4_tag, pointers4[2]))); 866 p4_16 = hn::InsertBlock<3>( 867 p4_16, hn::ResizeBitCast(uint16x8_tag, 868 hn::Load(uint16x4_tag, pointers4[3]))); 869 } 870 871 auto res_hi_16 = hn::Undefined(int16_tag); 872 if (conv_params->use_dist_wtd_comp_avg) { 873 const auto p_16_hi = 874 hn::BitCast(int16_tag, hn::InterleaveLower(p4_16, temp_hi_16)); 875 const auto wt_res_hi = 876 hn::WidenMulPairwiseAdd(int32_tag, p_16_hi, wt); 877 const auto shifted_32 = 878 hn::ShiftRight<DIST_PRECISION_BITS>(wt_res_hi); 879 res_hi_16 = hn::BitCast( 880 int16_tag, 881 hn::ReorderDemote2To(uint16_tag, shifted_32, shifted_32)); 882 } else { 883 res_hi_16 = hn::ShiftRight<1>( 884 hn::BitCast(int16_tag, hn::Add(p4_16, temp_hi_16))); 885 } 886 res_hi_16 = hn::Add(res_hi_16, res_sub_const); 887 res_hi_16 = hn::ShiftRightSame(hn::Add(res_hi_16, round_bits_const), 888 round_bits); 889 const auto res_8_hi = 890 hn::ReorderDemote2To(uint8_tag, res_hi_16, res_hi_16); 891 StoreRows(uint8x4_tag, uint8_tag, res_8_hi, p_stride, i + k, j + 4, 892 pred); 893 } else { 894 hn::Store(hn::ResizeBitCast(uint16x4_tag, temp_hi_16), uint16x4_tag, 895 pointers4[0]); 896 if constexpr (kNumRows >= 2) { 897 hn::Store( 898 hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<1>(temp_hi_16)), 899 uint16x4_tag, pointers4[1]); 900 } 901 if constexpr (kNumRows >= 3) { 902 hn::Store( 903 hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<2>(temp_hi_16)), 904 uint16x4_tag, pointers4[2]); 905 hn::Store( 906 hn::ResizeBitCast(uint16x4_tag, hn::ExtractBlock<3>(temp_hi_16)), 907 uint16x4_tag, pointers4[3]); 908 } 909 } 910 } 911 } else { 912 const auto res_lo_round = 913 hn::ShiftRightSame(hn::Add(res_lo, res_add_const), reduce_bits_vert); 914 const auto res_hi_round = 915 hn::ShiftRightSame(hn::Add(res_hi, res_add_const), reduce_bits_vert); 916 917 const auto res_16bit = 918 hn::ReorderDemote2To(int16_tag, res_lo_round, res_hi_round); 919 const auto res_8bit = hn::ReorderDemote2To(uint8_tag, res_16bit, res_16bit); 920 // Store, blending with 'pred' if needed 921 if (p_width == 4) { 922 StoreRows(uint8x4_tag, uint8_tag, res_8bit, p_stride, i + k, j, pred); 923 } else { 924 StoreRows(uint8x8_tag, uint8_tag, res_8bit, p_stride, i + k, j, pred); 925 } 926 } 927 } 928 929 template <bool InnerCoeffUpdate, 930 void (*PrepareCoeffs)(int gamma, int delta, int sy, 931 int16_t *HWY_RESTRICT coeffs)> 932 HWY_ATTR inline void WarpVerticalFilterTemplate( 933 uint8_t *HWY_RESTRICT pred, int16_t *HWY_RESTRICT horz_out, 934 ConvolveParams *HWY_RESTRICT conv_params, int16_t gamma, int16_t delta, 935 int p_height, int p_stride, int p_width, int i, int j, int sy4, 936 const int reduce_bits_vert, const IVec32 res_add_const, 937 const int round_bits, const IVec16 res_sub_const, 938 const IVec16 round_bits_const, const IVec16 wt) { 939 HWY_ALIGN int16_t src_lo[4 * hn::MaxLanes(int16_tag)]; 940 HWY_ALIGN int16_t src_hi[4 * hn::MaxLanes(int16_tag)]; 941 if constexpr (int16_tag.MaxBlocks() >= 3) { 942 const auto horz_out_0 = 943 hn::Load(int16_tag, horz_out + 0 * hn::MaxLanes(int16x8_tag)); 944 const auto horz_out_1 = 945 hn::LoadU(int16_tag, horz_out + 1 * hn::MaxLanes(int16x8_tag)); 946 const auto horz_out_2 = 947 hn::LoadU(int16_tag, horz_out + 2 * hn::MaxLanes(int16x8_tag)); 948 const auto horz_out_3 = 949 hn::LoadU(int16_tag, horz_out + 3 * hn::MaxLanes(int16x8_tag)); 950 hn::Store(hn::InterleaveLower(int16_tag, horz_out_0, horz_out_1), int16_tag, 951 src_lo + 0 * hn::MaxLanes(int16_tag)); 952 hn::Store(hn::InterleaveUpper(int16_tag, horz_out_0, horz_out_1), int16_tag, 953 src_hi + 0 * hn::MaxLanes(int16_tag)); 954 hn::Store(hn::InterleaveLower(int16_tag, horz_out_2, horz_out_3), int16_tag, 955 src_lo + 1 * hn::MaxLanes(int16_tag)); 956 hn::Store(hn::InterleaveUpper(int16_tag, horz_out_2, horz_out_3), int16_tag, 957 src_hi + 1 * hn::MaxLanes(int16_tag)); 958 } else if constexpr (int16_tag.MaxBlocks() == 2) { 959 const auto horz_out_0 = 960 hn::Load(int16_tag, horz_out + 0 * hn::MaxLanes(int16_tag)); 961 const auto horz_out_2 = 962 hn::Load(int16_tag, horz_out + 1 * hn::MaxLanes(int16_tag)); 963 const auto horz_out_4 = 964 hn::Load(int16_tag, horz_out + 2 * hn::MaxLanes(int16_tag)); 965 const auto horz_out_6 = 966 hn::Load(int16_tag, horz_out + 3 * hn::MaxLanes(int16_tag)); 967 const auto horz_out_1 = 968 hn::ConcatLowerUpper(int16_tag, horz_out_2, horz_out_0); 969 const auto horz_out_3 = 970 hn::ConcatLowerUpper(int16_tag, horz_out_4, horz_out_2); 971 const auto horz_out_5 = 972 hn::ConcatLowerUpper(int16_tag, horz_out_6, horz_out_4); 973 hn::Store(hn::InterleaveLower(int16_tag, horz_out_0, horz_out_1), int16_tag, 974 src_lo + 0 * hn::MaxLanes(int16_tag)); 975 hn::Store(hn::InterleaveUpper(int16_tag, horz_out_0, horz_out_1), int16_tag, 976 src_hi + 0 * hn::MaxLanes(int16_tag)); 977 hn::Store(hn::InterleaveLower(int16_tag, horz_out_2, horz_out_3), int16_tag, 978 src_lo + 1 * hn::MaxLanes(int16_tag)); 979 hn::Store(hn::InterleaveUpper(int16_tag, horz_out_2, horz_out_3), int16_tag, 980 src_hi + 1 * hn::MaxLanes(int16_tag)); 981 hn::Store(hn::InterleaveLower(int16_tag, horz_out_4, horz_out_5), int16_tag, 982 src_lo + 2 * hn::MaxLanes(int16_tag)); 983 hn::Store(hn::InterleaveUpper(int16_tag, horz_out_4, horz_out_5), int16_tag, 984 src_hi + 2 * hn::MaxLanes(int16_tag)); 985 } else { 986 const auto horz_out_0 = 987 hn::Load(int16_tag, horz_out + 0 * hn::MaxLanes(int16_tag)); 988 const auto horz_out_1 = 989 hn::Load(int16_tag, horz_out + 1 * hn::MaxLanes(int16_tag)); 990 const auto horz_out_2 = 991 hn::Load(int16_tag, horz_out + 2 * hn::MaxLanes(int16_tag)); 992 const auto horz_out_3 = 993 hn::Load(int16_tag, horz_out + 3 * hn::MaxLanes(int16_tag)); 994 const auto horz_out_4 = 995 hn::Load(int16_tag, horz_out + 4 * hn::MaxLanes(int16_tag)); 996 const auto horz_out_5 = 997 hn::Load(int16_tag, horz_out + 5 * hn::MaxLanes(int16_tag)); 998 hn::Store(hn::InterleaveLower(int16_tag, horz_out_0, horz_out_1), int16_tag, 999 src_lo + 0 * hn::MaxLanes(int16_tag)); 1000 hn::Store(hn::InterleaveUpper(int16_tag, horz_out_0, horz_out_1), int16_tag, 1001 src_hi + 0 * hn::MaxLanes(int16_tag)); 1002 hn::Store(hn::InterleaveLower(int16_tag, horz_out_2, horz_out_3), int16_tag, 1003 src_lo + 1 * hn::MaxLanes(int16_tag)); 1004 hn::Store(hn::InterleaveUpper(int16_tag, horz_out_2, horz_out_3), int16_tag, 1005 src_hi + 1 * hn::MaxLanes(int16_tag)); 1006 hn::Store(hn::InterleaveLower(int16_tag, horz_out_4, horz_out_5), int16_tag, 1007 src_lo + 2 * hn::MaxLanes(int16_tag)); 1008 hn::Store(hn::InterleaveUpper(int16_tag, horz_out_4, horz_out_5), int16_tag, 1009 src_hi + 2 * hn::MaxLanes(int16_tag)); 1010 } 1011 1012 HWY_ALIGN int16_t coeffs[8 * hn::MaxLanes(int16_tag)]; 1013 if constexpr (!InnerCoeffUpdate) { 1014 PrepareCoeffs(gamma, delta, sy4, coeffs); 1015 } 1016 1017 for (int k = -4; k < AOMMIN(4, p_height - i - 4); 1018 k += static_cast<int>(int16_tag.MaxBlocks())) { 1019 if constexpr (InnerCoeffUpdate) { 1020 int sy = sy4 + delta * (k + 4); 1021 PrepareCoeffs(gamma, delta, sy, coeffs); 1022 } 1023 1024 IVec32 res_lo, res_hi; 1025 FilterPixelsVertical(horz_out, src_lo, src_hi, coeffs, res_lo, res_hi, 1026 k + 4); 1027 StoreVerticalFilterOutput(res_lo, res_hi, res_add_const, wt, res_sub_const, 1028 round_bits_const, pred, conv_params, i, j, k + 4, 1029 reduce_bits_vert, p_stride, p_width, round_bits); 1030 1031 if constexpr (int16_tag.MaxBlocks() >= 3) { 1032 hn::Store(hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)), 1033 int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag)); 1034 hn::Store(hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag)), 1035 int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag)); 1036 hn::Store(hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)), 1037 int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag)); 1038 hn::Store(hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag)), 1039 int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag)); 1040 } else if constexpr (int16_tag.MaxBlocks() == 2) { 1041 hn::Store(hn::Load(int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag)), 1042 int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag)); 1043 hn::Store(hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)), 1044 int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag)); 1045 hn::Store(hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag)), 1046 int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)); 1047 hn::Store(hn::Load(int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag)), 1048 int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag)); 1049 hn::Store(hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)), 1050 int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag)); 1051 hn::Store(hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag)), 1052 int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)); 1053 } else if constexpr (int16_tag.MaxBlocks() == 1) { 1054 const auto src_lo_0 = 1055 hn::Load(int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag)); 1056 const auto src_lo_1 = 1057 hn::Load(int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag)); 1058 const auto src_lo_2 = 1059 hn::Load(int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)); 1060 const auto src_lo_3 = 1061 hn::Load(int16_tag, src_lo + 3 * hn::MaxLanes(int16_tag)); 1062 const auto src_lo_0_new = hn::InterleaveEven( 1063 hn::ShiftRightLanes<1>(int16_tag, src_lo_0), src_lo_1); 1064 const auto src_lo_1_new = hn::InterleaveEven( 1065 hn::ShiftRightLanes<1>(int16_tag, src_lo_1), src_lo_2); 1066 const auto src_lo_2_new = hn::InterleaveEven( 1067 hn::ShiftRightLanes<1>(int16_tag, src_lo_2), src_lo_3); 1068 hn::Store(src_lo_0_new, int16_tag, src_lo + 0 * hn::MaxLanes(int16_tag)); 1069 hn::Store(src_lo_1_new, int16_tag, src_lo + 1 * hn::MaxLanes(int16_tag)); 1070 hn::Store(src_lo_2_new, int16_tag, src_lo + 2 * hn::MaxLanes(int16_tag)); 1071 const auto src_hi_0 = 1072 hn::Load(int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag)); 1073 const auto src_hi_1 = 1074 hn::Load(int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag)); 1075 const auto src_hi_2 = 1076 hn::Load(int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)); 1077 const auto src_hi_3 = 1078 hn::Load(int16_tag, src_hi + 3 * hn::MaxLanes(int16_tag)); 1079 const auto src_hi_0_new = hn::InterleaveEven( 1080 hn::ShiftRightLanes<1>(int16_tag, src_hi_0), src_hi_1); 1081 const auto src_hi_1_new = hn::InterleaveEven( 1082 hn::ShiftRightLanes<1>(int16_tag, src_hi_1), src_hi_2); 1083 const auto src_hi_2_new = hn::InterleaveEven( 1084 hn::ShiftRightLanes<1>(int16_tag, src_hi_2), src_hi_3); 1085 hn::Store(src_hi_0_new, int16_tag, src_hi + 0 * hn::MaxLanes(int16_tag)); 1086 hn::Store(src_hi_1_new, int16_tag, src_hi + 1 * hn::MaxLanes(int16_tag)); 1087 hn::Store(src_hi_2_new, int16_tag, src_hi + 2 * hn::MaxLanes(int16_tag)); 1088 } 1089 } 1090 } 1091 1092 HWY_ATTR inline void PrepareWarpVerticalFilter( 1093 uint8_t *HWY_RESTRICT pred, int16_t *HWY_RESTRICT horz_out, 1094 ConvolveParams *HWY_RESTRICT conv_params, int16_t gamma, int16_t delta, 1095 int p_height, int p_stride, int p_width, int i, int j, int sy4, 1096 const int reduce_bits_vert, const IVec32 res_add_const, 1097 const int round_bits, const IVec16 res_sub_const, 1098 const IVec16 round_bits_const, const IVec16 wt) { 1099 if (gamma == 0 && delta == 0) 1100 WarpVerticalFilterTemplate<false, PrepareVerticalFilterCoeffsGamma0>( 1101 pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width, 1102 i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const, 1103 round_bits_const, wt); 1104 else if (gamma == 0 && delta != 0) 1105 WarpVerticalFilterTemplate<true, PrepareVerticalFilterCoeffsGamma0>( 1106 pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width, 1107 i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const, 1108 round_bits_const, wt); 1109 else if (gamma != 0 && delta == 0) 1110 WarpVerticalFilterTemplate<false, PrepareVerticalFilterCoeffsDelta0>( 1111 pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width, 1112 i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const, 1113 round_bits_const, wt); 1114 else 1115 WarpVerticalFilterTemplate<true, PrepareVerticalFilterCoeffs>( 1116 pred, horz_out, conv_params, gamma, delta, p_height, p_stride, p_width, 1117 i, j, sy4, reduce_bits_vert, res_add_const, round_bits, res_sub_const, 1118 round_bits_const, wt); 1119 } 1120 1121 HWY_ATTR inline void PrepareWarpHorizontalFilter( 1122 const uint8_t *HWY_RESTRICT ref, int16_t *HWY_RESTRICT horz_out, int stride, 1123 int32_t ix4, int32_t iy4, int32_t sx4, int alpha, int beta, int p_height, 1124 int height, int i, const IVec16 round_const, const int reduce_bits_horiz) { 1125 if (alpha == 0 && beta == 0) 1126 WarpHorizontalFilterTemplate<false, 1127 PrepareHorizontalFilterCoefficientsAlpha0>( 1128 ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i, 1129 round_const, reduce_bits_horiz); 1130 else if (alpha == 0 && beta != 0) 1131 WarpHorizontalFilterTemplate<true, 1132 PrepareHorizontalFilterCoefficientsAlpha0>( 1133 ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i, 1134 round_const, reduce_bits_horiz); 1135 else if (alpha != 0 && beta == 0) 1136 WarpHorizontalFilterTemplate<false, 1137 PrepareHorizontalFilterCoefficientsBeta0>( 1138 ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i, 1139 round_const, reduce_bits_horiz); 1140 else 1141 WarpHorizontalFilterTemplate<true, PrepareHorizontalFilterCoefficients, 1142 PrepareLastHorizontalFilterCoefficients>( 1143 ref, horz_out, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i, 1144 round_const, reduce_bits_horiz); 1145 } 1146 1147 template <typename D> 1148 HWY_ATTR HWY_INLINE int WarpHorizontalFilterOutOfBoundsSetLoop( 1149 D tag, const uint8_t *HWY_RESTRICT ref, int height, int stride, 1150 int p_height, int i, int iy4, int16_t const4, int16_t const5, int offset, 1151 int k, int16_t *HWY_RESTRICT horz_out) { 1152 constexpr int kNumRows = tag.MaxBlocks(); 1153 for (; k < AOMMIN(8, p_height - i) - kNumRows; k += kNumRows) { 1154 int iy = clamp(iy4 + k + 0, 0, height - 1); 1155 auto src = hn::ResizeBitCast( 1156 tag, hn::Set(int16x8_tag, const4 + ref[iy * stride + offset] * const5)); 1157 if constexpr (kNumRows >= 2) { 1158 iy = clamp(iy4 + k + 1, 0, height - 1); 1159 src = hn::InsertBlock<1>( 1160 src, 1161 hn::Set(int16x8_tag, const4 + ref[iy * stride + offset] * const5)); 1162 } 1163 if constexpr (kNumRows >= 3) { 1164 iy = clamp(iy4 + k + 2, 0, height - 1); 1165 src = hn::InsertBlock<2>( 1166 src, 1167 hn::Set(int16x8_tag, const4 + ref[iy * stride + offset] * const5)); 1168 iy = clamp(iy4 + k + 3, 0, height - 1); 1169 src = hn::InsertBlock<3>( 1170 src, 1171 hn::Set(int16x8_tag, const4 + ref[iy * stride + offset] * const5)); 1172 } 1173 hn::Store(src, tag, horz_out + (k + 7) * hn::MaxLanes(int16x8_tag)); 1174 } 1175 return k; 1176 } 1177 1178 HWY_ATTR void WarpHorizontalFilterOutOfBoundsSet( 1179 const uint8_t *HWY_RESTRICT ref, int height, int stride, int p_height, 1180 int i, int iy4, int16_t const4, int16_t const5, int offset, 1181 int16_t *HWY_RESTRICT horz_out) { 1182 int k = -7, iy; 1183 if constexpr (int16_tag.MaxBlocks() >= 3) { 1184 k = WarpHorizontalFilterOutOfBoundsSetLoop(int16_tag, ref, height, stride, 1185 p_height, i, iy4, const4, const5, 1186 offset, k, horz_out); 1187 } 1188 if constexpr (int16_tag.MaxBlocks() >= 2) { 1189 k = WarpHorizontalFilterOutOfBoundsSetLoop(int16x16_tag, ref, height, 1190 stride, p_height, i, iy4, const4, 1191 const5, offset, k, horz_out); 1192 } 1193 if constexpr (int16_tag.MaxBlocks() == 1) { 1194 k = WarpHorizontalFilterOutOfBoundsSetLoop(int16x8_tag, ref, height, stride, 1195 p_height, i, iy4, const4, const5, 1196 offset, k, horz_out); 1197 } 1198 iy = iy4 + k; 1199 iy = clamp(iy4 + k, 0, height - 1); 1200 hn::Store(hn::Set(int16x8_tag, const4 + ref[iy * stride + offset] * const5), 1201 int16x8_tag, horz_out + (k + 7) * hn::MaxLanes(int16x8_tag)); 1202 } 1203 1204 template <typename D> 1205 HWY_ATTR int WarpHorizontalFilterOutOfBoundsPadLoop( 1206 D tag, const uint8_t *HWY_RESTRICT ref, int stride, int32_t ix4, 1207 int32_t iy4, int32_t sx4, int alpha, int beta, int p_height, int height, 1208 int i, const IVec16 round_const, const int reduce_bits_horiz, 1209 int out_of_boundary_left, int out_of_boundary_right, int k, 1210 int16_t *HWY_RESTRICT horz_out) { 1211 constexpr int kNumRows = tag.MaxBlocks(); 1212 for (; k < (AOMMIN(8, p_height - i) - kNumRows); k += kNumRows) { 1213 auto src = LoadRowsClamped(tag, ref + ix4 - 7, stride, iy4 + k, height); 1214 if (out_of_boundary_left >= 0) { 1215 const auto shuffle_reg_left = 1216 hn::LoadDup128(tag, warp_pad_left[out_of_boundary_left]); 1217 src = hn::TableLookupBytes(src, shuffle_reg_left); 1218 } 1219 if (out_of_boundary_right >= 0) { 1220 const auto shuffle_reg_right = 1221 hn::LoadDup128(tag, warp_pad_right[out_of_boundary_right]); 1222 src = hn::TableLookupBytes(src, shuffle_reg_right); 1223 } 1224 int sx = sx4 + beta * (k + 4); 1225 HorizontalFilter(tag, src, horz_out, sx, alpha, beta, k + 7, round_const, 1226 reduce_bits_horiz); 1227 } 1228 return k; 1229 } 1230 1231 HWY_ATTR void WarpHorizontalFilterOutOfBoundsPad( 1232 const uint8_t *HWY_RESTRICT ref, int stride, int32_t ix4, int32_t iy4, 1233 int32_t sx4, int alpha, int beta, int p_height, int width, int height, 1234 int i, const IVec16 round_const, const int reduce_bits_horiz, 1235 int16_t *HWY_RESTRICT horz_out) { 1236 const int out_of_boundary_left = -(ix4 - 6); 1237 const int out_of_boundary_right = (ix4 + 8) - width; 1238 int k = -7, iy, sx; 1239 if constexpr (uint8_tag.MaxBlocks() >= 3) { 1240 k = WarpHorizontalFilterOutOfBoundsPadLoop( 1241 uint8_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i, 1242 round_const, reduce_bits_horiz, out_of_boundary_left, 1243 out_of_boundary_right, k, horz_out); 1244 } 1245 if constexpr (uint8_tag.MaxBlocks() >= 2) { 1246 k = WarpHorizontalFilterOutOfBoundsPadLoop( 1247 uint8x32_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height, 1248 i, round_const, reduce_bits_horiz, out_of_boundary_left, 1249 out_of_boundary_right, k, horz_out); 1250 } 1251 if constexpr (uint8_tag.MaxBlocks() == 1) { 1252 k = WarpHorizontalFilterOutOfBoundsPadLoop( 1253 uint8_tag, ref, stride, ix4, iy4, sx4, alpha, beta, p_height, height, i, 1254 round_const, reduce_bits_horiz, out_of_boundary_left, 1255 out_of_boundary_right, k, horz_out); 1256 } 1257 iy = iy4 + k; 1258 iy = clamp(iy, 0, height - 1); 1259 auto src = hn::LoadU(uint8x16_tag, ref + iy * stride + ix4 - 7); 1260 if (out_of_boundary_left >= 0) { 1261 const auto shuffle_reg_left = 1262 hn::LoadU(uint8x16_tag, warp_pad_left[out_of_boundary_left]); 1263 src = hn::TableLookupBytes(src, shuffle_reg_left); 1264 } 1265 if (out_of_boundary_right >= 0) { 1266 const auto shuffle_reg_right = 1267 hn::LoadU(uint8x16_tag, warp_pad_right[out_of_boundary_right]); 1268 src = hn::TableLookupBytes(src, shuffle_reg_right); 1269 } 1270 sx = sx4 + beta * (k + 4); 1271 HWY_ALIGN int8_t coeff[4 * hn::MaxLanes(int8_tag)]; 1272 PrepareLastHorizontalFilterCoefficients(alpha, beta, sx, coeff); 1273 FilterPixelsHorizontal(uint8x16_tag, src, horz_out, coeff, round_const, 1274 reduce_bits_horiz, k + 7); 1275 } 1276 1277 HWY_ATTR void WarpAffine(const int32_t *HWY_RESTRICT mat, 1278 const uint8_t *HWY_RESTRICT ref, int width, int height, 1279 int stride, uint8_t *HWY_RESTRICT pred, int p_col, 1280 int p_row, int p_width, int p_height, int p_stride, 1281 int subsampling_x, int subsampling_y, 1282 ConvolveParams *HWY_RESTRICT conv_params, 1283 int16_t alpha, int16_t beta, int16_t gamma, 1284 int16_t delta) { 1285 int i, j; 1286 const int bd = 8; 1287 const int reduce_bits_horiz = conv_params->round_0; 1288 const int reduce_bits_vert = conv_params->is_compound 1289 ? conv_params->round_1 1290 : 2 * FILTER_BITS - reduce_bits_horiz; 1291 const int offset_bits_horiz = bd + FILTER_BITS - 1; 1292 assert(IMPLIES(conv_params->is_compound, conv_params->dst != NULL)); 1293 1294 const int offset_bits_vert = bd + 2 * FILTER_BITS - reduce_bits_horiz; 1295 const auto reduce_bits_vert_const = 1296 hn::Set(int32_tag, ((1 << reduce_bits_vert) >> 1)); 1297 const auto res_add_const = hn::Set(int32_tag, 1 << offset_bits_vert); 1298 const int round_bits = 1299 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1; 1300 const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0; 1301 assert(IMPLIES(conv_params->do_average, conv_params->is_compound)); 1302 1303 const auto round_const = hn::Set( 1304 int16_tag, (1 << offset_bits_horiz) + ((1 << reduce_bits_horiz) >> 1)); 1305 1306 IVec16 res_sub_const, round_bits_const, wt; 1307 UnpackWeightsAndSetRoundConst(conv_params, round_bits, offset_bits, 1308 res_sub_const, round_bits_const, wt); 1309 1310 IVec32 res_add_const_1; 1311 if (conv_params->is_compound == 1) { 1312 res_add_const_1 = hn::Add(reduce_bits_vert_const, res_add_const); 1313 } else { 1314 res_add_const_1 = hn::Set(int32_tag, -(1 << (bd + reduce_bits_vert - 1)) + 1315 ((1 << reduce_bits_vert) >> 1)); 1316 } 1317 const int32_t const1 = alpha * (-4) + beta * (-4) + 1318 (1 << (WARPEDDIFF_PREC_BITS - 1)) + 1319 (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS); 1320 const int32_t const2 = gamma * (-4) + delta * (-4) + 1321 (1 << (WARPEDDIFF_PREC_BITS - 1)) + 1322 (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS); 1323 const int32_t const3 = ((1 << WARP_PARAM_REDUCE_BITS) - 1); 1324 const int16_t const4 = (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)); 1325 const int16_t const5 = (1 << (FILTER_BITS - reduce_bits_horiz)); 1326 1327 for (i = 0; i < p_height; i += 8) { 1328 for (j = 0; j < p_width; j += 8) { 1329 HWY_ALIGN int16_t horz_out[8 * 16 + hn::MaxLanes(int16_tag)]; 1330 const int32_t src_x = (p_col + j + 4) << subsampling_x; 1331 const int32_t src_y = (p_row + i + 4) << subsampling_y; 1332 const int64_t dst_x = 1333 (int64_t)mat[2] * src_x + (int64_t)mat[3] * src_y + (int64_t)mat[0]; 1334 const int64_t dst_y = 1335 (int64_t)mat[4] * src_x + (int64_t)mat[5] * src_y + (int64_t)mat[1]; 1336 const int64_t x4 = dst_x >> subsampling_x; 1337 const int64_t y4 = dst_y >> subsampling_y; 1338 1339 int32_t ix4 = (int32_t)(x4 >> WARPEDMODEL_PREC_BITS); 1340 int32_t sx4 = x4 & ((1 << WARPEDMODEL_PREC_BITS) - 1); 1341 int32_t iy4 = (int32_t)(y4 >> WARPEDMODEL_PREC_BITS); 1342 int32_t sy4 = y4 & ((1 << WARPEDMODEL_PREC_BITS) - 1); 1343 1344 // Add in all the constant terms, including rounding and offset 1345 sx4 += const1; 1346 sy4 += const2; 1347 1348 sx4 &= ~const3; 1349 sy4 &= ~const3; 1350 1351 // Horizontal filter 1352 // If the block is aligned such that, after clamping, every sample 1353 // would be taken from the leftmost/rightmost column, then we can 1354 // skip the expensive horizontal filter. 1355 1356 if (ix4 <= -7) { 1357 WarpHorizontalFilterOutOfBoundsSet(ref, height, stride, p_height, i, 1358 iy4, const4, const5, 0, horz_out); 1359 } else if (ix4 >= width + 6) { 1360 WarpHorizontalFilterOutOfBoundsSet(ref, height, stride, p_height, i, 1361 iy4, const4, const5, width - 1, 1362 horz_out); 1363 } else if (((ix4 - 7) < 0) || ((ix4 + 9) > width)) { 1364 WarpHorizontalFilterOutOfBoundsPad( 1365 ref, stride, ix4, iy4, sx4, alpha, beta, p_height, width, height, i, 1366 round_const, reduce_bits_horiz, horz_out); 1367 } else { 1368 PrepareWarpHorizontalFilter(ref, horz_out, stride, ix4, iy4, sx4, alpha, 1369 beta, p_height, height, i, round_const, 1370 reduce_bits_horiz); 1371 } 1372 1373 // Vertical filter 1374 PrepareWarpVerticalFilter(pred, horz_out, conv_params, gamma, delta, 1375 p_height, p_stride, p_width, i, j, sy4, 1376 reduce_bits_vert, res_add_const_1, round_bits, 1377 res_sub_const, round_bits_const, wt); 1378 } 1379 } 1380 } 1381 1382 } // namespace HWY_NAMESPACE 1383 } // namespace 1384 1385 #define MAKE_WARP_AFFINE(suffix) \ 1386 extern "C" void av1_warp_affine_##suffix( \ 1387 const int32_t *HWY_RESTRICT mat, const uint8_t *HWY_RESTRICT ref, \ 1388 int width, int height, int stride, uint8_t *HWY_RESTRICT pred, \ 1389 int p_col, int p_row, int p_width, int p_height, int p_stride, \ 1390 int subsampling_x, int subsampling_y, \ 1391 ConvolveParams *HWY_RESTRICT conv_params, int16_t alpha, int16_t beta, \ 1392 int16_t gamma, int16_t delta); \ 1393 HWY_ATTR void av1_warp_affine_##suffix( \ 1394 const int32_t *HWY_RESTRICT mat, const uint8_t *HWY_RESTRICT ref, \ 1395 int width, int height, int stride, uint8_t *HWY_RESTRICT pred, \ 1396 int p_col, int p_row, int p_width, int p_height, int p_stride, \ 1397 int subsampling_x, int subsampling_y, \ 1398 ConvolveParams *HWY_RESTRICT conv_params, int16_t alpha, int16_t beta, \ 1399 int16_t gamma, int16_t delta) { \ 1400 HWY_NAMESPACE::WarpAffine(mat, ref, width, height, stride, pred, p_col, \ 1401 p_row, p_width, p_height, p_stride, \ 1402 subsampling_x, subsampling_y, conv_params, \ 1403 alpha, beta, gamma, delta); \ 1404 } 1405 1406 HWY_AFTER_NAMESPACE(); 1407 1408 #endif // AV1_COMMON_WARP_PLANE_HWY_H_