selfguided_hwy.h (30338B)
1 /* 2 * Copyright (c) 2025, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #ifndef AV1_COMMON_SELFGUIDED_HWY_H_ 13 #define AV1_COMMON_SELFGUIDED_HWY_H_ 14 15 #include "av1/common/restoration.h" 16 #include "config/aom_config.h" 17 #include "config/av1_rtcd.h" 18 #include "third_party/highway/hwy/highway.h" 19 20 HWY_BEFORE_NAMESPACE(); 21 22 namespace { 23 namespace HWY_NAMESPACE { 24 25 namespace hn = hwy::HWY_NAMESPACE; 26 27 template <int NumBlocks> 28 struct ScanTraits {}; 29 30 template <> 31 struct ScanTraits<1> { 32 template <typename D> 33 HWY_ATTR HWY_INLINE static hn::VFromD<D> AddBlocks(D int32_tag, 34 hn::VFromD<D> v) { 35 (void)int32_tag; 36 return v; 37 } 38 }; 39 40 template <> 41 struct ScanTraits<2> { 42 template <typename D> 43 HWY_ATTR HWY_INLINE static hn::VFromD<D> AddBlocks(D int32_tag, 44 hn::VFromD<D> v) { 45 constexpr hn::Half<D> half_tag; 46 const int32_t s = hn::ExtractLane(v, 3); 47 const auto s01 = hn::Set(half_tag, s); 48 const auto s02 = hn::InsertBlock<1>(hn::Zero(int32_tag), s01); 49 return hn::Add(v, s02); 50 } 51 }; 52 53 template <> 54 struct ScanTraits<4> { 55 template <typename D> 56 HWY_ATTR HWY_INLINE static hn::VFromD<D> AddBlocks(D int32_tag, 57 hn::VFromD<D> v) { 58 HWY_ALIGN static const int32_t kA[] = { 59 0, 0, 0, 0, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 60 }; 61 HWY_ALIGN static const int32_t kB[] = { 62 0, 0, 0, 0, 0, 0, 0, 0, 23, 23, 23, 23, 23, 23, 23, 23, 63 }; 64 HWY_ALIGN static const int32_t kC[] = { 65 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 27, 27, 27, 66 }; 67 const auto a = hn::SetTableIndices(int32_tag, kA); 68 const auto b = hn::SetTableIndices(int32_tag, kB); 69 const auto c = hn::SetTableIndices(int32_tag, kC); 70 const auto s01 = 71 hn::TwoTablesLookupLanes(int32_tag, hn::Zero(int32_tag), v, a); 72 const auto s02 = 73 hn::TwoTablesLookupLanes(int32_tag, hn::Zero(int32_tag), v, b); 74 const auto s03 = 75 hn::TwoTablesLookupLanes(int32_tag, hn::Zero(int32_tag), v, c); 76 v = hn::Add(v, s01); 77 v = hn::Add(v, s02); 78 v = hn::Add(v, s03); 79 return v; 80 } 81 }; 82 83 // Compute the scan of a register holding 32-bit integers. If the register holds 84 // x0..x7 then the scan will hold x0, x0+x1, x0+x1+x2, ..., x0+x1+...+x7 85 // 86 // For the AVX2 example below, let [...] represent a 128-bit block, and let a, 87 // ..., h be 32-bit integers (assumed small enough to be able to add them 88 // without overflow). 89 // 90 // Use -> as shorthand for summing, i.e. h->a = h + g + f + e + d + c + b + a. 91 // 92 // x = [h g f e][d c b a] 93 // x01 = [g f e 0][c b a 0] 94 // x02 = [g+h f+g e+f e][c+d b+c a+b a] 95 // x03 = [e+f e 0 0][a+b a 0 0] 96 // x04 = [e->h e->g e->f e][a->d a->c a->b a] 97 // s = a->d 98 // s01 = [a->d a->d a->d a->d] 99 // s02 = [a->d a->d a->d a->d][0 0 0 0] 100 // ret = [a->h a->g a->f a->e][a->d a->c a->b a] 101 template <typename D> 102 HWY_ATTR HWY_INLINE hn::VFromD<D> Scan32(D int32_tag, hn::VFromD<D> x) { 103 const auto x01 = hn::ShiftLeftBytes<4>(x); 104 const auto x02 = hn::Add(x, x01); 105 const auto x03 = hn::ShiftLeftBytes<8>(x02); 106 const auto x04 = hn::Add(x02, x03); 107 return ScanTraits<int32_tag.MaxBlocks()>::AddBlocks(int32_tag, x04); 108 } 109 110 // Compute two integral images from src. B sums elements; A sums their 111 // squares. The images are offset by one pixel, so will have width and height 112 // equal to width + 1, height + 1 and the first row and column will be zero. 113 // 114 // A+1 and B+1 should be aligned to 32 bytes. buf_stride should be a multiple 115 // of 8. 116 template <typename T, typename D> 117 HWY_ATTR HWY_INLINE void IntegralImages(D int32_tag, const T *HWY_RESTRICT src, 118 int src_stride, int width, int height, 119 int32_t *HWY_RESTRICT A, 120 int32_t *HWY_RESTRICT B, 121 int buf_stride) { 122 constexpr hn::Rebind<T, D> uint_tag; 123 constexpr hn::Repartition<int16_t, D> int16_tag; 124 // Write out the zero top row 125 hwy::ZeroBytes(A, 4 * (width + 8)); 126 hwy::ZeroBytes(B, 4 * (width + 8)); 127 128 for (int i = 0; i < height; ++i) { 129 // Zero the left column. 130 A[(i + 1) * buf_stride] = B[(i + 1) * buf_stride] = 0; 131 132 // ldiff is the difference H - D where H is the output sample immediately 133 // to the left and D is the output sample above it. These are scalars, 134 // replicated across the eight lanes. 135 auto ldiff1 = hn::Zero(int32_tag); 136 auto ldiff2 = hn::Zero(int32_tag); 137 for (int j = 0; j < width; j += hn::MaxLanes(int32_tag)) { 138 const int ABj = 1 + j; 139 140 const auto above1 = hn::Load(int32_tag, B + ABj + i * buf_stride); 141 const auto above2 = hn::Load(int32_tag, A + ABj + i * buf_stride); 142 143 const auto x1 = hn::PromoteTo( 144 int32_tag, hn::LoadU(uint_tag, src + j + i * src_stride)); 145 const auto x2 = hn::WidenMulPairwiseAdd( 146 int32_tag, hn::BitCast(int16_tag, x1), hn::BitCast(int16_tag, x1)); 147 148 const auto sc1 = Scan32(int32_tag, x1); 149 const auto sc2 = Scan32(int32_tag, x2); 150 151 const auto row1 = hn::Add(hn::Add(sc1, above1), ldiff1); 152 const auto row2 = hn::Add(hn::Add(sc2, above2), ldiff2); 153 154 hn::Store(row1, int32_tag, B + ABj + (i + 1) * buf_stride); 155 hn::Store(row2, int32_tag, A + ABj + (i + 1) * buf_stride); 156 157 // Calculate the new H - D. 158 ldiff1 = hn::Set(int32_tag, hn::ExtractLane(hn::Sub(row1, above1), 159 hn::MaxLanes(int32_tag) - 1)); 160 ldiff2 = hn::Set(int32_tag, hn::ExtractLane(hn::Sub(row2, above2), 161 hn::MaxLanes(int32_tag) - 1)); 162 } 163 } 164 } 165 166 template <typename D> 167 HWY_ATTR HWY_INLINE hn::VFromD<D> BoxSumFromII(D int32_tag, 168 const int32_t *HWY_RESTRICT ii, 169 int stride, int r) { 170 const auto tl = hn::LoadU(int32_tag, ii - (r + 1) - (r + 1) * stride); 171 const auto tr = hn::LoadU(int32_tag, ii + (r + 0) - (r + 1) * stride); 172 const auto bl = hn::LoadU(int32_tag, ii - (r + 1) + r * stride); 173 const auto br = hn::LoadU(int32_tag, ii + (r + 0) + r * stride); 174 const auto u = hn::Sub(tr, tl); 175 const auto v = hn::Sub(br, bl); 176 return hn::Sub(v, u); 177 } 178 179 template <typename D> 180 HWY_ATTR HWY_INLINE hn::VFromD<D> RoundForShift(D int32_tag, 181 unsigned int shift) { 182 return hn::Set(int32_tag, (1 << shift) >> 1); 183 } 184 185 template <typename D> 186 HWY_ATTR HWY_INLINE hn::VFromD<D> ComputeP(D int32_tag, hn::VFromD<D> sum1, 187 hn::VFromD<D> sum2, int bit_depth, 188 int n) { 189 constexpr hn::Repartition<int16_t, D> int16_tag; 190 if (bit_depth > 8) { 191 const auto rounding_a = RoundForShift(int32_tag, 2 * (bit_depth - 8)); 192 const auto rounding_b = RoundForShift(int32_tag, bit_depth - 8); 193 const auto a = 194 hn::ShiftRightSame(hn::Add(sum2, rounding_a), 2 * (bit_depth - 8)); 195 const auto b = hn::ShiftRightSame(hn::Add(sum1, rounding_b), bit_depth - 8); 196 // b < 2^14, so we can use a 16-bit madd rather than a 32-bit 197 // mullo to square it 198 const auto b_16 = hn::BitCast(int16_tag, b); 199 const auto bb = hn::WidenMulPairwiseAdd(int32_tag, b_16, b_16); 200 const auto an = hn::Max(hn::Mul(a, hn::Set(int32_tag, n)), bb); 201 return hn::Sub(an, bb); 202 } 203 const auto sum1_16 = hn::BitCast(int16_tag, sum1); 204 const auto bb = hn::WidenMulPairwiseAdd(int32_tag, sum1_16, sum1_16); 205 const auto an = hn::Mul(sum2, hn::Set(int32_tag, n)); 206 return hn::Sub(an, bb); 207 } 208 209 // Calculate 8 values of the "cross sum" starting at buf. This is a 3x3 filter 210 // where the outer four corners have weight 3 and all other pixels have weight 211 // 4. 212 // 213 // Pixels are indexed as follows: 214 // xtl xt xtr 215 // xl x xr 216 // xbl xb xbr 217 // 218 // buf points to x 219 // 220 // fours = xl + xt + xr + xb + x 221 // threes = xtl + xtr + xbr + xbl 222 // cross_sum = 4 * fours + 3 * threes 223 // = 4 * (fours + threes) - threes 224 // = (fours + threes) << 2 - threes 225 template <typename D> 226 HWY_ATTR HWY_INLINE hn::VFromD<D> CrossSum(D int32_tag, 227 const int32_t *HWY_RESTRICT buf, 228 int stride) { 229 const auto xtl = hn::LoadU(int32_tag, buf - 1 - stride); 230 const auto xt = hn::LoadU(int32_tag, buf - stride); 231 const auto xtr = hn::LoadU(int32_tag, buf + 1 - stride); 232 const auto xl = hn::LoadU(int32_tag, buf - 1); 233 const auto x = hn::LoadU(int32_tag, buf); 234 const auto xr = hn::LoadU(int32_tag, buf + 1); 235 const auto xbl = hn::LoadU(int32_tag, buf - 1 + stride); 236 const auto xb = hn::LoadU(int32_tag, buf + stride); 237 const auto xbr = hn::LoadU(int32_tag, buf + 1 + stride); 238 239 const auto fours = hn::Add(xl, hn::Add(xt, hn::Add(xr, hn::Add(xb, x)))); 240 const auto threes = hn::Add(xtl, hn::Add(xtr, hn::Add(xbr, xbl))); 241 242 return hn::Sub(hn::ShiftLeft<2>(hn::Add(fours, threes)), threes); 243 } 244 245 // The final filter for self-guided restoration. Computes a weighted average 246 // across A, B with "cross sums" (see CrossSum implementation above). 247 template <typename DL> 248 HWY_ATTR HWY_INLINE void FinalFilter( 249 DL int32_tag, int32_t *HWY_RESTRICT dst, int dst_stride, 250 const int32_t *HWY_RESTRICT A, const int32_t *HWY_RESTRICT B, 251 int buf_stride, const void *HWY_RESTRICT dgd8, int dgd_stride, int width, 252 int height, int highbd) { 253 constexpr hn::Repartition<uint8_t, hn::Half<DL>> uint8_half_tag; 254 constexpr hn::Repartition<int16_t, DL> int16_tag; 255 constexpr int nb = 5; 256 constexpr int kShift = SGRPROJ_SGR_BITS + nb - SGRPROJ_RST_BITS; 257 const auto rounding = RoundForShift(int32_tag, kShift); 258 const uint8_t *HWY_RESTRICT dgd_real = 259 highbd ? reinterpret_cast<const uint8_t *>(CONVERT_TO_SHORTPTR(dgd8)) 260 : reinterpret_cast<const uint8_t *>(dgd8); 261 262 for (int i = 0; i < height; ++i) { 263 for (int j = 0; j < width; j += hn::MaxLanes(int32_tag)) { 264 const auto a = CrossSum(int32_tag, A + i * buf_stride + j, buf_stride); 265 const auto b = CrossSum(int32_tag, B + i * buf_stride + j, buf_stride); 266 267 const auto raw = hn::LoadU(uint8_half_tag, 268 dgd_real + ((i * dgd_stride + j) << highbd)); 269 const auto src = 270 highbd ? hn::PromoteTo( 271 int32_tag, 272 hn::BitCast( 273 hn::Repartition<int16_t, decltype(uint8_half_tag)>(), 274 raw)) 275 : hn::PromoteTo(int32_tag, hn::LowerHalf(raw)); 276 277 auto v = 278 hn::Add(hn::WidenMulPairwiseAdd(int32_tag, hn::BitCast(int16_tag, a), 279 hn::BitCast(int16_tag, src)), 280 b); 281 auto w = hn::ShiftRight<kShift>(hn::Add(v, rounding)); 282 283 hn::StoreU(w, int32_tag, dst + i * dst_stride + j); 284 } 285 } 286 } 287 288 // Assumes that C, D are integral images for the original buffer which has been 289 // extended to have a padding of SGRPROJ_BORDER_VERT/SGRPROJ_BORDER_HORZ pixels 290 // on the sides. A, B, C, D point at logical position (0, 0). 291 template <int Step, typename DL> 292 HWY_ATTR HWY_INLINE void CalcAB(DL int32_tag, int32_t *HWY_RESTRICT A, 293 int32_t *HWY_RESTRICT B, 294 const int32_t *HWY_RESTRICT C, 295 const int32_t *HWY_RESTRICT D, int width, 296 int height, int buf_stride, int bit_depth, 297 int sgr_params_idx, int radius_idx) { 298 constexpr hn::Repartition<int16_t, DL> int16_tag; 299 constexpr hn::Repartition<uint32_t, DL> uint32_tag; 300 const sgr_params_type *HWY_RESTRICT const params = 301 &av1_sgr_params[sgr_params_idx]; 302 const int r = params->r[radius_idx]; 303 const int n = (2 * r + 1) * (2 * r + 1); 304 const auto s = hn::Set(int32_tag, params->s[radius_idx]); 305 // one_over_n[n-1] is 2^12/n, so easily fits in an int16 306 const auto one_over_n = 307 hn::BitCast(int16_tag, hn::Set(int32_tag, av1_one_by_x[n - 1])); 308 309 const auto rnd_z = RoundForShift(int32_tag, SGRPROJ_MTABLE_BITS); 310 const auto rnd_res = RoundForShift(int32_tag, SGRPROJ_RECIP_BITS); 311 312 // Set up masks 313 const int max_lanes = static_cast<int>(hn::MaxLanes(int32_tag)); 314 HWY_ALIGN hn::Mask<decltype(int32_tag)> mask[max_lanes]; 315 for (int idx = 0; idx < max_lanes; idx++) { 316 mask[idx] = hn::FirstN(int32_tag, idx); 317 } 318 319 for (int i = -1; i < height + 1; i += Step) { 320 for (int j = -1; j < width + 1; j += max_lanes) { 321 const int32_t *HWY_RESTRICT Cij = C + i * buf_stride + j; 322 const int32_t *HWY_RESTRICT Dij = D + i * buf_stride + j; 323 324 auto sum1 = BoxSumFromII(int32_tag, Dij, buf_stride, r); 325 auto sum2 = BoxSumFromII(int32_tag, Cij, buf_stride, r); 326 327 // When width + 2 isn't a multiple of 8, sum1 and sum2 will contain 328 // some uninitialised data in their upper words. We use a mask to 329 // ensure that these bits are set to 0. 330 int idx = AOMMIN(max_lanes, width + 1 - j); 331 assert(idx >= 1); 332 333 if (idx < max_lanes) { 334 sum1 = hn::IfThenElseZero(mask[idx], sum1); 335 sum2 = hn::IfThenElseZero(mask[idx], sum2); 336 } 337 338 const auto p = ComputeP(int32_tag, sum1, sum2, bit_depth, n); 339 340 const auto z = hn::BitCast( 341 int32_tag, hn::Min(hn::ShiftRight<SGRPROJ_MTABLE_BITS>(hn::BitCast( 342 uint32_tag, hn::MulAdd(p, s, rnd_z))), 343 hn::Set(uint32_tag, 255))); 344 345 const auto a_res = hn::GatherIndex(int32_tag, av1_x_by_xplus1, z); 346 347 hn::StoreU(a_res, int32_tag, A + i * buf_stride + j); 348 349 const auto a_complement = hn::Sub(hn::Set(int32_tag, SGRPROJ_SGR), a_res); 350 351 // sum1 might have lanes greater than 2^15, so we can't use madd to do 352 // multiplication involving sum1. However, a_complement and one_over_n 353 // are both less than 256, so we can multiply them first. 354 const auto a_comp_over_n = hn::WidenMulPairwiseAdd( 355 int32_tag, hn::BitCast(int16_tag, a_complement), one_over_n); 356 const auto b_int = hn::Mul(a_comp_over_n, sum1); 357 const auto b_res = 358 hn::ShiftRight<SGRPROJ_RECIP_BITS>(hn::Add(b_int, rnd_res)); 359 360 hn::StoreU(b_res, int32_tag, B + i * buf_stride + j); 361 } 362 } 363 } 364 365 // Calculate 8 values of the "cross sum" starting at buf. 366 // 367 // Pixels are indexed like this: 368 // xtl xt xtr 369 // - buf - 370 // xbl xb xbr 371 // 372 // Pixels are weighted like this: 373 // 5 6 5 374 // 0 0 0 375 // 5 6 5 376 // 377 // fives = xtl + xtr + xbl + xbr 378 // sixes = xt + xb 379 // cross_sum = 6 * sixes + 5 * fives 380 // = 5 * (fives + sixes) - sixes 381 // = (fives + sixes) << 2 + (fives + sixes) + sixes 382 template <typename D> 383 HWY_ATTR HWY_INLINE hn::VFromD<D> CrossSumFastEvenRow( 384 D int32_tag, const int32_t *HWY_RESTRICT buf, int stride) { 385 const auto xtl = hn::LoadU(int32_tag, buf - 1 - stride); 386 const auto xt = hn::LoadU(int32_tag, buf - stride); 387 const auto xtr = hn::LoadU(int32_tag, buf + 1 - stride); 388 const auto xbl = hn::LoadU(int32_tag, buf - 1 + stride); 389 const auto xb = hn::LoadU(int32_tag, buf + stride); 390 const auto xbr = hn::LoadU(int32_tag, buf + 1 + stride); 391 392 const auto fives = hn::Add(xtl, hn::Add(xtr, hn::Add(xbr, xbl))); 393 const auto sixes = hn::Add(xt, xb); 394 const auto fives_plus_sixes = hn::Add(fives, sixes); 395 396 return hn::Add(hn::Add(hn::ShiftLeft<2>(fives_plus_sixes), fives_plus_sixes), 397 sixes); 398 } 399 400 // Calculate 8 values of the "cross sum" starting at buf. 401 // 402 // Pixels are indexed like this: 403 // xl x xr 404 // 405 // Pixels are weighted like this: 406 // 5 6 5 407 // 408 // buf points to x 409 // 410 // fives = xl + xr 411 // sixes = x 412 // cross_sum = 5 * fives + 6 * sixes 413 // = 4 * (fives + sixes) + (fives + sixes) + sixes 414 // = (fives + sixes) << 2 + (fives + sixes) + sixes 415 template <typename D> 416 HWY_ATTR HWY_INLINE hn::VFromD<D> CrossSumFastOddRow( 417 D int32_tag, const int32_t *HWY_RESTRICT buf) { 418 const auto xl = hn::LoadU(int32_tag, buf - 1); 419 const auto x = hn::LoadU(int32_tag, buf); 420 const auto xr = hn::LoadU(int32_tag, buf + 1); 421 422 const auto fives = hn::Add(xl, xr); 423 const auto sixes = x; 424 425 const auto fives_plus_sixes = hn::Add(fives, sixes); 426 427 return hn::Add(hn::Add(hn::ShiftLeft<2>(fives_plus_sixes), fives_plus_sixes), 428 sixes); 429 } 430 431 // The final filter for the self-guided restoration. Computes a 432 // weighted average across A, B with "cross sums" (see cross_sum_... 433 // implementations above). 434 template <typename DL> 435 HWY_ATTR HWY_INLINE void FinalFilterFast( 436 DL int32_tag, int32_t *HWY_RESTRICT dst, int dst_stride, 437 const int32_t *HWY_RESTRICT A, const int32_t *HWY_RESTRICT B, 438 int buf_stride, const void *HWY_RESTRICT dgd8, int dgd_stride, int width, 439 int height, int highbd) { 440 constexpr hn::Repartition<uint8_t, hn::Half<DL>> uint8_half_tag; 441 constexpr hn::Repartition<int16_t, DL> int16_tag; 442 constexpr int nb0 = 5; 443 constexpr int nb1 = 4; 444 constexpr int kShift0 = SGRPROJ_SGR_BITS + nb0 - SGRPROJ_RST_BITS; 445 constexpr int kShift1 = SGRPROJ_SGR_BITS + nb1 - SGRPROJ_RST_BITS; 446 447 const auto rounding0 = RoundForShift(int32_tag, kShift0); 448 const auto rounding1 = RoundForShift(int32_tag, kShift1); 449 450 const uint8_t *HWY_RESTRICT dgd_real = 451 highbd ? reinterpret_cast<const uint8_t *>(CONVERT_TO_SHORTPTR(dgd8)) 452 : reinterpret_cast<const uint8_t *>(dgd8); 453 454 for (int i = 0; i < height; ++i) { 455 if (!(i & 1)) { // even row 456 for (int j = 0; j < width; j += hn::MaxLanes(int32_tag)) { 457 const auto a = 458 CrossSumFastEvenRow(int32_tag, A + i * buf_stride + j, buf_stride); 459 const auto b = 460 CrossSumFastEvenRow(int32_tag, B + i * buf_stride + j, buf_stride); 461 462 const auto raw = hn::LoadU(uint8_half_tag, 463 dgd_real + ((i * dgd_stride + j) << highbd)); 464 const auto src = 465 highbd 466 ? hn::PromoteTo( 467 int32_tag, 468 hn::BitCast( 469 hn::Repartition<int16_t, decltype(uint8_half_tag)>(), 470 raw)) 471 : hn::PromoteTo(int32_tag, hn::LowerHalf(raw)); 472 473 auto v = hn::Add( 474 hn::WidenMulPairwiseAdd(int32_tag, hn::BitCast(int16_tag, a), 475 hn::BitCast(int16_tag, src)), 476 b); 477 auto w = hn::ShiftRight<kShift0>(hn::Add(v, rounding0)); 478 479 hn::StoreU(w, int32_tag, dst + i * dst_stride + j); 480 } 481 } else { // odd row 482 for (int j = 0; j < width; j += hn::MaxLanes(int32_tag)) { 483 const auto a = CrossSumFastOddRow(int32_tag, A + i * buf_stride + j); 484 const auto b = CrossSumFastOddRow(int32_tag, B + i * buf_stride + j); 485 486 const auto raw = hn::LoadU(uint8_half_tag, 487 dgd_real + ((i * dgd_stride + j) << highbd)); 488 const auto src = 489 highbd 490 ? hn::PromoteTo( 491 int32_tag, 492 hn::BitCast( 493 hn::Repartition<int16_t, decltype(uint8_half_tag)>(), 494 raw)) 495 : hn::PromoteTo(int32_tag, hn::LowerHalf(raw)); 496 497 auto v = hn::Add( 498 hn::WidenMulPairwiseAdd(int32_tag, hn::BitCast(int16_tag, a), 499 hn::BitCast(int16_tag, src)), 500 b); 501 auto w = hn::ShiftRight<kShift1>(hn::Add(v, rounding1)); 502 503 hn::StoreU(w, int32_tag, dst + i * dst_stride + j); 504 } 505 } 506 } 507 } 508 509 HWY_ATTR HWY_INLINE int SelfGuidedRestoration( 510 const uint8_t *dgd8, int width, int height, int dgd_stride, 511 int32_t *HWY_RESTRICT flt0, int32_t *HWY_RESTRICT flt1, int flt_stride, 512 int sgr_params_idx, int bit_depth, int highbd) { 513 constexpr hn::ScalableTag<int32_t> int32_tag; 514 constexpr int kAlignment32Log2 = hwy::CeilLog2(hn::MaxLanes(int32_tag)); 515 // The ALIGN_POWER_OF_TWO macro here ensures that column 1 of Atl, Btl, Ctl 516 // and Dtl is vector aligned. 517 const int buf_elts = 518 ALIGN_POWER_OF_TWO(RESTORATION_PROC_UNIT_PELS, kAlignment32Log2); 519 520 int32_t *buf = reinterpret_cast<int32_t *>( 521 aom_memalign(4 << kAlignment32Log2, 4 * sizeof(*buf) * buf_elts)); 522 if (!buf) { 523 return -1; 524 } 525 526 const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ; 527 const int height_ext = height + 2 * SGRPROJ_BORDER_VERT; 528 529 // Adjusting the stride of A and B here appears to avoid bad cache effects, 530 // leading to a significant speed improvement. 531 // We also align the stride to a multiple of the vector size for efficiency. 532 int buf_stride = 533 ALIGN_POWER_OF_TWO(width_ext + (2 << kAlignment32Log2), kAlignment32Log2); 534 535 // The "tl" pointers point at the top-left of the initialised data for the 536 // array. 537 int32_t *Atl = buf + 0 * buf_elts + (1 << kAlignment32Log2) - 1; 538 int32_t *Btl = buf + 1 * buf_elts + (1 << kAlignment32Log2) - 1; 539 int32_t *Ctl = buf + 2 * buf_elts + (1 << kAlignment32Log2) - 1; 540 int32_t *Dtl = buf + 3 * buf_elts + (1 << kAlignment32Log2) - 1; 541 542 // The "0" pointers are (- SGRPROJ_BORDER_VERT, -SGRPROJ_BORDER_HORZ). Note 543 // there's a zero row and column in A, B (integral images), so we move down 544 // and right one for them. 545 const int buf_diag_border = 546 SGRPROJ_BORDER_HORZ + buf_stride * SGRPROJ_BORDER_VERT; 547 548 int32_t *A0 = Atl + 1 + buf_stride; 549 int32_t *B0 = Btl + 1 + buf_stride; 550 int32_t *C0 = Ctl + 1 + buf_stride; 551 int32_t *D0 = Dtl + 1 + buf_stride; 552 553 // Finally, A, B, C, D point at position (0, 0). 554 int32_t *A = A0 + buf_diag_border; 555 int32_t *B = B0 + buf_diag_border; 556 int32_t *C = C0 + buf_diag_border; 557 int32_t *D = D0 + buf_diag_border; 558 559 const int dgd_diag_border = 560 SGRPROJ_BORDER_HORZ + dgd_stride * SGRPROJ_BORDER_VERT; 561 const uint8_t *dgd0 = dgd8 - dgd_diag_border; 562 563 // Generate integral images from the input. C will contain sums of squares; D 564 // will contain just sums 565 if (highbd) { 566 IntegralImages(int32_tag, CONVERT_TO_SHORTPTR(dgd0), dgd_stride, width_ext, 567 height_ext, Ctl, Dtl, buf_stride); 568 } else { 569 IntegralImages(int32_tag, dgd0, dgd_stride, width_ext, height_ext, Ctl, Dtl, 570 buf_stride); 571 } 572 573 const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx]; 574 // Write to flt0 and flt1 575 // If params->r == 0 we skip the corresponding filter. We only allow one of 576 // the radii to be 0, as having both equal to 0 would be equivalent to 577 // skipping SGR entirely. 578 assert(!(params->r[0] == 0 && params->r[1] == 0)); 579 assert(params->r[0] < AOMMIN(SGRPROJ_BORDER_VERT, SGRPROJ_BORDER_HORZ)); 580 assert(params->r[1] < AOMMIN(SGRPROJ_BORDER_VERT, SGRPROJ_BORDER_HORZ)); 581 582 if (params->r[0] > 0) { 583 CalcAB<2>(int32_tag, A, B, C, D, width, height, buf_stride, bit_depth, 584 sgr_params_idx, 0); 585 FinalFilterFast(int32_tag, flt0, flt_stride, A, B, buf_stride, dgd8, 586 dgd_stride, width, height, highbd); 587 } 588 589 if (params->r[1] > 0) { 590 CalcAB<1>(int32_tag, A, B, C, D, width, height, buf_stride, bit_depth, 591 sgr_params_idx, 1); 592 FinalFilter(int32_tag, flt1, flt_stride, A, B, buf_stride, dgd8, dgd_stride, 593 width, height, highbd); 594 } 595 aom_free(buf); 596 return 0; 597 } 598 599 HWY_ATTR HWY_INLINE int ApplySelfGuidedRestoration( 600 const uint8_t *HWY_RESTRICT dat8, int width, int height, int stride, 601 int eps, const int *HWY_RESTRICT xqd, uint8_t *HWY_RESTRICT dst8, 602 int dst_stride, int32_t *HWY_RESTRICT tmpbuf, int bit_depth, int highbd) { 603 constexpr hn::CappedTag<int32_t, 16> int32_tag; 604 constexpr size_t kBatchSize = hn::MaxLanes(int32_tag) * 2; 605 int32_t *flt0 = tmpbuf; 606 int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX; 607 assert(width * height <= RESTORATION_UNITPELS_MAX); 608 #if HWY_TARGET == HWY_SSE4 609 const int ret = av1_selfguided_restoration_sse4_1( 610 dat8, width, height, stride, flt0, flt1, width, eps, bit_depth, highbd); 611 #elif HWY_TARGET == HWY_AVX2 612 const int ret = av1_selfguided_restoration_avx2( 613 dat8, width, height, stride, flt0, flt1, width, eps, bit_depth, highbd); 614 #elif HWY_TARGET <= HWY_AVX3 615 const int ret = av1_selfguided_restoration_avx512( 616 dat8, width, height, stride, flt0, flt1, width, eps, bit_depth, highbd); 617 #else 618 #error "HWY_TARGET is not supported." 619 const int ret = -1; 620 #endif 621 if (ret != 0) { 622 return ret; 623 } 624 const sgr_params_type *const params = &av1_sgr_params[eps]; 625 int xq[2]; 626 av1_decode_xq(xqd, xq, params); 627 628 auto xq0 = hn::Set(int32_tag, xq[0]); 629 auto xq1 = hn::Set(int32_tag, xq[1]); 630 631 for (int i = 0; i < height; ++i) { 632 // Calculate output in batches of pixels 633 for (int j = 0; j < width; j += kBatchSize) { 634 const int k = i * width + j; 635 const int m = i * dst_stride + j; 636 637 const uint8_t *dat8ij = dat8 + i * stride + j; 638 auto ep_0 = hn::Undefined(int32_tag); 639 auto ep_1 = hn::Undefined(int32_tag); 640 if (highbd) { 641 constexpr hn::Repartition<uint16_t, hn::Half<decltype(int32_tag)>> 642 uint16_tag; 643 const auto src_0 = hn::LoadU(uint16_tag, CONVERT_TO_SHORTPTR(dat8ij)); 644 const auto src_1 = hn::LoadU( 645 uint16_tag, CONVERT_TO_SHORTPTR(dat8ij) + hn::MaxLanes(int32_tag)); 646 ep_0 = hn::PromoteTo(int32_tag, src_0); 647 ep_1 = hn::PromoteTo(int32_tag, src_1); 648 } else { 649 constexpr hn::Repartition<uint8_t, hn::Half<decltype(int32_tag)>> 650 uint8_tag; 651 const auto src_0 = hn::LoadU(uint8_tag, dat8ij); 652 ep_0 = hn::PromoteLowerTo(int32_tag, src_0); 653 ep_1 = hn::PromoteUpperTo(int32_tag, src_0); 654 } 655 656 const auto u_0 = hn::ShiftLeft<SGRPROJ_RST_BITS>(ep_0); 657 const auto u_1 = hn::ShiftLeft<SGRPROJ_RST_BITS>(ep_1); 658 659 auto v_0 = hn::ShiftLeft<SGRPROJ_PRJ_BITS>(u_0); 660 auto v_1 = hn::ShiftLeft<SGRPROJ_PRJ_BITS>(u_1); 661 662 if (params->r[0] > 0) { 663 const auto f1_0 = hn::Sub(hn::LoadU(int32_tag, &flt0[k]), u_0); 664 v_0 = hn::Add(v_0, hn::Mul(xq0, f1_0)); 665 666 const auto f1_1 = hn::Sub( 667 hn::LoadU(int32_tag, &flt0[k + hn::MaxLanes(int32_tag)]), u_1); 668 v_1 = hn::Add(v_1, hn::Mul(xq0, f1_1)); 669 } 670 671 if (params->r[1] > 0) { 672 const auto f2_0 = hn::Sub(hn::LoadU(int32_tag, &flt1[k]), u_0); 673 v_0 = hn::Add(v_0, hn::Mul(xq1, f2_0)); 674 675 const auto f2_1 = hn::Sub( 676 hn::LoadU(int32_tag, &flt1[k + hn::MaxLanes(int32_tag)]), u_1); 677 v_1 = hn::Add(v_1, hn::Mul(xq1, f2_1)); 678 } 679 680 const auto rounding = 681 RoundForShift(int32_tag, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS); 682 const auto w_0 = hn::ShiftRight<SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS>( 683 hn::Add(v_0, rounding)); 684 const auto w_1 = hn::ShiftRight<SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS>( 685 hn::Add(v_1, rounding)); 686 687 if (highbd) { 688 // Pack into 16 bits and clamp to [0, 2^bit_depth) 689 // Note that packing into 16 bits messes up the order of the bits, 690 // so we use a permute function to correct this 691 constexpr hn::Repartition<uint16_t, decltype(int32_tag)> uint16_tag; 692 const auto tmp = hn::OrderedDemote2To(uint16_tag, w_0, w_1); 693 const auto max = hn::Set(uint16_tag, (1 << bit_depth) - 1); 694 const auto res = hn::Min(tmp, max); 695 hn::StoreU(res, uint16_tag, CONVERT_TO_SHORTPTR(dst8 + m)); 696 } else { 697 // Pack into 8 bits and clamp to [0, 256) 698 // Note that each pack messes up the order of the bits, 699 // so we use a permute function to correct this 700 constexpr hn::Repartition<int16_t, decltype(int32_tag)> int16_tag; 701 constexpr hn::Repartition<uint8_t, hn::Half<decltype(int32_tag)>> 702 uint8_tag; 703 const auto tmp = hn::OrderedDemote2To(int16_tag, w_0, w_1); 704 const auto res = hn::DemoteTo(uint8_tag, tmp); 705 hn::StoreU(res, uint8_tag, dst8 + m); 706 } 707 } 708 } 709 return 0; 710 } 711 712 } // namespace HWY_NAMESPACE 713 } // namespace 714 715 HWY_AFTER_NAMESPACE(); 716 717 #define MAKE_SELFGUIDED_RESTORATION(suffix) \ 718 extern "C" int av1_selfguided_restoration_##suffix( \ 719 const uint8_t *dgd8, int width, int height, int dgd_stride, \ 720 int32_t *flt0, int32_t *flt1, int flt_stride, int sgr_params_idx, \ 721 int bit_depth, int highbd); \ 722 HWY_ATTR HWY_NOINLINE int av1_selfguided_restoration_##suffix( \ 723 const uint8_t *dgd8, int width, int height, int dgd_stride, \ 724 int32_t *flt0, int32_t *flt1, int flt_stride, int sgr_params_idx, \ 725 int bit_depth, int highbd) { \ 726 return HWY_NAMESPACE::SelfGuidedRestoration( \ 727 dgd8, width, height, dgd_stride, flt0, flt1, flt_stride, \ 728 sgr_params_idx, bit_depth, highbd); \ 729 } \ 730 extern "C" int av1_apply_selfguided_restoration_##suffix( \ 731 const uint8_t *dat8, int width, int height, int stride, int eps, \ 732 const int *xqd, uint8_t *dst8, int dst_stride, int32_t *tmpbuf, \ 733 int bit_depth, int highbd); \ 734 HWY_ATTR int av1_apply_selfguided_restoration_##suffix( \ 735 const uint8_t *dat8, int width, int height, int stride, int eps, \ 736 const int *xqd, uint8_t *dst8, int dst_stride, int32_t *tmpbuf, \ 737 int bit_depth, int highbd) { \ 738 return HWY_NAMESPACE::ApplySelfGuidedRestoration( \ 739 dat8, width, height, stride, eps, xqd, dst8, dst_stride, tmpbuf, \ 740 bit_depth, highbd); \ 741 } 742 743 #endif // AV1_COMMON_SELFGUIDED_HWY_H_