selfguided_neon.c (56051B)
1 /* 2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "config/aom_config.h" 16 #include "config/av1_rtcd.h" 17 18 #include "aom_dsp/aom_dsp_common.h" 19 #include "aom_dsp/txfm_common.h" 20 #include "aom_dsp/arm/mem_neon.h" 21 #include "aom_dsp/arm/transpose_neon.h" 22 #include "aom_mem/aom_mem.h" 23 #include "aom_ports/mem.h" 24 #include "av1/common/av1_common_int.h" 25 #include "av1/common/common.h" 26 #include "av1/common/resize.h" 27 #include "av1/common/restoration.h" 28 29 // Constants used for right shift in final_filter calculation. 30 #define NB_EVEN 5 31 #define NB_ODD 4 32 33 static inline void calc_ab_fast_internal_common( 34 uint32x4_t s0, uint32x4_t s1, uint32x4_t s2, uint32x4_t s3, uint32x4_t s4, 35 uint32x4_t s5, uint32x4_t s6, uint32x4_t s7, int32x4_t sr4, int32x4_t sr5, 36 int32x4_t sr6, int32x4_t sr7, uint32x4_t const_n_val, uint32x4_t s_vec, 37 uint32x4_t const_val, uint32x4_t one_by_n_minus_1_vec, 38 uint16x4_t sgrproj_sgr, int32_t *src1, uint16_t *dst_A16, int32_t *src2, 39 const int buf_stride) { 40 uint32x4_t q0, q1, q2, q3; 41 uint32x4_t p0, p1, p2, p3; 42 uint16x4_t d0, d1, d2, d3; 43 44 s0 = vmulq_u32(s0, const_n_val); 45 s1 = vmulq_u32(s1, const_n_val); 46 s2 = vmulq_u32(s2, const_n_val); 47 s3 = vmulq_u32(s3, const_n_val); 48 49 q0 = vmulq_u32(s4, s4); 50 q1 = vmulq_u32(s5, s5); 51 q2 = vmulq_u32(s6, s6); 52 q3 = vmulq_u32(s7, s7); 53 54 p0 = vcleq_u32(q0, s0); 55 p1 = vcleq_u32(q1, s1); 56 p2 = vcleq_u32(q2, s2); 57 p3 = vcleq_u32(q3, s3); 58 59 q0 = vsubq_u32(s0, q0); 60 q1 = vsubq_u32(s1, q1); 61 q2 = vsubq_u32(s2, q2); 62 q3 = vsubq_u32(s3, q3); 63 64 p0 = vandq_u32(p0, q0); 65 p1 = vandq_u32(p1, q1); 66 p2 = vandq_u32(p2, q2); 67 p3 = vandq_u32(p3, q3); 68 69 p0 = vmulq_u32(p0, s_vec); 70 p1 = vmulq_u32(p1, s_vec); 71 p2 = vmulq_u32(p2, s_vec); 72 p3 = vmulq_u32(p3, s_vec); 73 74 p0 = vrshrq_n_u32(p0, SGRPROJ_MTABLE_BITS); 75 p1 = vrshrq_n_u32(p1, SGRPROJ_MTABLE_BITS); 76 p2 = vrshrq_n_u32(p2, SGRPROJ_MTABLE_BITS); 77 p3 = vrshrq_n_u32(p3, SGRPROJ_MTABLE_BITS); 78 79 p0 = vminq_u32(p0, const_val); 80 p1 = vminq_u32(p1, const_val); 81 p2 = vminq_u32(p2, const_val); 82 p3 = vminq_u32(p3, const_val); 83 84 { 85 store_u32_4x4((uint32_t *)src1, buf_stride, p0, p1, p2, p3); 86 87 for (int x = 0; x < 4; x++) { 88 for (int y = 0; y < 4; y++) { 89 dst_A16[x * buf_stride + y] = av1_x_by_xplus1[src1[x * buf_stride + y]]; 90 } 91 } 92 load_u16_4x4(dst_A16, buf_stride, &d0, &d1, &d2, &d3); 93 } 94 p0 = vsubl_u16(sgrproj_sgr, d0); 95 p1 = vsubl_u16(sgrproj_sgr, d1); 96 p2 = vsubl_u16(sgrproj_sgr, d2); 97 p3 = vsubl_u16(sgrproj_sgr, d3); 98 99 s4 = vmulq_u32(vreinterpretq_u32_s32(sr4), one_by_n_minus_1_vec); 100 s5 = vmulq_u32(vreinterpretq_u32_s32(sr5), one_by_n_minus_1_vec); 101 s6 = vmulq_u32(vreinterpretq_u32_s32(sr6), one_by_n_minus_1_vec); 102 s7 = vmulq_u32(vreinterpretq_u32_s32(sr7), one_by_n_minus_1_vec); 103 104 s4 = vmulq_u32(s4, p0); 105 s5 = vmulq_u32(s5, p1); 106 s6 = vmulq_u32(s6, p2); 107 s7 = vmulq_u32(s7, p3); 108 109 p0 = vrshrq_n_u32(s4, SGRPROJ_RECIP_BITS); 110 p1 = vrshrq_n_u32(s5, SGRPROJ_RECIP_BITS); 111 p2 = vrshrq_n_u32(s6, SGRPROJ_RECIP_BITS); 112 p3 = vrshrq_n_u32(s7, SGRPROJ_RECIP_BITS); 113 114 store_s32_4x4(src2, buf_stride, vreinterpretq_s32_u32(p0), 115 vreinterpretq_s32_u32(p1), vreinterpretq_s32_u32(p2), 116 vreinterpretq_s32_u32(p3)); 117 } 118 static inline void calc_ab_internal_common( 119 uint32x4_t s0, uint32x4_t s1, uint32x4_t s2, uint32x4_t s3, uint32x4_t s4, 120 uint32x4_t s5, uint32x4_t s6, uint32x4_t s7, uint16x8_t s16_0, 121 uint16x8_t s16_1, uint16x8_t s16_2, uint16x8_t s16_3, uint16x8_t s16_4, 122 uint16x8_t s16_5, uint16x8_t s16_6, uint16x8_t s16_7, 123 uint32x4_t const_n_val, uint32x4_t s_vec, uint32x4_t const_val, 124 uint16x4_t one_by_n_minus_1_vec, uint16x8_t sgrproj_sgr, int32_t *src1, 125 uint16_t *dst_A16, int32_t *dst2, const int buf_stride) { 126 uint16x4_t d0, d1, d2, d3, d4, d5, d6, d7; 127 uint32x4_t q0, q1, q2, q3, q4, q5, q6, q7; 128 uint32x4_t p0, p1, p2, p3, p4, p5, p6, p7; 129 130 s0 = vmulq_u32(s0, const_n_val); 131 s1 = vmulq_u32(s1, const_n_val); 132 s2 = vmulq_u32(s2, const_n_val); 133 s3 = vmulq_u32(s3, const_n_val); 134 s4 = vmulq_u32(s4, const_n_val); 135 s5 = vmulq_u32(s5, const_n_val); 136 s6 = vmulq_u32(s6, const_n_val); 137 s7 = vmulq_u32(s7, const_n_val); 138 139 d0 = vget_low_u16(s16_4); 140 d1 = vget_low_u16(s16_5); 141 d2 = vget_low_u16(s16_6); 142 d3 = vget_low_u16(s16_7); 143 d4 = vget_high_u16(s16_4); 144 d5 = vget_high_u16(s16_5); 145 d6 = vget_high_u16(s16_6); 146 d7 = vget_high_u16(s16_7); 147 148 q0 = vmull_u16(d0, d0); 149 q1 = vmull_u16(d1, d1); 150 q2 = vmull_u16(d2, d2); 151 q3 = vmull_u16(d3, d3); 152 q4 = vmull_u16(d4, d4); 153 q5 = vmull_u16(d5, d5); 154 q6 = vmull_u16(d6, d6); 155 q7 = vmull_u16(d7, d7); 156 157 p0 = vcleq_u32(q0, s0); 158 p1 = vcleq_u32(q1, s1); 159 p2 = vcleq_u32(q2, s2); 160 p3 = vcleq_u32(q3, s3); 161 p4 = vcleq_u32(q4, s4); 162 p5 = vcleq_u32(q5, s5); 163 p6 = vcleq_u32(q6, s6); 164 p7 = vcleq_u32(q7, s7); 165 166 q0 = vsubq_u32(s0, q0); 167 q1 = vsubq_u32(s1, q1); 168 q2 = vsubq_u32(s2, q2); 169 q3 = vsubq_u32(s3, q3); 170 q4 = vsubq_u32(s4, q4); 171 q5 = vsubq_u32(s5, q5); 172 q6 = vsubq_u32(s6, q6); 173 q7 = vsubq_u32(s7, q7); 174 175 p0 = vandq_u32(p0, q0); 176 p1 = vandq_u32(p1, q1); 177 p2 = vandq_u32(p2, q2); 178 p3 = vandq_u32(p3, q3); 179 p4 = vandq_u32(p4, q4); 180 p5 = vandq_u32(p5, q5); 181 p6 = vandq_u32(p6, q6); 182 p7 = vandq_u32(p7, q7); 183 184 p0 = vmulq_u32(p0, s_vec); 185 p1 = vmulq_u32(p1, s_vec); 186 p2 = vmulq_u32(p2, s_vec); 187 p3 = vmulq_u32(p3, s_vec); 188 p4 = vmulq_u32(p4, s_vec); 189 p5 = vmulq_u32(p5, s_vec); 190 p6 = vmulq_u32(p6, s_vec); 191 p7 = vmulq_u32(p7, s_vec); 192 193 p0 = vrshrq_n_u32(p0, SGRPROJ_MTABLE_BITS); 194 p1 = vrshrq_n_u32(p1, SGRPROJ_MTABLE_BITS); 195 p2 = vrshrq_n_u32(p2, SGRPROJ_MTABLE_BITS); 196 p3 = vrshrq_n_u32(p3, SGRPROJ_MTABLE_BITS); 197 p4 = vrshrq_n_u32(p4, SGRPROJ_MTABLE_BITS); 198 p5 = vrshrq_n_u32(p5, SGRPROJ_MTABLE_BITS); 199 p6 = vrshrq_n_u32(p6, SGRPROJ_MTABLE_BITS); 200 p7 = vrshrq_n_u32(p7, SGRPROJ_MTABLE_BITS); 201 202 p0 = vminq_u32(p0, const_val); 203 p1 = vminq_u32(p1, const_val); 204 p2 = vminq_u32(p2, const_val); 205 p3 = vminq_u32(p3, const_val); 206 p4 = vminq_u32(p4, const_val); 207 p5 = vminq_u32(p5, const_val); 208 p6 = vminq_u32(p6, const_val); 209 p7 = vminq_u32(p7, const_val); 210 211 { 212 store_u32_4x4((uint32_t *)src1, buf_stride, p0, p1, p2, p3); 213 store_u32_4x4((uint32_t *)src1 + 4, buf_stride, p4, p5, p6, p7); 214 215 for (int x = 0; x < 4; x++) { 216 for (int y = 0; y < 8; y++) { 217 dst_A16[x * buf_stride + y] = av1_x_by_xplus1[src1[x * buf_stride + y]]; 218 } 219 } 220 load_u16_8x4(dst_A16, buf_stride, &s16_4, &s16_5, &s16_6, &s16_7); 221 } 222 223 s16_4 = vsubq_u16(sgrproj_sgr, s16_4); 224 s16_5 = vsubq_u16(sgrproj_sgr, s16_5); 225 s16_6 = vsubq_u16(sgrproj_sgr, s16_6); 226 s16_7 = vsubq_u16(sgrproj_sgr, s16_7); 227 228 s0 = vmull_u16(vget_low_u16(s16_0), one_by_n_minus_1_vec); 229 s1 = vmull_u16(vget_low_u16(s16_1), one_by_n_minus_1_vec); 230 s2 = vmull_u16(vget_low_u16(s16_2), one_by_n_minus_1_vec); 231 s3 = vmull_u16(vget_low_u16(s16_3), one_by_n_minus_1_vec); 232 s4 = vmull_u16(vget_high_u16(s16_0), one_by_n_minus_1_vec); 233 s5 = vmull_u16(vget_high_u16(s16_1), one_by_n_minus_1_vec); 234 s6 = vmull_u16(vget_high_u16(s16_2), one_by_n_minus_1_vec); 235 s7 = vmull_u16(vget_high_u16(s16_3), one_by_n_minus_1_vec); 236 237 s0 = vmulq_u32(s0, vmovl_u16(vget_low_u16(s16_4))); 238 s1 = vmulq_u32(s1, vmovl_u16(vget_low_u16(s16_5))); 239 s2 = vmulq_u32(s2, vmovl_u16(vget_low_u16(s16_6))); 240 s3 = vmulq_u32(s3, vmovl_u16(vget_low_u16(s16_7))); 241 s4 = vmulq_u32(s4, vmovl_u16(vget_high_u16(s16_4))); 242 s5 = vmulq_u32(s5, vmovl_u16(vget_high_u16(s16_5))); 243 s6 = vmulq_u32(s6, vmovl_u16(vget_high_u16(s16_6))); 244 s7 = vmulq_u32(s7, vmovl_u16(vget_high_u16(s16_7))); 245 246 p0 = vrshrq_n_u32(s0, SGRPROJ_RECIP_BITS); 247 p1 = vrshrq_n_u32(s1, SGRPROJ_RECIP_BITS); 248 p2 = vrshrq_n_u32(s2, SGRPROJ_RECIP_BITS); 249 p3 = vrshrq_n_u32(s3, SGRPROJ_RECIP_BITS); 250 p4 = vrshrq_n_u32(s4, SGRPROJ_RECIP_BITS); 251 p5 = vrshrq_n_u32(s5, SGRPROJ_RECIP_BITS); 252 p6 = vrshrq_n_u32(s6, SGRPROJ_RECIP_BITS); 253 p7 = vrshrq_n_u32(s7, SGRPROJ_RECIP_BITS); 254 255 store_s32_4x4(dst2, buf_stride, vreinterpretq_s32_u32(p0), 256 vreinterpretq_s32_u32(p1), vreinterpretq_s32_u32(p2), 257 vreinterpretq_s32_u32(p3)); 258 store_s32_4x4(dst2 + 4, buf_stride, vreinterpretq_s32_u32(p4), 259 vreinterpretq_s32_u32(p5), vreinterpretq_s32_u32(p6), 260 vreinterpretq_s32_u32(p7)); 261 } 262 263 static inline void boxsum2_square_sum_calc( 264 int16x4_t t1, int16x4_t t2, int16x4_t t3, int16x4_t t4, int16x4_t t5, 265 int16x4_t t6, int16x4_t t7, int16x4_t t8, int16x4_t t9, int16x4_t t10, 266 int16x4_t t11, int32x4_t *r0, int32x4_t *r1, int32x4_t *r2, int32x4_t *r3) { 267 int32x4_t d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11; 268 int32x4_t r12, r34, r67, r89, r1011; 269 int32x4_t r345, r6789, r789; 270 271 d1 = vmull_s16(t1, t1); 272 d2 = vmull_s16(t2, t2); 273 d3 = vmull_s16(t3, t3); 274 d4 = vmull_s16(t4, t4); 275 d5 = vmull_s16(t5, t5); 276 d6 = vmull_s16(t6, t6); 277 d7 = vmull_s16(t7, t7); 278 d8 = vmull_s16(t8, t8); 279 d9 = vmull_s16(t9, t9); 280 d10 = vmull_s16(t10, t10); 281 d11 = vmull_s16(t11, t11); 282 283 r12 = vaddq_s32(d1, d2); 284 r34 = vaddq_s32(d3, d4); 285 r67 = vaddq_s32(d6, d7); 286 r89 = vaddq_s32(d8, d9); 287 r1011 = vaddq_s32(d10, d11); 288 r345 = vaddq_s32(r34, d5); 289 r6789 = vaddq_s32(r67, r89); 290 r789 = vsubq_s32(r6789, d6); 291 *r0 = vaddq_s32(r12, r345); 292 *r1 = vaddq_s32(r67, r345); 293 *r2 = vaddq_s32(d5, r6789); 294 *r3 = vaddq_s32(r789, r1011); 295 } 296 297 static inline void boxsum2(int16_t *src, const int src_stride, int16_t *dst16, 298 int32_t *dst32, int32_t *dst2, const int dst_stride, 299 const int width, const int height) { 300 assert(width > 2 * SGRPROJ_BORDER_HORZ); 301 assert(height > 2 * SGRPROJ_BORDER_VERT); 302 303 int16_t *dst1_16_ptr, *src_ptr; 304 int32_t *dst2_ptr; 305 int h, w, count = 0; 306 const int dst_stride_2 = (dst_stride << 1); 307 const int dst_stride_8 = (dst_stride << 3); 308 309 dst1_16_ptr = dst16; 310 dst2_ptr = dst2; 311 src_ptr = src; 312 w = width; 313 { 314 int16x8_t t1, t2, t3, t4, t5, t6, t7; 315 int16x8_t t8, t9, t10, t11, t12; 316 317 int16x8_t q12345, q56789, q34567, q7891011; 318 int16x8_t q12, q34, q67, q89, q1011; 319 int16x8_t q345, q6789, q789; 320 321 int32x4_t r12345, r56789, r34567, r7891011; 322 323 do { 324 h = height; 325 dst1_16_ptr = dst16 + (count << 3); 326 dst2_ptr = dst2 + (count << 3); 327 src_ptr = src + (count << 3); 328 329 dst1_16_ptr += dst_stride_2; 330 dst2_ptr += dst_stride_2; 331 do { 332 load_s16_8x4(src_ptr, src_stride, &t1, &t2, &t3, &t4); 333 src_ptr += 4 * src_stride; 334 load_s16_8x4(src_ptr, src_stride, &t5, &t6, &t7, &t8); 335 src_ptr += 4 * src_stride; 336 load_s16_8x4(src_ptr, src_stride, &t9, &t10, &t11, &t12); 337 338 q12 = vaddq_s16(t1, t2); 339 q34 = vaddq_s16(t3, t4); 340 q67 = vaddq_s16(t6, t7); 341 q89 = vaddq_s16(t8, t9); 342 q1011 = vaddq_s16(t10, t11); 343 q345 = vaddq_s16(q34, t5); 344 q6789 = vaddq_s16(q67, q89); 345 q789 = vaddq_s16(q89, t7); 346 q12345 = vaddq_s16(q12, q345); 347 q34567 = vaddq_s16(q67, q345); 348 q56789 = vaddq_s16(t5, q6789); 349 q7891011 = vaddq_s16(q789, q1011); 350 351 store_s16_8x4(dst1_16_ptr, dst_stride_2, q12345, q34567, q56789, 352 q7891011); 353 dst1_16_ptr += dst_stride_8; 354 355 boxsum2_square_sum_calc( 356 vget_low_s16(t1), vget_low_s16(t2), vget_low_s16(t3), 357 vget_low_s16(t4), vget_low_s16(t5), vget_low_s16(t6), 358 vget_low_s16(t7), vget_low_s16(t8), vget_low_s16(t9), 359 vget_low_s16(t10), vget_low_s16(t11), &r12345, &r34567, &r56789, 360 &r7891011); 361 362 store_s32_4x4(dst2_ptr, dst_stride_2, r12345, r34567, r56789, r7891011); 363 364 boxsum2_square_sum_calc( 365 vget_high_s16(t1), vget_high_s16(t2), vget_high_s16(t3), 366 vget_high_s16(t4), vget_high_s16(t5), vget_high_s16(t6), 367 vget_high_s16(t7), vget_high_s16(t8), vget_high_s16(t9), 368 vget_high_s16(t10), vget_high_s16(t11), &r12345, &r34567, &r56789, 369 &r7891011); 370 371 store_s32_4x4(dst2_ptr + 4, dst_stride_2, r12345, r34567, r56789, 372 r7891011); 373 dst2_ptr += (dst_stride_8); 374 h -= 8; 375 } while (h > 0); 376 w -= 8; 377 count++; 378 } while (w > 0); 379 380 // memset needed for row pixels as 2nd stage of boxsum filter uses 381 // first 2 rows of dst16, dst2 buffer which is not filled in first stage. 382 for (int x = 0; x < 2; x++) { 383 memset(dst16 + x * dst_stride, 0, (width + 4) * sizeof(*dst16)); 384 memset(dst2 + x * dst_stride, 0, (width + 4) * sizeof(*dst2)); 385 } 386 387 // memset needed for extra columns as 2nd stage of boxsum filter uses 388 // last 2 columns of dst16, dst2 buffer which is not filled in first stage. 389 for (int x = 2; x < height + 2; x++) { 390 int dst_offset = x * dst_stride + width + 2; 391 memset(dst16 + dst_offset, 0, 3 * sizeof(*dst16)); 392 memset(dst2 + dst_offset, 0, 3 * sizeof(*dst2)); 393 } 394 } 395 396 { 397 int16x4_t s1, s2, s3, s4, s5, s6, s7, s8; 398 int32x4_t d1, d2, d3, d4, d5, d6, d7, d8; 399 int32x4_t q12345, q34567, q23456, q45678; 400 int32x4_t q23, q45, q67; 401 int32x4_t q2345, q4567; 402 403 int32x4_t r12345, r34567, r23456, r45678; 404 int32x4_t r23, r45, r67; 405 int32x4_t r2345, r4567; 406 407 int32_t *src2_ptr, *dst1_32_ptr; 408 int16_t *src1_ptr; 409 count = 0; 410 h = height; 411 do { 412 dst1_32_ptr = dst32 + count * dst_stride_8 + (dst_stride_2); 413 dst2_ptr = dst2 + count * dst_stride_8 + (dst_stride_2); 414 src1_ptr = dst16 + count * dst_stride_8 + (dst_stride_2); 415 src2_ptr = dst2 + count * dst_stride_8 + (dst_stride_2); 416 w = width; 417 418 dst1_32_ptr += 2; 419 dst2_ptr += 2; 420 load_s16_4x4(src1_ptr, dst_stride_2, &s1, &s2, &s3, &s4); 421 transpose_elems_inplace_s16_4x4(&s1, &s2, &s3, &s4); 422 load_s32_4x4(src2_ptr, dst_stride_2, &d1, &d2, &d3, &d4); 423 transpose_elems_inplace_s32_4x4(&d1, &d2, &d3, &d4); 424 do { 425 src1_ptr += 4; 426 src2_ptr += 4; 427 load_s16_4x4(src1_ptr, dst_stride_2, &s5, &s6, &s7, &s8); 428 transpose_elems_inplace_s16_4x4(&s5, &s6, &s7, &s8); 429 load_s32_4x4(src2_ptr, dst_stride_2, &d5, &d6, &d7, &d8); 430 transpose_elems_inplace_s32_4x4(&d5, &d6, &d7, &d8); 431 q23 = vaddl_s16(s2, s3); 432 q45 = vaddl_s16(s4, s5); 433 q67 = vaddl_s16(s6, s7); 434 q2345 = vaddq_s32(q23, q45); 435 q4567 = vaddq_s32(q45, q67); 436 q12345 = vaddq_s32(vmovl_s16(s1), q2345); 437 q23456 = vaddq_s32(q2345, vmovl_s16(s6)); 438 q34567 = vaddq_s32(q4567, vmovl_s16(s3)); 439 q45678 = vaddq_s32(q4567, vmovl_s16(s8)); 440 441 transpose_elems_inplace_s32_4x4(&q12345, &q23456, &q34567, &q45678); 442 store_s32_4x4(dst1_32_ptr, dst_stride_2, q12345, q23456, q34567, 443 q45678); 444 dst1_32_ptr += 4; 445 s1 = s5; 446 s2 = s6; 447 s3 = s7; 448 s4 = s8; 449 450 r23 = vaddq_s32(d2, d3); 451 r45 = vaddq_s32(d4, d5); 452 r67 = vaddq_s32(d6, d7); 453 r2345 = vaddq_s32(r23, r45); 454 r4567 = vaddq_s32(r45, r67); 455 r12345 = vaddq_s32(d1, r2345); 456 r23456 = vaddq_s32(r2345, d6); 457 r34567 = vaddq_s32(r4567, d3); 458 r45678 = vaddq_s32(r4567, d8); 459 460 transpose_elems_inplace_s32_4x4(&r12345, &r23456, &r34567, &r45678); 461 store_s32_4x4(dst2_ptr, dst_stride_2, r12345, r23456, r34567, r45678); 462 dst2_ptr += 4; 463 d1 = d5; 464 d2 = d6; 465 d3 = d7; 466 d4 = d8; 467 w -= 4; 468 } while (w > 0); 469 h -= 8; 470 count++; 471 } while (h > 0); 472 } 473 } 474 475 static inline void calc_ab_internal_lbd(int32_t *A, uint16_t *A16, 476 uint16_t *B16, int32_t *B, 477 const int buf_stride, const int width, 478 const int height, const int r, 479 const int s, const int ht_inc) { 480 int32_t *src1, *dst2, count = 0; 481 uint16_t *dst_A16, *src2; 482 const uint32_t n = (2 * r + 1) * (2 * r + 1); 483 const uint32x4_t const_n_val = vdupq_n_u32(n); 484 const uint16x8_t sgrproj_sgr = vdupq_n_u16(SGRPROJ_SGR); 485 const uint16x4_t one_by_n_minus_1_vec = vdup_n_u16(av1_one_by_x[n - 1]); 486 const uint32x4_t const_val = vdupq_n_u32(255); 487 488 uint16x8_t s16_0, s16_1, s16_2, s16_3, s16_4, s16_5, s16_6, s16_7; 489 490 uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7; 491 492 const uint32x4_t s_vec = vdupq_n_u32(s); 493 int w, h = height; 494 495 do { 496 dst_A16 = A16 + (count << 2) * buf_stride; 497 src1 = A + (count << 2) * buf_stride; 498 src2 = B16 + (count << 2) * buf_stride; 499 dst2 = B + (count << 2) * buf_stride; 500 w = width; 501 do { 502 load_u32_4x4((uint32_t *)src1, buf_stride, &s0, &s1, &s2, &s3); 503 load_u32_4x4((uint32_t *)src1 + 4, buf_stride, &s4, &s5, &s6, &s7); 504 load_u16_8x4(src2, buf_stride, &s16_0, &s16_1, &s16_2, &s16_3); 505 506 s16_4 = s16_0; 507 s16_5 = s16_1; 508 s16_6 = s16_2; 509 s16_7 = s16_3; 510 511 calc_ab_internal_common( 512 s0, s1, s2, s3, s4, s5, s6, s7, s16_0, s16_1, s16_2, s16_3, s16_4, 513 s16_5, s16_6, s16_7, const_n_val, s_vec, const_val, 514 one_by_n_minus_1_vec, sgrproj_sgr, src1, dst_A16, dst2, buf_stride); 515 516 w -= 8; 517 dst2 += 8; 518 src1 += 8; 519 src2 += 8; 520 dst_A16 += 8; 521 } while (w > 0); 522 count++; 523 h -= (ht_inc * 4); 524 } while (h > 0); 525 } 526 527 #if CONFIG_AV1_HIGHBITDEPTH 528 static inline void calc_ab_internal_hbd(int32_t *A, uint16_t *A16, 529 uint16_t *B16, int32_t *B, 530 const int buf_stride, const int width, 531 const int height, const int bit_depth, 532 const int r, const int s, 533 const int ht_inc) { 534 int32_t *src1, *dst2, count = 0; 535 uint16_t *dst_A16, *src2; 536 const uint32_t n = (2 * r + 1) * (2 * r + 1); 537 const int16x8_t bd_min_2_vec = vdupq_n_s16(-(bit_depth - 8)); 538 const int32x4_t bd_min_1_vec = vdupq_n_s32(-((bit_depth - 8) << 1)); 539 const uint32x4_t const_n_val = vdupq_n_u32(n); 540 const uint16x8_t sgrproj_sgr = vdupq_n_u16(SGRPROJ_SGR); 541 const uint16x4_t one_by_n_minus_1_vec = vdup_n_u16(av1_one_by_x[n - 1]); 542 const uint32x4_t const_val = vdupq_n_u32(255); 543 544 int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7; 545 uint16x8_t s16_0, s16_1, s16_2, s16_3; 546 uint16x8_t s16_4, s16_5, s16_6, s16_7; 547 uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7; 548 549 const uint32x4_t s_vec = vdupq_n_u32(s); 550 int w, h = height; 551 552 do { 553 src1 = A + (count << 2) * buf_stride; 554 src2 = B16 + (count << 2) * buf_stride; 555 dst2 = B + (count << 2) * buf_stride; 556 dst_A16 = A16 + (count << 2) * buf_stride; 557 w = width; 558 do { 559 load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3); 560 load_s32_4x4(src1 + 4, buf_stride, &sr4, &sr5, &sr6, &sr7); 561 load_u16_8x4(src2, buf_stride, &s16_0, &s16_1, &s16_2, &s16_3); 562 563 s0 = vrshlq_u32(vreinterpretq_u32_s32(sr0), bd_min_1_vec); 564 s1 = vrshlq_u32(vreinterpretq_u32_s32(sr1), bd_min_1_vec); 565 s2 = vrshlq_u32(vreinterpretq_u32_s32(sr2), bd_min_1_vec); 566 s3 = vrshlq_u32(vreinterpretq_u32_s32(sr3), bd_min_1_vec); 567 s4 = vrshlq_u32(vreinterpretq_u32_s32(sr4), bd_min_1_vec); 568 s5 = vrshlq_u32(vreinterpretq_u32_s32(sr5), bd_min_1_vec); 569 s6 = vrshlq_u32(vreinterpretq_u32_s32(sr6), bd_min_1_vec); 570 s7 = vrshlq_u32(vreinterpretq_u32_s32(sr7), bd_min_1_vec); 571 572 s16_4 = vrshlq_u16(s16_0, bd_min_2_vec); 573 s16_5 = vrshlq_u16(s16_1, bd_min_2_vec); 574 s16_6 = vrshlq_u16(s16_2, bd_min_2_vec); 575 s16_7 = vrshlq_u16(s16_3, bd_min_2_vec); 576 577 calc_ab_internal_common( 578 s0, s1, s2, s3, s4, s5, s6, s7, s16_0, s16_1, s16_2, s16_3, s16_4, 579 s16_5, s16_6, s16_7, const_n_val, s_vec, const_val, 580 one_by_n_minus_1_vec, sgrproj_sgr, src1, dst_A16, dst2, buf_stride); 581 582 w -= 8; 583 dst2 += 8; 584 src1 += 8; 585 src2 += 8; 586 dst_A16 += 8; 587 } while (w > 0); 588 count++; 589 h -= (ht_inc * 4); 590 } while (h > 0); 591 } 592 #endif // CONFIG_AV1_HIGHBITDEPTH 593 594 static inline void calc_ab_fast_internal_lbd(int32_t *A, uint16_t *A16, 595 int32_t *B, const int buf_stride, 596 const int width, const int height, 597 const int r, const int s, 598 const int ht_inc) { 599 int32_t *src1, *src2, count = 0; 600 uint16_t *dst_A16; 601 const uint32_t n = (2 * r + 1) * (2 * r + 1); 602 const uint32x4_t const_n_val = vdupq_n_u32(n); 603 const uint16x4_t sgrproj_sgr = vdup_n_u16(SGRPROJ_SGR); 604 const uint32x4_t one_by_n_minus_1_vec = vdupq_n_u32(av1_one_by_x[n - 1]); 605 const uint32x4_t const_val = vdupq_n_u32(255); 606 607 int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7; 608 uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7; 609 610 const uint32x4_t s_vec = vdupq_n_u32(s); 611 int w, h = height; 612 613 do { 614 src1 = A + (count << 2) * buf_stride; 615 src2 = B + (count << 2) * buf_stride; 616 dst_A16 = A16 + (count << 2) * buf_stride; 617 w = width; 618 do { 619 load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3); 620 load_s32_4x4(src2, buf_stride, &sr4, &sr5, &sr6, &sr7); 621 622 s0 = vreinterpretq_u32_s32(sr0); 623 s1 = vreinterpretq_u32_s32(sr1); 624 s2 = vreinterpretq_u32_s32(sr2); 625 s3 = vreinterpretq_u32_s32(sr3); 626 s4 = vreinterpretq_u32_s32(sr4); 627 s5 = vreinterpretq_u32_s32(sr5); 628 s6 = vreinterpretq_u32_s32(sr6); 629 s7 = vreinterpretq_u32_s32(sr7); 630 631 calc_ab_fast_internal_common(s0, s1, s2, s3, s4, s5, s6, s7, sr4, sr5, 632 sr6, sr7, const_n_val, s_vec, const_val, 633 one_by_n_minus_1_vec, sgrproj_sgr, src1, 634 dst_A16, src2, buf_stride); 635 636 w -= 4; 637 src1 += 4; 638 src2 += 4; 639 dst_A16 += 4; 640 } while (w > 0); 641 count++; 642 h -= (ht_inc * 4); 643 } while (h > 0); 644 } 645 646 #if CONFIG_AV1_HIGHBITDEPTH 647 static inline void calc_ab_fast_internal_hbd(int32_t *A, uint16_t *A16, 648 int32_t *B, const int buf_stride, 649 const int width, const int height, 650 const int bit_depth, const int r, 651 const int s, const int ht_inc) { 652 int32_t *src1, *src2, count = 0; 653 uint16_t *dst_A16; 654 const uint32_t n = (2 * r + 1) * (2 * r + 1); 655 const int32x4_t bd_min_2_vec = vdupq_n_s32(-(bit_depth - 8)); 656 const int32x4_t bd_min_1_vec = vdupq_n_s32(-((bit_depth - 8) << 1)); 657 const uint32x4_t const_n_val = vdupq_n_u32(n); 658 const uint16x4_t sgrproj_sgr = vdup_n_u16(SGRPROJ_SGR); 659 const uint32x4_t one_by_n_minus_1_vec = vdupq_n_u32(av1_one_by_x[n - 1]); 660 const uint32x4_t const_val = vdupq_n_u32(255); 661 662 int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7; 663 uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7; 664 665 const uint32x4_t s_vec = vdupq_n_u32(s); 666 int w, h = height; 667 668 do { 669 src1 = A + (count << 2) * buf_stride; 670 src2 = B + (count << 2) * buf_stride; 671 dst_A16 = A16 + (count << 2) * buf_stride; 672 w = width; 673 do { 674 load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3); 675 load_s32_4x4(src2, buf_stride, &sr4, &sr5, &sr6, &sr7); 676 677 s0 = vrshlq_u32(vreinterpretq_u32_s32(sr0), bd_min_1_vec); 678 s1 = vrshlq_u32(vreinterpretq_u32_s32(sr1), bd_min_1_vec); 679 s2 = vrshlq_u32(vreinterpretq_u32_s32(sr2), bd_min_1_vec); 680 s3 = vrshlq_u32(vreinterpretq_u32_s32(sr3), bd_min_1_vec); 681 s4 = vrshlq_u32(vreinterpretq_u32_s32(sr4), bd_min_2_vec); 682 s5 = vrshlq_u32(vreinterpretq_u32_s32(sr5), bd_min_2_vec); 683 s6 = vrshlq_u32(vreinterpretq_u32_s32(sr6), bd_min_2_vec); 684 s7 = vrshlq_u32(vreinterpretq_u32_s32(sr7), bd_min_2_vec); 685 686 calc_ab_fast_internal_common(s0, s1, s2, s3, s4, s5, s6, s7, sr4, sr5, 687 sr6, sr7, const_n_val, s_vec, const_val, 688 one_by_n_minus_1_vec, sgrproj_sgr, src1, 689 dst_A16, src2, buf_stride); 690 691 w -= 4; 692 src1 += 4; 693 src2 += 4; 694 dst_A16 += 4; 695 } while (w > 0); 696 count++; 697 h -= (ht_inc * 4); 698 } while (h > 0); 699 } 700 #endif // CONFIG_AV1_HIGHBITDEPTH 701 702 static inline void boxsum1(int16_t *src, const int src_stride, uint16_t *dst1, 703 int32_t *dst2, const int dst_stride, const int width, 704 const int height) { 705 assert(width > 2 * SGRPROJ_BORDER_HORZ); 706 assert(height > 2 * SGRPROJ_BORDER_VERT); 707 708 int16_t *src_ptr; 709 int32_t *dst2_ptr; 710 uint16_t *dst1_ptr; 711 int h, w, count = 0; 712 713 w = width; 714 { 715 int16x8_t s1, s2, s3, s4, s5, s6, s7, s8; 716 int16x8_t q23, q34, q56, q234, q345, q456, q567; 717 int32x4_t r23, r56, r345, r456, r567, r78, r678; 718 int32x4_t r4_low, r4_high, r34_low, r34_high, r234_low, r234_high; 719 int32x4_t r2, r3, r5, r6, r7, r8; 720 int16x8_t q678, q78; 721 722 do { 723 dst1_ptr = dst1 + (count << 3); 724 dst2_ptr = dst2 + (count << 3); 725 src_ptr = src + (count << 3); 726 h = height; 727 728 load_s16_8x4(src_ptr, src_stride, &s1, &s2, &s3, &s4); 729 src_ptr += 4 * src_stride; 730 731 q23 = vaddq_s16(s2, s3); 732 q234 = vaddq_s16(q23, s4); 733 q34 = vaddq_s16(s3, s4); 734 dst1_ptr += (dst_stride << 1); 735 736 r2 = vmull_s16(vget_low_s16(s2), vget_low_s16(s2)); 737 r3 = vmull_s16(vget_low_s16(s3), vget_low_s16(s3)); 738 r4_low = vmull_s16(vget_low_s16(s4), vget_low_s16(s4)); 739 r23 = vaddq_s32(r2, r3); 740 r234_low = vaddq_s32(r23, r4_low); 741 r34_low = vaddq_s32(r3, r4_low); 742 743 r2 = vmull_s16(vget_high_s16(s2), vget_high_s16(s2)); 744 r3 = vmull_s16(vget_high_s16(s3), vget_high_s16(s3)); 745 r4_high = vmull_s16(vget_high_s16(s4), vget_high_s16(s4)); 746 r23 = vaddq_s32(r2, r3); 747 r234_high = vaddq_s32(r23, r4_high); 748 r34_high = vaddq_s32(r3, r4_high); 749 750 dst2_ptr += (dst_stride << 1); 751 752 do { 753 load_s16_8x4(src_ptr, src_stride, &s5, &s6, &s7, &s8); 754 src_ptr += 4 * src_stride; 755 756 q345 = vaddq_s16(s5, q34); 757 q56 = vaddq_s16(s5, s6); 758 q456 = vaddq_s16(s4, q56); 759 q567 = vaddq_s16(s7, q56); 760 q78 = vaddq_s16(s7, s8); 761 q678 = vaddq_s16(s6, q78); 762 763 store_s16_8x4((int16_t *)dst1_ptr, dst_stride, q234, q345, q456, q567); 764 dst1_ptr += (dst_stride << 2); 765 766 s4 = s8; 767 q34 = q78; 768 q234 = q678; 769 770 r5 = vmull_s16(vget_low_s16(s5), vget_low_s16(s5)); 771 r6 = vmull_s16(vget_low_s16(s6), vget_low_s16(s6)); 772 r7 = vmull_s16(vget_low_s16(s7), vget_low_s16(s7)); 773 r8 = vmull_s16(vget_low_s16(s8), vget_low_s16(s8)); 774 775 r345 = vaddq_s32(r5, r34_low); 776 r56 = vaddq_s32(r5, r6); 777 r456 = vaddq_s32(r4_low, r56); 778 r567 = vaddq_s32(r7, r56); 779 r78 = vaddq_s32(r7, r8); 780 r678 = vaddq_s32(r6, r78); 781 store_s32_4x4(dst2_ptr, dst_stride, r234_low, r345, r456, r567); 782 783 r4_low = r8; 784 r34_low = r78; 785 r234_low = r678; 786 787 r5 = vmull_s16(vget_high_s16(s5), vget_high_s16(s5)); 788 r6 = vmull_s16(vget_high_s16(s6), vget_high_s16(s6)); 789 r7 = vmull_s16(vget_high_s16(s7), vget_high_s16(s7)); 790 r8 = vmull_s16(vget_high_s16(s8), vget_high_s16(s8)); 791 792 r345 = vaddq_s32(r5, r34_high); 793 r56 = vaddq_s32(r5, r6); 794 r456 = vaddq_s32(r4_high, r56); 795 r567 = vaddq_s32(r7, r56); 796 r78 = vaddq_s32(r7, r8); 797 r678 = vaddq_s32(r6, r78); 798 store_s32_4x4((dst2_ptr + 4), dst_stride, r234_high, r345, r456, r567); 799 dst2_ptr += (dst_stride << 2); 800 801 r4_high = r8; 802 r34_high = r78; 803 r234_high = r678; 804 805 h -= 4; 806 } while (h > 0); 807 w -= 8; 808 count++; 809 } while (w > 0); 810 811 // memset needed for row pixels as 2nd stage of boxsum filter uses 812 // first 2 rows of dst1, dst2 buffer which is not filled in first stage. 813 for (int x = 0; x < 2; x++) { 814 memset(dst1 + x * dst_stride, 0, (width + 4) * sizeof(*dst1)); 815 memset(dst2 + x * dst_stride, 0, (width + 4) * sizeof(*dst2)); 816 } 817 818 // memset needed for extra columns as 2nd stage of boxsum filter uses 819 // last 2 columns of dst1, dst2 buffer which is not filled in first stage. 820 for (int x = 2; x < height + 2; x++) { 821 int dst_offset = x * dst_stride + width + 2; 822 memset(dst1 + dst_offset, 0, 3 * sizeof(*dst1)); 823 memset(dst2 + dst_offset, 0, 3 * sizeof(*dst2)); 824 } 825 } 826 827 { 828 int16x4_t d1, d2, d3, d4, d5, d6, d7, d8; 829 int16x4_t q23, q34, q56, q234, q345, q456, q567; 830 int32x4_t r23, r56, r234, r345, r456, r567, r34, r78, r678; 831 int32x4_t r1, r2, r3, r4, r5, r6, r7, r8; 832 int16x4_t q678, q78; 833 834 int32_t *src2_ptr; 835 uint16_t *src1_ptr; 836 count = 0; 837 h = height; 838 w = width; 839 do { 840 dst1_ptr = dst1 + (count << 2) * dst_stride; 841 dst2_ptr = dst2 + (count << 2) * dst_stride; 842 src1_ptr = dst1 + (count << 2) * dst_stride; 843 src2_ptr = dst2 + (count << 2) * dst_stride; 844 w = width; 845 846 load_s16_4x4((int16_t *)src1_ptr, dst_stride, &d1, &d2, &d3, &d4); 847 transpose_elems_inplace_s16_4x4(&d1, &d2, &d3, &d4); 848 load_s32_4x4(src2_ptr, dst_stride, &r1, &r2, &r3, &r4); 849 transpose_elems_inplace_s32_4x4(&r1, &r2, &r3, &r4); 850 src1_ptr += 4; 851 src2_ptr += 4; 852 853 q23 = vadd_s16(d2, d3); 854 q234 = vadd_s16(q23, d4); 855 q34 = vadd_s16(d3, d4); 856 dst1_ptr += 2; 857 r23 = vaddq_s32(r2, r3); 858 r234 = vaddq_s32(r23, r4); 859 r34 = vaddq_s32(r3, r4); 860 dst2_ptr += 2; 861 862 do { 863 load_s16_4x4((int16_t *)src1_ptr, dst_stride, &d5, &d6, &d7, &d8); 864 transpose_elems_inplace_s16_4x4(&d5, &d6, &d7, &d8); 865 load_s32_4x4(src2_ptr, dst_stride, &r5, &r6, &r7, &r8); 866 transpose_elems_inplace_s32_4x4(&r5, &r6, &r7, &r8); 867 src1_ptr += 4; 868 src2_ptr += 4; 869 870 q345 = vadd_s16(d5, q34); 871 q56 = vadd_s16(d5, d6); 872 q456 = vadd_s16(d4, q56); 873 q567 = vadd_s16(d7, q56); 874 q78 = vadd_s16(d7, d8); 875 q678 = vadd_s16(d6, q78); 876 transpose_elems_inplace_s16_4x4(&q234, &q345, &q456, &q567); 877 store_s16_4x4((int16_t *)dst1_ptr, dst_stride, q234, q345, q456, q567); 878 dst1_ptr += 4; 879 880 d4 = d8; 881 q34 = q78; 882 q234 = q678; 883 884 r345 = vaddq_s32(r5, r34); 885 r56 = vaddq_s32(r5, r6); 886 r456 = vaddq_s32(r4, r56); 887 r567 = vaddq_s32(r7, r56); 888 r78 = vaddq_s32(r7, r8); 889 r678 = vaddq_s32(r6, r78); 890 transpose_elems_inplace_s32_4x4(&r234, &r345, &r456, &r567); 891 store_s32_4x4(dst2_ptr, dst_stride, r234, r345, r456, r567); 892 dst2_ptr += 4; 893 894 r4 = r8; 895 r34 = r78; 896 r234 = r678; 897 w -= 4; 898 } while (w > 0); 899 h -= 4; 900 count++; 901 } while (h > 0); 902 } 903 } 904 905 static inline int32x4_t cross_sum_inp_s32(int32_t *buf, int buf_stride) { 906 int32x4_t xtr, xt, xtl, xl, x, xr, xbr, xb, xbl; 907 int32x4_t fours, threes, res; 908 909 xtl = vld1q_s32(buf - buf_stride - 1); 910 xt = vld1q_s32(buf - buf_stride); 911 xtr = vld1q_s32(buf - buf_stride + 1); 912 xl = vld1q_s32(buf - 1); 913 x = vld1q_s32(buf); 914 xr = vld1q_s32(buf + 1); 915 xbl = vld1q_s32(buf + buf_stride - 1); 916 xb = vld1q_s32(buf + buf_stride); 917 xbr = vld1q_s32(buf + buf_stride + 1); 918 919 fours = vaddq_s32(xl, vaddq_s32(xt, vaddq_s32(xr, vaddq_s32(xb, x)))); 920 threes = vaddq_s32(xtl, vaddq_s32(xtr, vaddq_s32(xbr, xbl))); 921 res = vsubq_s32(vshlq_n_s32(vaddq_s32(fours, threes), 2), threes); 922 return res; 923 } 924 925 static inline void cross_sum_inp_u16(uint16_t *buf, int buf_stride, 926 int32x4_t *a0, int32x4_t *a1) { 927 uint16x8_t xtr, xt, xtl, xl, x, xr, xbr, xb, xbl; 928 uint16x8_t r0, r1; 929 930 xtl = vld1q_u16(buf - buf_stride - 1); 931 xt = vld1q_u16(buf - buf_stride); 932 xtr = vld1q_u16(buf - buf_stride + 1); 933 xl = vld1q_u16(buf - 1); 934 x = vld1q_u16(buf); 935 xr = vld1q_u16(buf + 1); 936 xbl = vld1q_u16(buf + buf_stride - 1); 937 xb = vld1q_u16(buf + buf_stride); 938 xbr = vld1q_u16(buf + buf_stride + 1); 939 940 xb = vaddq_u16(xb, x); 941 xt = vaddq_u16(xt, xr); 942 xl = vaddq_u16(xl, xb); 943 xl = vaddq_u16(xl, xt); 944 945 r0 = vshlq_n_u16(xl, 2); 946 947 xbl = vaddq_u16(xbl, xbr); 948 xtl = vaddq_u16(xtl, xtr); 949 xtl = vaddq_u16(xtl, xbl); 950 951 r1 = vshlq_n_u16(xtl, 2); 952 r1 = vsubq_u16(r1, xtl); 953 954 *a0 = vreinterpretq_s32_u32( 955 vaddq_u32(vmovl_u16(vget_low_u16(r0)), vmovl_u16(vget_low_u16(r1)))); 956 *a1 = vreinterpretq_s32_u32( 957 vaddq_u32(vmovl_u16(vget_high_u16(r0)), vmovl_u16(vget_high_u16(r1)))); 958 } 959 960 static inline int32x4_t cross_sum_fast_even_row(int32_t *buf, int buf_stride) { 961 int32x4_t xtr, xt, xtl, xbr, xb, xbl; 962 int32x4_t fives, sixes, fives_plus_sixes; 963 964 xtl = vld1q_s32(buf - buf_stride - 1); 965 xt = vld1q_s32(buf - buf_stride); 966 xtr = vld1q_s32(buf - buf_stride + 1); 967 xbl = vld1q_s32(buf + buf_stride - 1); 968 xb = vld1q_s32(buf + buf_stride); 969 xbr = vld1q_s32(buf + buf_stride + 1); 970 971 fives = vaddq_s32(xtl, vaddq_s32(xtr, vaddq_s32(xbr, xbl))); 972 sixes = vaddq_s32(xt, xb); 973 fives_plus_sixes = vaddq_s32(fives, sixes); 974 975 return vaddq_s32( 976 vaddq_s32(vshlq_n_s32(fives_plus_sixes, 2), fives_plus_sixes), sixes); 977 } 978 979 static inline void cross_sum_fast_even_row_inp16(uint16_t *buf, int buf_stride, 980 int32x4_t *a0, int32x4_t *a1) { 981 uint16x8_t xtr, xt, xtl, xbr, xb, xbl, xb0; 982 983 xtl = vld1q_u16(buf - buf_stride - 1); 984 xt = vld1q_u16(buf - buf_stride); 985 xtr = vld1q_u16(buf - buf_stride + 1); 986 xbl = vld1q_u16(buf + buf_stride - 1); 987 xb = vld1q_u16(buf + buf_stride); 988 xbr = vld1q_u16(buf + buf_stride + 1); 989 990 xbr = vaddq_u16(xbr, xbl); 991 xtr = vaddq_u16(xtr, xtl); 992 xbr = vaddq_u16(xbr, xtr); 993 xtl = vshlq_n_u16(xbr, 2); 994 xbr = vaddq_u16(xtl, xbr); 995 996 xb = vaddq_u16(xb, xt); 997 xb0 = vshlq_n_u16(xb, 1); 998 xb = vshlq_n_u16(xb, 2); 999 xb = vaddq_u16(xb, xb0); 1000 1001 *a0 = vreinterpretq_s32_u32( 1002 vaddq_u32(vmovl_u16(vget_low_u16(xbr)), vmovl_u16(vget_low_u16(xb)))); 1003 *a1 = vreinterpretq_s32_u32( 1004 vaddq_u32(vmovl_u16(vget_high_u16(xbr)), vmovl_u16(vget_high_u16(xb)))); 1005 } 1006 1007 static inline int32x4_t cross_sum_fast_odd_row(int32_t *buf) { 1008 int32x4_t xl, x, xr; 1009 int32x4_t fives, sixes, fives_plus_sixes; 1010 1011 xl = vld1q_s32(buf - 1); 1012 x = vld1q_s32(buf); 1013 xr = vld1q_s32(buf + 1); 1014 fives = vaddq_s32(xl, xr); 1015 sixes = x; 1016 fives_plus_sixes = vaddq_s32(fives, sixes); 1017 1018 return vaddq_s32( 1019 vaddq_s32(vshlq_n_s32(fives_plus_sixes, 2), fives_plus_sixes), sixes); 1020 } 1021 1022 static inline void cross_sum_fast_odd_row_inp16(uint16_t *buf, int32x4_t *a0, 1023 int32x4_t *a1) { 1024 uint16x8_t xl, x, xr; 1025 uint16x8_t x0; 1026 1027 xl = vld1q_u16(buf - 1); 1028 x = vld1q_u16(buf); 1029 xr = vld1q_u16(buf + 1); 1030 xl = vaddq_u16(xl, xr); 1031 x0 = vshlq_n_u16(xl, 2); 1032 xl = vaddq_u16(xl, x0); 1033 1034 x0 = vshlq_n_u16(x, 1); 1035 x = vshlq_n_u16(x, 2); 1036 x = vaddq_u16(x, x0); 1037 1038 *a0 = vreinterpretq_s32_u32( 1039 vaddq_u32(vmovl_u16(vget_low_u16(xl)), vmovl_u16(vget_low_u16(x)))); 1040 *a1 = vreinterpretq_s32_u32( 1041 vaddq_u32(vmovl_u16(vget_high_u16(xl)), vmovl_u16(vget_high_u16(x)))); 1042 } 1043 1044 static void final_filter_fast_internal(uint16_t *A, int32_t *B, 1045 const int buf_stride, int16_t *src, 1046 const int src_stride, int32_t *dst, 1047 const int dst_stride, const int width, 1048 const int height) { 1049 int16x8_t s0; 1050 int32_t *B_tmp, *dst_ptr; 1051 uint16_t *A_tmp; 1052 int16_t *src_ptr; 1053 int32x4_t a_res0, a_res1, b_res0, b_res1; 1054 int w, h, count = 0; 1055 assert(SGRPROJ_SGR_BITS == 8); 1056 assert(SGRPROJ_RST_BITS == 4); 1057 1058 A_tmp = A; 1059 B_tmp = B; 1060 src_ptr = src; 1061 dst_ptr = dst; 1062 h = height; 1063 do { 1064 A_tmp = (A + count * buf_stride); 1065 B_tmp = (B + count * buf_stride); 1066 src_ptr = (src + count * src_stride); 1067 dst_ptr = (dst + count * dst_stride); 1068 w = width; 1069 if (!(count & 1)) { 1070 do { 1071 s0 = vld1q_s16(src_ptr); 1072 cross_sum_fast_even_row_inp16(A_tmp, buf_stride, &a_res0, &a_res1); 1073 a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0); 1074 a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1); 1075 1076 b_res0 = cross_sum_fast_even_row(B_tmp, buf_stride); 1077 b_res1 = cross_sum_fast_even_row(B_tmp + 4, buf_stride); 1078 a_res0 = vaddq_s32(a_res0, b_res0); 1079 a_res1 = vaddq_s32(a_res1, b_res1); 1080 1081 a_res0 = 1082 vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS); 1083 a_res1 = 1084 vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS); 1085 1086 vst1q_s32(dst_ptr, a_res0); 1087 vst1q_s32(dst_ptr + 4, a_res1); 1088 1089 A_tmp += 8; 1090 B_tmp += 8; 1091 src_ptr += 8; 1092 dst_ptr += 8; 1093 w -= 8; 1094 } while (w > 0); 1095 } else { 1096 do { 1097 s0 = vld1q_s16(src_ptr); 1098 cross_sum_fast_odd_row_inp16(A_tmp, &a_res0, &a_res1); 1099 a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0); 1100 a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1); 1101 1102 b_res0 = cross_sum_fast_odd_row(B_tmp); 1103 b_res1 = cross_sum_fast_odd_row(B_tmp + 4); 1104 a_res0 = vaddq_s32(a_res0, b_res0); 1105 a_res1 = vaddq_s32(a_res1, b_res1); 1106 1107 a_res0 = 1108 vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_ODD - SGRPROJ_RST_BITS); 1109 a_res1 = 1110 vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_ODD - SGRPROJ_RST_BITS); 1111 1112 vst1q_s32(dst_ptr, a_res0); 1113 vst1q_s32(dst_ptr + 4, a_res1); 1114 1115 A_tmp += 8; 1116 B_tmp += 8; 1117 src_ptr += 8; 1118 dst_ptr += 8; 1119 w -= 8; 1120 } while (w > 0); 1121 } 1122 count++; 1123 h -= 1; 1124 } while (h > 0); 1125 } 1126 1127 static void final_filter_internal(uint16_t *A, int32_t *B, const int buf_stride, 1128 int16_t *src, const int src_stride, 1129 int32_t *dst, const int dst_stride, 1130 const int width, const int height) { 1131 int16x8_t s0; 1132 int32_t *B_tmp, *dst_ptr; 1133 uint16_t *A_tmp; 1134 int16_t *src_ptr; 1135 int32x4_t a_res0, a_res1, b_res0, b_res1; 1136 int w, h, count = 0; 1137 1138 assert(SGRPROJ_SGR_BITS == 8); 1139 assert(SGRPROJ_RST_BITS == 4); 1140 h = height; 1141 1142 do { 1143 A_tmp = (A + count * buf_stride); 1144 B_tmp = (B + count * buf_stride); 1145 src_ptr = (src + count * src_stride); 1146 dst_ptr = (dst + count * dst_stride); 1147 w = width; 1148 do { 1149 s0 = vld1q_s16(src_ptr); 1150 cross_sum_inp_u16(A_tmp, buf_stride, &a_res0, &a_res1); 1151 a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0); 1152 a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1); 1153 1154 b_res0 = cross_sum_inp_s32(B_tmp, buf_stride); 1155 b_res1 = cross_sum_inp_s32(B_tmp + 4, buf_stride); 1156 a_res0 = vaddq_s32(a_res0, b_res0); 1157 a_res1 = vaddq_s32(a_res1, b_res1); 1158 1159 a_res0 = 1160 vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS); 1161 a_res1 = 1162 vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS); 1163 vst1q_s32(dst_ptr, a_res0); 1164 vst1q_s32(dst_ptr + 4, a_res1); 1165 1166 A_tmp += 8; 1167 B_tmp += 8; 1168 src_ptr += 8; 1169 dst_ptr += 8; 1170 w -= 8; 1171 } while (w > 0); 1172 count++; 1173 h -= 1; 1174 } while (h > 0); 1175 } 1176 1177 static inline int restoration_fast_internal(uint16_t *dgd16, int width, 1178 int height, int dgd_stride, 1179 int32_t *dst, int dst_stride, 1180 int bit_depth, int sgr_params_idx, 1181 int radius_idx) { 1182 const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx]; 1183 const int r = params->r[radius_idx]; 1184 const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ; 1185 const int height_ext = height + 2 * SGRPROJ_BORDER_VERT; 1186 const int buf_stride = ((width_ext + 3) & ~3) + 16; 1187 1188 const size_t buf_size = 3 * sizeof(int32_t) * RESTORATION_PROC_UNIT_PELS; 1189 int32_t *buf = aom_memalign(8, buf_size); 1190 if (!buf) return -1; 1191 1192 int32_t *square_sum_buf = buf; 1193 int32_t *sum_buf = square_sum_buf + RESTORATION_PROC_UNIT_PELS; 1194 uint16_t *tmp16_buf = (uint16_t *)(sum_buf + RESTORATION_PROC_UNIT_PELS); 1195 assert((char *)(sum_buf + RESTORATION_PROC_UNIT_PELS) <= 1196 (char *)buf + buf_size && 1197 "Allocated buffer is too small. Resize the buffer."); 1198 1199 assert(r <= MAX_RADIUS && "Need MAX_RADIUS >= r"); 1200 assert(r <= SGRPROJ_BORDER_VERT - 1 && r <= SGRPROJ_BORDER_HORZ - 1 && 1201 "Need SGRPROJ_BORDER_* >= r+1"); 1202 1203 assert(radius_idx == 0); 1204 assert(r == 2); 1205 1206 // input(dgd16) is 16bit. 1207 // sum of pixels 1st stage output will be in 16bit(tmp16_buf). End output is 1208 // kept in 32bit [sum_buf]. sum of squares output is kept in 32bit 1209 // buffer(square_sum_buf). 1210 boxsum2((int16_t *)(dgd16 - dgd_stride * SGRPROJ_BORDER_VERT - 1211 SGRPROJ_BORDER_HORZ), 1212 dgd_stride, (int16_t *)tmp16_buf, sum_buf, square_sum_buf, buf_stride, 1213 width_ext, height_ext); 1214 1215 square_sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ; 1216 sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ; 1217 tmp16_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ; 1218 1219 // Calculation of a, b. a output is in 16bit tmp_buf which is in range of 1220 // [1, 256] for all bit depths. b output is kept in 32bit buffer. 1221 1222 #if CONFIG_AV1_HIGHBITDEPTH 1223 if (bit_depth > 8) { 1224 calc_ab_fast_internal_hbd( 1225 (square_sum_buf - buf_stride - 1), (tmp16_buf - buf_stride - 1), 1226 (sum_buf - buf_stride - 1), buf_stride * 2, width + 2, height + 2, 1227 bit_depth, r, params->s[radius_idx], 2); 1228 } else { 1229 calc_ab_fast_internal_lbd( 1230 (square_sum_buf - buf_stride - 1), (tmp16_buf - buf_stride - 1), 1231 (sum_buf - buf_stride - 1), buf_stride * 2, width + 2, height + 2, r, 1232 params->s[radius_idx], 2); 1233 } 1234 #else 1235 (void)bit_depth; 1236 calc_ab_fast_internal_lbd((square_sum_buf - buf_stride - 1), 1237 (tmp16_buf - buf_stride - 1), 1238 (sum_buf - buf_stride - 1), buf_stride * 2, 1239 width + 2, height + 2, r, params->s[radius_idx], 2); 1240 #endif 1241 final_filter_fast_internal(tmp16_buf, sum_buf, buf_stride, (int16_t *)dgd16, 1242 dgd_stride, dst, dst_stride, width, height); 1243 aom_free(buf); 1244 return 0; 1245 } 1246 1247 static inline int restoration_internal(uint16_t *dgd16, int width, int height, 1248 int dgd_stride, int32_t *dst, 1249 int dst_stride, int bit_depth, 1250 int sgr_params_idx, int radius_idx) { 1251 const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx]; 1252 const int r = params->r[radius_idx]; 1253 const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ; 1254 const int height_ext = height + 2 * SGRPROJ_BORDER_VERT; 1255 const int buf_stride = ((width_ext + 3) & ~3) + 16; 1256 1257 const size_t buf_size = 3 * sizeof(int32_t) * RESTORATION_PROC_UNIT_PELS; 1258 int32_t *buf = aom_memalign(8, buf_size); 1259 if (!buf) return -1; 1260 1261 int32_t *square_sum_buf = buf; 1262 int32_t *B = square_sum_buf + RESTORATION_PROC_UNIT_PELS; 1263 uint16_t *A16 = (uint16_t *)(B + RESTORATION_PROC_UNIT_PELS); 1264 uint16_t *sum_buf = A16 + RESTORATION_PROC_UNIT_PELS; 1265 1266 assert((char *)(sum_buf + RESTORATION_PROC_UNIT_PELS) <= 1267 (char *)buf + buf_size && 1268 "Allocated buffer is too small. Resize the buffer."); 1269 1270 assert(r <= MAX_RADIUS && "Need MAX_RADIUS >= r"); 1271 assert(r <= SGRPROJ_BORDER_VERT - 1 && r <= SGRPROJ_BORDER_HORZ - 1 && 1272 "Need SGRPROJ_BORDER_* >= r+1"); 1273 1274 assert(radius_idx == 1); 1275 assert(r == 1); 1276 1277 // input(dgd16) is 16bit. 1278 // sum of pixels output will be in 16bit(sum_buf). 1279 // sum of squares output is kept in 32bit buffer(square_sum_buf). 1280 boxsum1((int16_t *)(dgd16 - dgd_stride * SGRPROJ_BORDER_VERT - 1281 SGRPROJ_BORDER_HORZ), 1282 dgd_stride, sum_buf, square_sum_buf, buf_stride, width_ext, 1283 height_ext); 1284 1285 square_sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ; 1286 B += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ; 1287 A16 += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ; 1288 sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ; 1289 1290 #if CONFIG_AV1_HIGHBITDEPTH 1291 // Calculation of a, b. a output is in 16bit tmp_buf which is in range of 1292 // [1, 256] for all bit depths. b output is kept in 32bit buffer. 1293 if (bit_depth > 8) { 1294 calc_ab_internal_hbd((square_sum_buf - buf_stride - 1), 1295 (A16 - buf_stride - 1), (sum_buf - buf_stride - 1), 1296 (B - buf_stride - 1), buf_stride, width + 2, 1297 height + 2, bit_depth, r, params->s[radius_idx], 1); 1298 } else { 1299 calc_ab_internal_lbd((square_sum_buf - buf_stride - 1), 1300 (A16 - buf_stride - 1), (sum_buf - buf_stride - 1), 1301 (B - buf_stride - 1), buf_stride, width + 2, 1302 height + 2, r, params->s[radius_idx], 1); 1303 } 1304 #else 1305 (void)bit_depth; 1306 calc_ab_internal_lbd((square_sum_buf - buf_stride - 1), 1307 (A16 - buf_stride - 1), (sum_buf - buf_stride - 1), 1308 (B - buf_stride - 1), buf_stride, width + 2, height + 2, 1309 r, params->s[radius_idx], 1); 1310 #endif 1311 final_filter_internal(A16, B, buf_stride, (int16_t *)dgd16, dgd_stride, dst, 1312 dst_stride, width, height); 1313 aom_free(buf); 1314 return 0; 1315 } 1316 1317 static inline void src_convert_u8_to_u16(const uint8_t *src, 1318 const int src_stride, uint16_t *dst, 1319 const int dst_stride, const int width, 1320 const int height) { 1321 const uint8_t *src_ptr; 1322 uint16_t *dst_ptr; 1323 int h, w, count = 0; 1324 1325 uint8x8_t t1, t2, t3, t4; 1326 uint16x8_t s1, s2, s3, s4; 1327 h = height; 1328 do { 1329 src_ptr = src + (count << 2) * src_stride; 1330 dst_ptr = dst + (count << 2) * dst_stride; 1331 w = width; 1332 if (w >= 7) { 1333 do { 1334 load_u8_8x4(src_ptr, src_stride, &t1, &t2, &t3, &t4); 1335 s1 = vmovl_u8(t1); 1336 s2 = vmovl_u8(t2); 1337 s3 = vmovl_u8(t3); 1338 s4 = vmovl_u8(t4); 1339 store_u16_8x4(dst_ptr, dst_stride, s1, s2, s3, s4); 1340 1341 src_ptr += 8; 1342 dst_ptr += 8; 1343 w -= 8; 1344 } while (w > 7); 1345 } 1346 1347 for (int y = 0; y < w; y++) { 1348 dst_ptr[y] = src_ptr[y]; 1349 dst_ptr[y + 1 * dst_stride] = src_ptr[y + 1 * src_stride]; 1350 dst_ptr[y + 2 * dst_stride] = src_ptr[y + 2 * src_stride]; 1351 dst_ptr[y + 3 * dst_stride] = src_ptr[y + 3 * src_stride]; 1352 } 1353 count++; 1354 h -= 4; 1355 } while (h > 3); 1356 1357 src_ptr = src + (count << 2) * src_stride; 1358 dst_ptr = dst + (count << 2) * dst_stride; 1359 for (int x = 0; x < h; x++) { 1360 for (int y = 0; y < width; y++) { 1361 dst_ptr[y + x * dst_stride] = src_ptr[y + x * src_stride]; 1362 } 1363 } 1364 1365 // memset uninitialized rows of src buffer as they are needed for the 1366 // boxsum filter calculation. 1367 for (int x = height; x < height + 5; x++) 1368 memset(dst + x * dst_stride, 0, (width + 2) * sizeof(*dst)); 1369 } 1370 1371 #if CONFIG_AV1_HIGHBITDEPTH 1372 static inline void src_convert_hbd_copy(const uint16_t *src, int src_stride, 1373 uint16_t *dst, const int dst_stride, 1374 int width, int height) { 1375 const uint16_t *src_ptr; 1376 uint16_t *dst_ptr; 1377 int h, w, count = 0; 1378 uint16x8_t s1, s2, s3, s4; 1379 1380 h = height; 1381 do { 1382 src_ptr = src + (count << 2) * src_stride; 1383 dst_ptr = dst + (count << 2) * dst_stride; 1384 w = width; 1385 do { 1386 load_u16_8x4(src_ptr, src_stride, &s1, &s2, &s3, &s4); 1387 store_u16_8x4(dst_ptr, dst_stride, s1, s2, s3, s4); 1388 src_ptr += 8; 1389 dst_ptr += 8; 1390 w -= 8; 1391 } while (w > 7); 1392 1393 for (int y = 0; y < w; y++) { 1394 dst_ptr[y] = src_ptr[y]; 1395 dst_ptr[y + 1 * dst_stride] = src_ptr[y + 1 * src_stride]; 1396 dst_ptr[y + 2 * dst_stride] = src_ptr[y + 2 * src_stride]; 1397 dst_ptr[y + 3 * dst_stride] = src_ptr[y + 3 * src_stride]; 1398 } 1399 count++; 1400 h -= 4; 1401 } while (h > 3); 1402 1403 src_ptr = src + (count << 2) * src_stride; 1404 dst_ptr = dst + (count << 2) * dst_stride; 1405 1406 for (int x = 0; x < h; x++) { 1407 memcpy((dst_ptr + x * dst_stride), (src_ptr + x * src_stride), 1408 sizeof(uint16_t) * width); 1409 } 1410 // memset uninitialized rows of src buffer as they are needed for the 1411 // boxsum filter calculation. 1412 for (int x = height; x < height + 5; x++) 1413 memset(dst + x * dst_stride, 0, (width + 2) * sizeof(*dst)); 1414 } 1415 #endif // CONFIG_AV1_HIGHBITDEPTH 1416 1417 int av1_selfguided_restoration_neon(const uint8_t *dat8, int width, int height, 1418 int stride, int32_t *flt0, int32_t *flt1, 1419 int flt_stride, int sgr_params_idx, 1420 int bit_depth, int highbd) { 1421 const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx]; 1422 assert(!(params->r[0] == 0 && params->r[1] == 0)); 1423 1424 uint16_t dgd16_[RESTORATION_PROC_UNIT_PELS]; 1425 const int dgd16_stride = width + 2 * SGRPROJ_BORDER_HORZ; 1426 uint16_t *dgd16 = 1427 dgd16_ + dgd16_stride * SGRPROJ_BORDER_VERT + SGRPROJ_BORDER_HORZ; 1428 const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ; 1429 const int height_ext = height + 2 * SGRPROJ_BORDER_VERT; 1430 const int dgd_stride = stride; 1431 1432 #if CONFIG_AV1_HIGHBITDEPTH 1433 if (highbd) { 1434 const uint16_t *dgd16_tmp = CONVERT_TO_SHORTPTR(dat8); 1435 src_convert_hbd_copy( 1436 dgd16_tmp - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, 1437 dgd_stride, 1438 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ, 1439 dgd16_stride, width_ext, height_ext); 1440 } else { 1441 src_convert_u8_to_u16( 1442 dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, 1443 dgd_stride, 1444 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ, 1445 dgd16_stride, width_ext, height_ext); 1446 } 1447 #else 1448 (void)highbd; 1449 src_convert_u8_to_u16( 1450 dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, dgd_stride, 1451 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ, 1452 dgd16_stride, width_ext, height_ext); 1453 #endif 1454 1455 if (params->r[0] > 0) { 1456 int ret = 1457 restoration_fast_internal(dgd16, width, height, dgd16_stride, flt0, 1458 flt_stride, bit_depth, sgr_params_idx, 0); 1459 if (ret != 0) return ret; 1460 } 1461 if (params->r[1] > 0) { 1462 int ret = restoration_internal(dgd16, width, height, dgd16_stride, flt1, 1463 flt_stride, bit_depth, sgr_params_idx, 1); 1464 if (ret != 0) return ret; 1465 } 1466 return 0; 1467 } 1468 1469 int av1_apply_selfguided_restoration_neon(const uint8_t *dat8, int width, 1470 int height, int stride, int eps, 1471 const int *xqd, uint8_t *dst8, 1472 int dst_stride, int32_t *tmpbuf, 1473 int bit_depth, int highbd) { 1474 int32_t *flt0 = tmpbuf; 1475 int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX; 1476 assert(width * height <= RESTORATION_UNITPELS_MAX); 1477 uint16_t dgd16_[RESTORATION_PROC_UNIT_PELS]; 1478 const int dgd16_stride = width + 2 * SGRPROJ_BORDER_HORZ; 1479 uint16_t *dgd16 = 1480 dgd16_ + dgd16_stride * SGRPROJ_BORDER_VERT + SGRPROJ_BORDER_HORZ; 1481 const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ; 1482 const int height_ext = height + 2 * SGRPROJ_BORDER_VERT; 1483 const int dgd_stride = stride; 1484 const sgr_params_type *const params = &av1_sgr_params[eps]; 1485 int xq[2]; 1486 1487 assert(!(params->r[0] == 0 && params->r[1] == 0)); 1488 1489 #if CONFIG_AV1_HIGHBITDEPTH 1490 if (highbd) { 1491 const uint16_t *dgd16_tmp = CONVERT_TO_SHORTPTR(dat8); 1492 src_convert_hbd_copy( 1493 dgd16_tmp - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, 1494 dgd_stride, 1495 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ, 1496 dgd16_stride, width_ext, height_ext); 1497 } else { 1498 src_convert_u8_to_u16( 1499 dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, 1500 dgd_stride, 1501 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ, 1502 dgd16_stride, width_ext, height_ext); 1503 } 1504 #else 1505 (void)highbd; 1506 src_convert_u8_to_u16( 1507 dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, dgd_stride, 1508 dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ, 1509 dgd16_stride, width_ext, height_ext); 1510 #endif 1511 if (params->r[0] > 0) { 1512 int ret = restoration_fast_internal(dgd16, width, height, dgd16_stride, 1513 flt0, width, bit_depth, eps, 0); 1514 if (ret != 0) return ret; 1515 } 1516 if (params->r[1] > 0) { 1517 int ret = restoration_internal(dgd16, width, height, dgd16_stride, flt1, 1518 width, bit_depth, eps, 1); 1519 if (ret != 0) return ret; 1520 } 1521 1522 av1_decode_xq(xqd, xq, params); 1523 1524 { 1525 int16_t *src_ptr; 1526 uint8_t *dst_ptr; 1527 #if CONFIG_AV1_HIGHBITDEPTH 1528 uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst8); 1529 uint16_t *dst16_ptr; 1530 #endif 1531 int16x4_t d0, d4; 1532 int16x8_t r0, s0; 1533 uint16x8_t r4; 1534 int32x4_t u0, u4, v0, v4, f00, f10; 1535 uint8x8_t t0; 1536 int count = 0, w = width, h = height, rc = 0; 1537 1538 const int32x4_t xq0_vec = vdupq_n_s32(xq[0]); 1539 const int32x4_t xq1_vec = vdupq_n_s32(xq[1]); 1540 const int16x8_t zero = vdupq_n_s16(0); 1541 const uint16x8_t max = vdupq_n_u16((1 << bit_depth) - 1); 1542 src_ptr = (int16_t *)dgd16; 1543 do { 1544 w = width; 1545 count = 0; 1546 dst_ptr = dst8 + rc * dst_stride; 1547 #if CONFIG_AV1_HIGHBITDEPTH 1548 dst16_ptr = dst16 + rc * dst_stride; 1549 #endif 1550 do { 1551 s0 = vld1q_s16(src_ptr + count); 1552 1553 u0 = vshll_n_s16(vget_low_s16(s0), SGRPROJ_RST_BITS); 1554 u4 = vshll_n_s16(vget_high_s16(s0), SGRPROJ_RST_BITS); 1555 1556 v0 = vshlq_n_s32(u0, SGRPROJ_PRJ_BITS); 1557 v4 = vshlq_n_s32(u4, SGRPROJ_PRJ_BITS); 1558 1559 if (params->r[0] > 0) { 1560 f00 = vld1q_s32(flt0 + count); 1561 f10 = vld1q_s32(flt0 + count + 4); 1562 1563 f00 = vsubq_s32(f00, u0); 1564 f10 = vsubq_s32(f10, u4); 1565 1566 v0 = vmlaq_s32(v0, xq0_vec, f00); 1567 v4 = vmlaq_s32(v4, xq0_vec, f10); 1568 } 1569 1570 if (params->r[1] > 0) { 1571 f00 = vld1q_s32(flt1 + count); 1572 f10 = vld1q_s32(flt1 + count + 4); 1573 1574 f00 = vsubq_s32(f00, u0); 1575 f10 = vsubq_s32(f10, u4); 1576 1577 v0 = vmlaq_s32(v0, xq1_vec, f00); 1578 v4 = vmlaq_s32(v4, xq1_vec, f10); 1579 } 1580 1581 d0 = vqrshrn_n_s32(v0, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS); 1582 d4 = vqrshrn_n_s32(v4, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS); 1583 1584 r0 = vcombine_s16(d0, d4); 1585 1586 r4 = vreinterpretq_u16_s16(vmaxq_s16(r0, zero)); 1587 1588 #if CONFIG_AV1_HIGHBITDEPTH 1589 if (highbd) { 1590 r4 = vminq_u16(r4, max); 1591 vst1q_u16(dst16_ptr, r4); 1592 dst16_ptr += 8; 1593 } else { 1594 t0 = vqmovn_u16(r4); 1595 vst1_u8(dst_ptr, t0); 1596 dst_ptr += 8; 1597 } 1598 #else 1599 (void)max; 1600 t0 = vqmovn_u16(r4); 1601 vst1_u8(dst_ptr, t0); 1602 dst_ptr += 8; 1603 #endif 1604 w -= 8; 1605 count += 8; 1606 } while (w > 0); 1607 1608 src_ptr += dgd16_stride; 1609 flt1 += width; 1610 flt0 += width; 1611 rc++; 1612 h--; 1613 } while (h > 0); 1614 } 1615 return 0; 1616 }