highbd_obmc_variance_neon.c (14558B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 14 #include "config/aom_config.h" 15 #include "config/aom_dsp_rtcd.h" 16 17 #include "aom/aom_integer.h" 18 #include "aom_dsp/arm/mem_neon.h" 19 #include "aom_dsp/arm/sum_neon.h" 20 21 static inline void highbd_obmc_variance_8x1_s16_neon(uint16x8_t pre, 22 const int32_t *wsrc, 23 const int32_t *mask, 24 uint32x4_t *sse, 25 int32x4_t *sum) { 26 int16x8_t pre_s16 = vreinterpretq_s16_u16(pre); 27 int32x4_t wsrc_lo = vld1q_s32(&wsrc[0]); 28 int32x4_t wsrc_hi = vld1q_s32(&wsrc[4]); 29 30 int32x4_t mask_lo = vld1q_s32(&mask[0]); 31 int32x4_t mask_hi = vld1q_s32(&mask[4]); 32 33 int16x8_t mask_s16 = vcombine_s16(vmovn_s32(mask_lo), vmovn_s32(mask_hi)); 34 35 int32x4_t diff_lo = vmull_s16(vget_low_s16(pre_s16), vget_low_s16(mask_s16)); 36 int32x4_t diff_hi = 37 vmull_s16(vget_high_s16(pre_s16), vget_high_s16(mask_s16)); 38 39 diff_lo = vsubq_s32(wsrc_lo, diff_lo); 40 diff_hi = vsubq_s32(wsrc_hi, diff_hi); 41 42 // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away 43 // from zero, however vrshrq_n_s32 rounds to nearest with ties rounded up. 44 // This difference only affects the bit patterns at the rounding breakpoints 45 // exactly, so we can add -1 to all negative numbers to move the breakpoint 46 // one value across and into the correct rounding region. 47 diff_lo = vsraq_n_s32(diff_lo, diff_lo, 31); 48 diff_hi = vsraq_n_s32(diff_hi, diff_hi, 31); 49 int32x4_t round_lo = vrshrq_n_s32(diff_lo, 12); 50 int32x4_t round_hi = vrshrq_n_s32(diff_hi, 12); 51 52 *sum = vaddq_s32(*sum, round_lo); 53 *sum = vaddq_s32(*sum, round_hi); 54 *sse = vmlaq_u32(*sse, vreinterpretq_u32_s32(round_lo), 55 vreinterpretq_u32_s32(round_lo)); 56 *sse = vmlaq_u32(*sse, vreinterpretq_u32_s32(round_hi), 57 vreinterpretq_u32_s32(round_hi)); 58 } 59 60 // For 12-bit data, we can only accumulate up to 256 elements in the unsigned 61 // 32-bit elements (4095*4095*256 = 4292870400) before we have to accumulate 62 // into 64-bit elements. Therefore blocks of size 32x64, 64x32, 64x64, 64x128, 63 // 128x64, 128x128 are processed in a different helper function. 64 static inline void highbd_obmc_variance_xlarge_neon( 65 const uint8_t *pre, int pre_stride, const int32_t *wsrc, 66 const int32_t *mask, int width, int h, int h_limit, uint64_t *sse, 67 int64_t *sum) { 68 uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre); 69 int32x4_t sum_s32 = vdupq_n_s32(0); 70 uint64x2_t sse_u64 = vdupq_n_u64(0); 71 72 // 'h_limit' is the number of 'w'-width rows we can process before our 32-bit 73 // accumulator overflows. After hitting this limit we accumulate into 64-bit 74 // elements. 75 int h_tmp = h > h_limit ? h_limit : h; 76 77 do { 78 uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) }; 79 int j = 0; 80 81 do { 82 int i = 0; 83 84 do { 85 uint16x8_t pre0 = vld1q_u16(pre_ptr + i); 86 highbd_obmc_variance_8x1_s16_neon(pre0, wsrc, mask, &sse_u32[0], 87 &sum_s32); 88 89 uint16x8_t pre1 = vld1q_u16(pre_ptr + i + 8); 90 highbd_obmc_variance_8x1_s16_neon(pre1, wsrc + 8, mask + 8, &sse_u32[1], 91 &sum_s32); 92 93 i += 16; 94 wsrc += 16; 95 mask += 16; 96 } while (i < width); 97 98 pre_ptr += pre_stride; 99 j++; 100 } while (j < h_tmp); 101 102 sse_u64 = vpadalq_u32(sse_u64, sse_u32[0]); 103 sse_u64 = vpadalq_u32(sse_u64, sse_u32[1]); 104 h -= h_tmp; 105 } while (h != 0); 106 107 *sse = horizontal_add_u64x2(sse_u64); 108 *sum = horizontal_long_add_s32x4(sum_s32); 109 } 110 111 static inline void highbd_obmc_variance_xlarge_neon_128xh( 112 const uint8_t *pre, int pre_stride, const int32_t *wsrc, 113 const int32_t *mask, int h, uint64_t *sse, int64_t *sum) { 114 highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 128, h, 16, sse, 115 sum); 116 } 117 118 static inline void highbd_obmc_variance_xlarge_neon_64xh( 119 const uint8_t *pre, int pre_stride, const int32_t *wsrc, 120 const int32_t *mask, int h, uint64_t *sse, int64_t *sum) { 121 highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 64, h, 32, sse, 122 sum); 123 } 124 125 static inline void highbd_obmc_variance_xlarge_neon_32xh( 126 const uint8_t *pre, int pre_stride, const int32_t *wsrc, 127 const int32_t *mask, int h, uint64_t *sse, int64_t *sum) { 128 highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 32, h, 64, sse, 129 sum); 130 } 131 132 static inline void highbd_obmc_variance_large_neon( 133 const uint8_t *pre, int pre_stride, const int32_t *wsrc, 134 const int32_t *mask, int width, int h, uint64_t *sse, int64_t *sum) { 135 uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre); 136 uint32x4_t sse_u32 = vdupq_n_u32(0); 137 int32x4_t sum_s32 = vdupq_n_s32(0); 138 139 do { 140 int i = 0; 141 do { 142 uint16x8_t pre0 = vld1q_u16(pre_ptr + i); 143 highbd_obmc_variance_8x1_s16_neon(pre0, wsrc, mask, &sse_u32, &sum_s32); 144 145 uint16x8_t pre1 = vld1q_u16(pre_ptr + i + 8); 146 highbd_obmc_variance_8x1_s16_neon(pre1, wsrc + 8, mask + 8, &sse_u32, 147 &sum_s32); 148 149 i += 16; 150 wsrc += 16; 151 mask += 16; 152 } while (i < width); 153 154 pre_ptr += pre_stride; 155 } while (--h != 0); 156 157 *sse = horizontal_long_add_u32x4(sse_u32); 158 *sum = horizontal_long_add_s32x4(sum_s32); 159 } 160 161 static inline void highbd_obmc_variance_neon_128xh( 162 const uint8_t *pre, int pre_stride, const int32_t *wsrc, 163 const int32_t *mask, int h, uint64_t *sse, int64_t *sum) { 164 highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 128, h, sse, 165 sum); 166 } 167 168 static inline void highbd_obmc_variance_neon_64xh(const uint8_t *pre, 169 int pre_stride, 170 const int32_t *wsrc, 171 const int32_t *mask, int h, 172 uint64_t *sse, int64_t *sum) { 173 highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 64, h, sse, sum); 174 } 175 176 static inline void highbd_obmc_variance_neon_32xh(const uint8_t *pre, 177 int pre_stride, 178 const int32_t *wsrc, 179 const int32_t *mask, int h, 180 uint64_t *sse, int64_t *sum) { 181 highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 32, h, sse, sum); 182 } 183 184 static inline void highbd_obmc_variance_neon_16xh(const uint8_t *pre, 185 int pre_stride, 186 const int32_t *wsrc, 187 const int32_t *mask, int h, 188 uint64_t *sse, int64_t *sum) { 189 highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 16, h, sse, sum); 190 } 191 192 static inline void highbd_obmc_variance_neon_8xh(const uint8_t *pre8, 193 int pre_stride, 194 const int32_t *wsrc, 195 const int32_t *mask, int h, 196 uint64_t *sse, int64_t *sum) { 197 uint16_t *pre = CONVERT_TO_SHORTPTR(pre8); 198 uint32x4_t sse_u32 = vdupq_n_u32(0); 199 int32x4_t sum_s32 = vdupq_n_s32(0); 200 201 do { 202 uint16x8_t pre_u16 = vld1q_u16(pre); 203 204 highbd_obmc_variance_8x1_s16_neon(pre_u16, wsrc, mask, &sse_u32, &sum_s32); 205 206 pre += pre_stride; 207 wsrc += 8; 208 mask += 8; 209 } while (--h != 0); 210 211 *sse = horizontal_long_add_u32x4(sse_u32); 212 *sum = horizontal_long_add_s32x4(sum_s32); 213 } 214 215 static inline void highbd_obmc_variance_neon_4xh(const uint8_t *pre8, 216 int pre_stride, 217 const int32_t *wsrc, 218 const int32_t *mask, int h, 219 uint64_t *sse, int64_t *sum) { 220 assert(h % 2 == 0); 221 uint16_t *pre = CONVERT_TO_SHORTPTR(pre8); 222 uint32x4_t sse_u32 = vdupq_n_u32(0); 223 int32x4_t sum_s32 = vdupq_n_s32(0); 224 225 do { 226 uint16x8_t pre_u16 = load_unaligned_u16_4x2(pre, pre_stride); 227 228 highbd_obmc_variance_8x1_s16_neon(pre_u16, wsrc, mask, &sse_u32, &sum_s32); 229 230 pre += 2 * pre_stride; 231 wsrc += 8; 232 mask += 8; 233 h -= 2; 234 } while (h != 0); 235 236 *sse = horizontal_long_add_u32x4(sse_u32); 237 *sum = horizontal_long_add_s32x4(sum_s32); 238 } 239 240 static inline void highbd_8_obmc_variance_cast(int64_t sum64, uint64_t sse64, 241 int *sum, unsigned int *sse) { 242 *sum = (int)sum64; 243 *sse = (unsigned int)sse64; 244 } 245 246 static inline void highbd_10_obmc_variance_cast(int64_t sum64, uint64_t sse64, 247 int *sum, unsigned int *sse) { 248 *sum = (int)ROUND_POWER_OF_TWO(sum64, 2); 249 *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 4); 250 } 251 252 static inline void highbd_12_obmc_variance_cast(int64_t sum64, uint64_t sse64, 253 int *sum, unsigned int *sse) { 254 *sum = (int)ROUND_POWER_OF_TWO(sum64, 4); 255 *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 8); 256 } 257 258 #define HIGHBD_OBMC_VARIANCE_WXH_NEON(w, h, bitdepth) \ 259 unsigned int aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ 260 const uint8_t *pre, int pre_stride, const int32_t *wsrc, \ 261 const int32_t *mask, unsigned int *sse) { \ 262 int sum; \ 263 int64_t sum64; \ 264 uint64_t sse64; \ 265 highbd_obmc_variance_neon_##w##xh(pre, pre_stride, wsrc, mask, h, &sse64, \ 266 &sum64); \ 267 highbd_##bitdepth##_obmc_variance_cast(sum64, sse64, &sum, sse); \ 268 return *sse - (unsigned int)(((int64_t)sum * sum) / (w * h)); \ 269 } 270 271 #define HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(w, h, bitdepth) \ 272 unsigned int aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \ 273 const uint8_t *pre, int pre_stride, const int32_t *wsrc, \ 274 const int32_t *mask, unsigned int *sse) { \ 275 int sum; \ 276 int64_t sum64; \ 277 uint64_t sse64; \ 278 highbd_obmc_variance_xlarge_neon_##w##xh(pre, pre_stride, wsrc, mask, h, \ 279 &sse64, &sum64); \ 280 highbd_##bitdepth##_obmc_variance_cast(sum64, sse64, &sum, sse); \ 281 return *sse - (unsigned int)(((int64_t)sum * sum) / (w * h)); \ 282 } 283 284 // 8-bit 285 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 8) 286 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 8) 287 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 8) 288 289 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 8) 290 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 8) 291 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 8) 292 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 8) 293 294 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 8) 295 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 8) 296 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 8) 297 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 8) 298 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 8) 299 300 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 8) 301 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 8) 302 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 8) 303 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 64, 8) 304 305 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 8) 306 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 32, 8) 307 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 64, 8) 308 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 128, 8) 309 310 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 64, 8) 311 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 128, 8) 312 313 // 10-bit 314 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 10) 315 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 10) 316 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 10) 317 318 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 10) 319 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 10) 320 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 10) 321 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 10) 322 323 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 10) 324 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 10) 325 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 10) 326 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 10) 327 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 10) 328 329 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 10) 330 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 10) 331 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 10) 332 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 64, 10) 333 334 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 10) 335 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 32, 10) 336 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 64, 10) 337 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 128, 10) 338 339 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 64, 10) 340 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 128, 10) 341 342 // 12-bit 343 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 12) 344 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 12) 345 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 12) 346 347 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 12) 348 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 12) 349 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 12) 350 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 12) 351 352 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 12) 353 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 12) 354 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 12) 355 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 12) 356 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 12) 357 358 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 12) 359 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 12) 360 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 12) 361 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(32, 64, 12) 362 363 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 12) 364 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 32, 12) 365 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 64, 12) 366 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 128, 12) 367 368 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(128, 64, 12) 369 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(128, 128, 12)