highbd_sse_neon.c (10973B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 14 #include "config/aom_dsp_rtcd.h" 15 #include "aom_dsp/arm/sum_neon.h" 16 17 static inline void highbd_sse_8x1_init_neon(const uint16_t *src, 18 const uint16_t *ref, 19 uint32x4_t *sse_acc0, 20 uint32x4_t *sse_acc1) { 21 uint16x8_t s = vld1q_u16(src); 22 uint16x8_t r = vld1q_u16(ref); 23 24 uint16x8_t abs_diff = vabdq_u16(s, r); 25 uint16x4_t abs_diff_lo = vget_low_u16(abs_diff); 26 uint16x4_t abs_diff_hi = vget_high_u16(abs_diff); 27 28 *sse_acc0 = vmull_u16(abs_diff_lo, abs_diff_lo); 29 *sse_acc1 = vmull_u16(abs_diff_hi, abs_diff_hi); 30 } 31 32 static inline void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref, 33 uint32x4_t *sse_acc0, 34 uint32x4_t *sse_acc1) { 35 uint16x8_t s = vld1q_u16(src); 36 uint16x8_t r = vld1q_u16(ref); 37 38 uint16x8_t abs_diff = vabdq_u16(s, r); 39 uint16x4_t abs_diff_lo = vget_low_u16(abs_diff); 40 uint16x4_t abs_diff_hi = vget_high_u16(abs_diff); 41 42 *sse_acc0 = vmlal_u16(*sse_acc0, abs_diff_lo, abs_diff_lo); 43 *sse_acc1 = vmlal_u16(*sse_acc1, abs_diff_hi, abs_diff_hi); 44 } 45 46 static inline int64_t highbd_sse_128xh_neon(const uint16_t *src, int src_stride, 47 const uint16_t *ref, int ref_stride, 48 int height) { 49 uint32x4_t sse[16]; 50 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); 51 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); 52 highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); 53 highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); 54 highbd_sse_8x1_init_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]); 55 highbd_sse_8x1_init_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]); 56 highbd_sse_8x1_init_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]); 57 highbd_sse_8x1_init_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]); 58 highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]); 59 highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]); 60 highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]); 61 highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]); 62 highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]); 63 highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]); 64 highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]); 65 highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]); 66 67 src += src_stride; 68 ref += ref_stride; 69 70 while (--height != 0) { 71 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); 72 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); 73 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); 74 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); 75 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]); 76 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]); 77 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]); 78 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]); 79 highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]); 80 highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]); 81 highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]); 82 highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]); 83 highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]); 84 highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]); 85 highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]); 86 highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]); 87 88 src += src_stride; 89 ref += ref_stride; 90 } 91 92 return horizontal_long_add_u32x4_x16(sse); 93 } 94 95 static inline int64_t highbd_sse_64xh_neon(const uint16_t *src, int src_stride, 96 const uint16_t *ref, int ref_stride, 97 int height) { 98 uint32x4_t sse[8]; 99 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); 100 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); 101 highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); 102 highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); 103 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]); 104 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]); 105 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]); 106 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]); 107 108 src += src_stride; 109 ref += ref_stride; 110 111 while (--height != 0) { 112 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); 113 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); 114 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); 115 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); 116 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]); 117 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]); 118 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]); 119 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]); 120 121 src += src_stride; 122 ref += ref_stride; 123 } 124 125 return horizontal_long_add_u32x4_x8(sse); 126 } 127 128 static inline int64_t highbd_sse_32xh_neon(const uint16_t *src, int src_stride, 129 const uint16_t *ref, int ref_stride, 130 int height) { 131 uint32x4_t sse[8]; 132 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); 133 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); 134 highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); 135 highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); 136 137 src += src_stride; 138 ref += ref_stride; 139 140 while (--height != 0) { 141 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); 142 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); 143 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]); 144 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]); 145 146 src += src_stride; 147 ref += ref_stride; 148 } 149 150 return horizontal_long_add_u32x4_x8(sse); 151 } 152 153 static inline int64_t highbd_sse_16xh_neon(const uint16_t *src, int src_stride, 154 const uint16_t *ref, int ref_stride, 155 int height) { 156 uint32x4_t sse[4]; 157 highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); 158 highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); 159 160 src += src_stride; 161 ref += ref_stride; 162 163 while (--height != 0) { 164 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]); 165 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]); 166 167 src += src_stride; 168 ref += ref_stride; 169 } 170 171 return horizontal_long_add_u32x4_x4(sse); 172 } 173 174 static inline int64_t highbd_sse_8xh_neon(const uint16_t *src, int src_stride, 175 const uint16_t *ref, int ref_stride, 176 int height) { 177 uint32x4_t sse[2]; 178 highbd_sse_8x1_init_neon(src, ref, &sse[0], &sse[1]); 179 180 src += src_stride; 181 ref += ref_stride; 182 183 while (--height != 0) { 184 highbd_sse_8x1_neon(src, ref, &sse[0], &sse[1]); 185 186 src += src_stride; 187 ref += ref_stride; 188 } 189 190 return horizontal_long_add_u32x4_x2(sse); 191 } 192 193 static inline int64_t highbd_sse_4xh_neon(const uint16_t *src, int src_stride, 194 const uint16_t *ref, int ref_stride, 195 int height) { 196 // Peel the first loop iteration. 197 uint16x4_t s = vld1_u16(src); 198 uint16x4_t r = vld1_u16(ref); 199 200 uint16x4_t abs_diff = vabd_u16(s, r); 201 uint32x4_t sse = vmull_u16(abs_diff, abs_diff); 202 203 src += src_stride; 204 ref += ref_stride; 205 206 while (--height != 0) { 207 s = vld1_u16(src); 208 r = vld1_u16(ref); 209 210 abs_diff = vabd_u16(s, r); 211 sse = vmlal_u16(sse, abs_diff, abs_diff); 212 213 src += src_stride; 214 ref += ref_stride; 215 } 216 217 return horizontal_long_add_u32x4(sse); 218 } 219 220 static inline int64_t highbd_sse_wxh_neon(const uint16_t *src, int src_stride, 221 const uint16_t *ref, int ref_stride, 222 int width, int height) { 223 // { 0, 1, 2, 3, 4, 5, 6, 7 } 224 uint16x8_t k01234567 = vmovl_u8(vcreate_u8(0x0706050403020100)); 225 uint16x8_t remainder_mask = vcltq_u16(k01234567, vdupq_n_u16(width & 7)); 226 uint64_t sse = 0; 227 228 do { 229 int w = width; 230 int offset = 0; 231 232 do { 233 uint16x8_t s = vld1q_u16(src + offset); 234 uint16x8_t r = vld1q_u16(ref + offset); 235 236 if (w < 8) { 237 // Mask out-of-range elements. 238 s = vandq_u16(s, remainder_mask); 239 r = vandq_u16(r, remainder_mask); 240 } 241 242 uint16x8_t abs_diff = vabdq_u16(s, r); 243 uint16x4_t abs_diff_lo = vget_low_u16(abs_diff); 244 uint16x4_t abs_diff_hi = vget_high_u16(abs_diff); 245 246 uint32x4_t sse_u32 = vmull_u16(abs_diff_lo, abs_diff_lo); 247 sse_u32 = vmlal_u16(sse_u32, abs_diff_hi, abs_diff_hi); 248 249 sse += horizontal_long_add_u32x4(sse_u32); 250 251 offset += 8; 252 w -= 8; 253 } while (w > 0); 254 255 src += src_stride; 256 ref += ref_stride; 257 } while (--height != 0); 258 259 return sse; 260 } 261 262 int64_t aom_highbd_sse_neon(const uint8_t *src8, int src_stride, 263 const uint8_t *ref8, int ref_stride, int width, 264 int height) { 265 uint16_t *src = CONVERT_TO_SHORTPTR(src8); 266 uint16_t *ref = CONVERT_TO_SHORTPTR(ref8); 267 268 switch (width) { 269 case 4: 270 return highbd_sse_4xh_neon(src, src_stride, ref, ref_stride, height); 271 case 8: 272 return highbd_sse_8xh_neon(src, src_stride, ref, ref_stride, height); 273 case 16: 274 return highbd_sse_16xh_neon(src, src_stride, ref, ref_stride, height); 275 case 32: 276 return highbd_sse_32xh_neon(src, src_stride, ref, ref_stride, height); 277 case 64: 278 return highbd_sse_64xh_neon(src, src_stride, ref, ref_stride, height); 279 case 128: 280 return highbd_sse_128xh_neon(src, src_stride, ref, ref_stride, height); 281 default: 282 return highbd_sse_wxh_neon(src, src_stride, ref, ref_stride, width, 283 height); 284 } 285 }