obmc_variance_neon.c (11599B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 14 #include "config/aom_config.h" 15 #include "config/aom_dsp_rtcd.h" 16 #include "mem_neon.h" 17 #include "sum_neon.h" 18 19 static inline void obmc_variance_8x1_s16_neon(int16x8_t pre_s16, 20 const int32_t *wsrc, 21 const int32_t *mask, 22 int32x4_t *ssev, 23 int32x4_t *sumv) { 24 // For 4xh and 8xh we observe it is faster to avoid the double-widening of 25 // pre. Instead we do a single widening step and narrow the mask to 16-bits 26 // to allow us to perform a widening multiply. Widening multiply 27 // instructions have better throughput on some micro-architectures but for 28 // the larger block sizes this benefit is outweighed by the additional 29 // instruction needed to first narrow the mask vectors. 30 31 int32x4_t wsrc_s32_lo = vld1q_s32(&wsrc[0]); 32 int32x4_t wsrc_s32_hi = vld1q_s32(&wsrc[4]); 33 int16x8_t mask_s16 = vuzpq_s16(vreinterpretq_s16_s32(vld1q_s32(&mask[0])), 34 vreinterpretq_s16_s32(vld1q_s32(&mask[4]))) 35 .val[0]; 36 37 int32x4_t diff_s32_lo = 38 vmlsl_s16(wsrc_s32_lo, vget_low_s16(pre_s16), vget_low_s16(mask_s16)); 39 int32x4_t diff_s32_hi = 40 vmlsl_s16(wsrc_s32_hi, vget_high_s16(pre_s16), vget_high_s16(mask_s16)); 41 42 // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away 43 // from zero, however vrshrq_n_s32 rounds to nearest with ties rounded up. 44 // This difference only affects the bit patterns at the rounding breakpoints 45 // exactly, so we can add -1 to all negative numbers to move the breakpoint 46 // one value across and into the correct rounding region. 47 diff_s32_lo = vsraq_n_s32(diff_s32_lo, diff_s32_lo, 31); 48 diff_s32_hi = vsraq_n_s32(diff_s32_hi, diff_s32_hi, 31); 49 int32x4_t round_s32_lo = vrshrq_n_s32(diff_s32_lo, 12); 50 int32x4_t round_s32_hi = vrshrq_n_s32(diff_s32_hi, 12); 51 52 *sumv = vrsraq_n_s32(*sumv, diff_s32_lo, 12); 53 *sumv = vrsraq_n_s32(*sumv, diff_s32_hi, 12); 54 *ssev = vmlaq_s32(*ssev, round_s32_lo, round_s32_lo); 55 *ssev = vmlaq_s32(*ssev, round_s32_hi, round_s32_hi); 56 } 57 58 #if AOM_ARCH_AARCH64 59 60 // Use tbl for doing a double-width zero extension from 8->32 bits since we can 61 // do this in one instruction rather than two (indices out of range (255 here) 62 // are set to zero by tbl). 63 DECLARE_ALIGNED(16, static const uint8_t, obmc_variance_permute_idx[]) = { 64 0, 255, 255, 255, 1, 255, 255, 255, 2, 255, 255, 255, 3, 255, 255, 255, 65 4, 255, 255, 255, 5, 255, 255, 255, 6, 255, 255, 255, 7, 255, 255, 255, 66 8, 255, 255, 255, 9, 255, 255, 255, 10, 255, 255, 255, 11, 255, 255, 255, 67 12, 255, 255, 255, 13, 255, 255, 255, 14, 255, 255, 255, 15, 255, 255, 255 68 }; 69 70 static inline void obmc_variance_8x1_s32_neon( 71 int32x4_t pre_lo, int32x4_t pre_hi, const int32_t *wsrc, 72 const int32_t *mask, int32x4_t *ssev, int32x4_t *sumv) { 73 int32x4_t wsrc_lo = vld1q_s32(&wsrc[0]); 74 int32x4_t wsrc_hi = vld1q_s32(&wsrc[4]); 75 int32x4_t mask_lo = vld1q_s32(&mask[0]); 76 int32x4_t mask_hi = vld1q_s32(&mask[4]); 77 78 int32x4_t diff_lo = vmlsq_s32(wsrc_lo, pre_lo, mask_lo); 79 int32x4_t diff_hi = vmlsq_s32(wsrc_hi, pre_hi, mask_hi); 80 81 // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away from 82 // zero, however vrshrq_n_s32 rounds to nearest with ties rounded up. This 83 // difference only affects the bit patterns at the rounding breakpoints 84 // exactly, so we can add -1 to all negative numbers to move the breakpoint 85 // one value across and into the correct rounding region. 86 diff_lo = vsraq_n_s32(diff_lo, diff_lo, 31); 87 diff_hi = vsraq_n_s32(diff_hi, diff_hi, 31); 88 int32x4_t round_lo = vrshrq_n_s32(diff_lo, 12); 89 int32x4_t round_hi = vrshrq_n_s32(diff_hi, 12); 90 91 *sumv = vrsraq_n_s32(*sumv, diff_lo, 12); 92 *sumv = vrsraq_n_s32(*sumv, diff_hi, 12); 93 *ssev = vmlaq_s32(*ssev, round_lo, round_lo); 94 *ssev = vmlaq_s32(*ssev, round_hi, round_hi); 95 } 96 97 static inline void obmc_variance_large_neon(const uint8_t *pre, int pre_stride, 98 const int32_t *wsrc, 99 const int32_t *mask, int width, 100 int height, unsigned *sse, 101 int *sum) { 102 assert(width % 16 == 0); 103 104 // Use tbl for doing a double-width zero extension from 8->32 bits since we 105 // can do this in one instruction rather than two. 106 uint8x16_t pre_idx0 = vld1q_u8(&obmc_variance_permute_idx[0]); 107 uint8x16_t pre_idx1 = vld1q_u8(&obmc_variance_permute_idx[16]); 108 uint8x16_t pre_idx2 = vld1q_u8(&obmc_variance_permute_idx[32]); 109 uint8x16_t pre_idx3 = vld1q_u8(&obmc_variance_permute_idx[48]); 110 111 int32x4_t ssev = vdupq_n_s32(0); 112 int32x4_t sumv = vdupq_n_s32(0); 113 114 int h = height; 115 do { 116 int w = width; 117 do { 118 uint8x16_t pre_u8 = vld1q_u8(pre); 119 120 int32x4_t pre_s32_lo = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx0)); 121 int32x4_t pre_s32_hi = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx1)); 122 obmc_variance_8x1_s32_neon(pre_s32_lo, pre_s32_hi, &wsrc[0], &mask[0], 123 &ssev, &sumv); 124 125 pre_s32_lo = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx2)); 126 pre_s32_hi = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx3)); 127 obmc_variance_8x1_s32_neon(pre_s32_lo, pre_s32_hi, &wsrc[8], &mask[8], 128 &ssev, &sumv); 129 130 wsrc += 16; 131 mask += 16; 132 pre += 16; 133 w -= 16; 134 } while (w != 0); 135 136 pre += pre_stride - width; 137 } while (--h != 0); 138 139 *sse = horizontal_add_s32x4(ssev); 140 *sum = horizontal_add_s32x4(sumv); 141 } 142 143 #else // !AOM_ARCH_AARCH64 144 145 static inline void obmc_variance_large_neon(const uint8_t *pre, int pre_stride, 146 const int32_t *wsrc, 147 const int32_t *mask, int width, 148 int height, unsigned *sse, 149 int *sum) { 150 // Non-aarch64 targets do not have a 128-bit tbl instruction, so use the 151 // widening version of the core kernel instead. 152 153 assert(width % 16 == 0); 154 155 int32x4_t ssev = vdupq_n_s32(0); 156 int32x4_t sumv = vdupq_n_s32(0); 157 158 int h = height; 159 do { 160 int w = width; 161 do { 162 uint8x16_t pre_u8 = vld1q_u8(pre); 163 164 int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(pre_u8))); 165 obmc_variance_8x1_s16_neon(pre_s16, &wsrc[0], &mask[0], &ssev, &sumv); 166 167 pre_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(pre_u8))); 168 obmc_variance_8x1_s16_neon(pre_s16, &wsrc[8], &mask[8], &ssev, &sumv); 169 170 wsrc += 16; 171 mask += 16; 172 pre += 16; 173 w -= 16; 174 } while (w != 0); 175 176 pre += pre_stride - width; 177 } while (--h != 0); 178 179 *sse = horizontal_add_s32x4(ssev); 180 *sum = horizontal_add_s32x4(sumv); 181 } 182 183 #endif // AOM_ARCH_AARCH64 184 185 static inline void obmc_variance_neon_128xh(const uint8_t *pre, int pre_stride, 186 const int32_t *wsrc, 187 const int32_t *mask, int h, 188 unsigned *sse, int *sum) { 189 obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 128, h, sse, sum); 190 } 191 192 static inline void obmc_variance_neon_64xh(const uint8_t *pre, int pre_stride, 193 const int32_t *wsrc, 194 const int32_t *mask, int h, 195 unsigned *sse, int *sum) { 196 obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 64, h, sse, sum); 197 } 198 199 static inline void obmc_variance_neon_32xh(const uint8_t *pre, int pre_stride, 200 const int32_t *wsrc, 201 const int32_t *mask, int h, 202 unsigned *sse, int *sum) { 203 obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 32, h, sse, sum); 204 } 205 206 static inline void obmc_variance_neon_16xh(const uint8_t *pre, int pre_stride, 207 const int32_t *wsrc, 208 const int32_t *mask, int h, 209 unsigned *sse, int *sum) { 210 obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 16, h, sse, sum); 211 } 212 213 static inline void obmc_variance_neon_8xh(const uint8_t *pre, int pre_stride, 214 const int32_t *wsrc, 215 const int32_t *mask, int h, 216 unsigned *sse, int *sum) { 217 int32x4_t ssev = vdupq_n_s32(0); 218 int32x4_t sumv = vdupq_n_s32(0); 219 220 do { 221 uint8x8_t pre_u8 = vld1_u8(pre); 222 int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(pre_u8)); 223 224 obmc_variance_8x1_s16_neon(pre_s16, wsrc, mask, &ssev, &sumv); 225 226 pre += pre_stride; 227 wsrc += 8; 228 mask += 8; 229 } while (--h != 0); 230 231 *sse = horizontal_add_s32x4(ssev); 232 *sum = horizontal_add_s32x4(sumv); 233 } 234 235 static inline void obmc_variance_neon_4xh(const uint8_t *pre, int pre_stride, 236 const int32_t *wsrc, 237 const int32_t *mask, int h, 238 unsigned *sse, int *sum) { 239 assert(h % 2 == 0); 240 241 int32x4_t ssev = vdupq_n_s32(0); 242 int32x4_t sumv = vdupq_n_s32(0); 243 244 do { 245 uint8x8_t pre_u8 = load_unaligned_u8(pre, pre_stride); 246 int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(pre_u8)); 247 248 obmc_variance_8x1_s16_neon(pre_s16, wsrc, mask, &ssev, &sumv); 249 250 pre += 2 * pre_stride; 251 wsrc += 8; 252 mask += 8; 253 h -= 2; 254 } while (h != 0); 255 256 *sse = horizontal_add_s32x4(ssev); 257 *sum = horizontal_add_s32x4(sumv); 258 } 259 260 #define OBMC_VARIANCE_WXH_NEON(W, H) \ 261 unsigned aom_obmc_variance##W##x##H##_neon( \ 262 const uint8_t *pre, int pre_stride, const int32_t *wsrc, \ 263 const int32_t *mask, unsigned *sse) { \ 264 int sum; \ 265 obmc_variance_neon_##W##xh(pre, pre_stride, wsrc, mask, H, sse, &sum); \ 266 return *sse - (unsigned)(((int64_t)sum * sum) / (W * H)); \ 267 } 268 269 OBMC_VARIANCE_WXH_NEON(4, 4) 270 OBMC_VARIANCE_WXH_NEON(4, 8) 271 OBMC_VARIANCE_WXH_NEON(8, 4) 272 OBMC_VARIANCE_WXH_NEON(8, 8) 273 OBMC_VARIANCE_WXH_NEON(8, 16) 274 OBMC_VARIANCE_WXH_NEON(16, 8) 275 OBMC_VARIANCE_WXH_NEON(16, 16) 276 OBMC_VARIANCE_WXH_NEON(16, 32) 277 OBMC_VARIANCE_WXH_NEON(32, 16) 278 OBMC_VARIANCE_WXH_NEON(32, 32) 279 OBMC_VARIANCE_WXH_NEON(32, 64) 280 OBMC_VARIANCE_WXH_NEON(64, 32) 281 OBMC_VARIANCE_WXH_NEON(64, 64) 282 OBMC_VARIANCE_WXH_NEON(64, 128) 283 OBMC_VARIANCE_WXH_NEON(128, 64) 284 OBMC_VARIANCE_WXH_NEON(128, 128) 285 OBMC_VARIANCE_WXH_NEON(4, 16) 286 OBMC_VARIANCE_WXH_NEON(16, 4) 287 OBMC_VARIANCE_WXH_NEON(8, 32) 288 OBMC_VARIANCE_WXH_NEON(32, 8) 289 OBMC_VARIANCE_WXH_NEON(16, 64) 290 OBMC_VARIANCE_WXH_NEON(64, 16)