tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

selfguided_neon.c (56051B)


      1 /*
      2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 #include <assert.h>
     14 
     15 #include "config/aom_config.h"
     16 #include "config/av1_rtcd.h"
     17 
     18 #include "aom_dsp/aom_dsp_common.h"
     19 #include "aom_dsp/txfm_common.h"
     20 #include "aom_dsp/arm/mem_neon.h"
     21 #include "aom_dsp/arm/transpose_neon.h"
     22 #include "aom_mem/aom_mem.h"
     23 #include "aom_ports/mem.h"
     24 #include "av1/common/av1_common_int.h"
     25 #include "av1/common/common.h"
     26 #include "av1/common/resize.h"
     27 #include "av1/common/restoration.h"
     28 
     29 // Constants used for right shift in final_filter calculation.
     30 #define NB_EVEN 5
     31 #define NB_ODD 4
     32 
     33 static inline void calc_ab_fast_internal_common(
     34    uint32x4_t s0, uint32x4_t s1, uint32x4_t s2, uint32x4_t s3, uint32x4_t s4,
     35    uint32x4_t s5, uint32x4_t s6, uint32x4_t s7, int32x4_t sr4, int32x4_t sr5,
     36    int32x4_t sr6, int32x4_t sr7, uint32x4_t const_n_val, uint32x4_t s_vec,
     37    uint32x4_t const_val, uint32x4_t one_by_n_minus_1_vec,
     38    uint16x4_t sgrproj_sgr, int32_t *src1, uint16_t *dst_A16, int32_t *src2,
     39    const int buf_stride) {
     40  uint32x4_t q0, q1, q2, q3;
     41  uint32x4_t p0, p1, p2, p3;
     42  uint16x4_t d0, d1, d2, d3;
     43 
     44  s0 = vmulq_u32(s0, const_n_val);
     45  s1 = vmulq_u32(s1, const_n_val);
     46  s2 = vmulq_u32(s2, const_n_val);
     47  s3 = vmulq_u32(s3, const_n_val);
     48 
     49  q0 = vmulq_u32(s4, s4);
     50  q1 = vmulq_u32(s5, s5);
     51  q2 = vmulq_u32(s6, s6);
     52  q3 = vmulq_u32(s7, s7);
     53 
     54  p0 = vcleq_u32(q0, s0);
     55  p1 = vcleq_u32(q1, s1);
     56  p2 = vcleq_u32(q2, s2);
     57  p3 = vcleq_u32(q3, s3);
     58 
     59  q0 = vsubq_u32(s0, q0);
     60  q1 = vsubq_u32(s1, q1);
     61  q2 = vsubq_u32(s2, q2);
     62  q3 = vsubq_u32(s3, q3);
     63 
     64  p0 = vandq_u32(p0, q0);
     65  p1 = vandq_u32(p1, q1);
     66  p2 = vandq_u32(p2, q2);
     67  p3 = vandq_u32(p3, q3);
     68 
     69  p0 = vmulq_u32(p0, s_vec);
     70  p1 = vmulq_u32(p1, s_vec);
     71  p2 = vmulq_u32(p2, s_vec);
     72  p3 = vmulq_u32(p3, s_vec);
     73 
     74  p0 = vrshrq_n_u32(p0, SGRPROJ_MTABLE_BITS);
     75  p1 = vrshrq_n_u32(p1, SGRPROJ_MTABLE_BITS);
     76  p2 = vrshrq_n_u32(p2, SGRPROJ_MTABLE_BITS);
     77  p3 = vrshrq_n_u32(p3, SGRPROJ_MTABLE_BITS);
     78 
     79  p0 = vminq_u32(p0, const_val);
     80  p1 = vminq_u32(p1, const_val);
     81  p2 = vminq_u32(p2, const_val);
     82  p3 = vminq_u32(p3, const_val);
     83 
     84  {
     85    store_u32_4x4((uint32_t *)src1, buf_stride, p0, p1, p2, p3);
     86 
     87    for (int x = 0; x < 4; x++) {
     88      for (int y = 0; y < 4; y++) {
     89        dst_A16[x * buf_stride + y] = av1_x_by_xplus1[src1[x * buf_stride + y]];
     90      }
     91    }
     92    load_u16_4x4(dst_A16, buf_stride, &d0, &d1, &d2, &d3);
     93  }
     94  p0 = vsubl_u16(sgrproj_sgr, d0);
     95  p1 = vsubl_u16(sgrproj_sgr, d1);
     96  p2 = vsubl_u16(sgrproj_sgr, d2);
     97  p3 = vsubl_u16(sgrproj_sgr, d3);
     98 
     99  s4 = vmulq_u32(vreinterpretq_u32_s32(sr4), one_by_n_minus_1_vec);
    100  s5 = vmulq_u32(vreinterpretq_u32_s32(sr5), one_by_n_minus_1_vec);
    101  s6 = vmulq_u32(vreinterpretq_u32_s32(sr6), one_by_n_minus_1_vec);
    102  s7 = vmulq_u32(vreinterpretq_u32_s32(sr7), one_by_n_minus_1_vec);
    103 
    104  s4 = vmulq_u32(s4, p0);
    105  s5 = vmulq_u32(s5, p1);
    106  s6 = vmulq_u32(s6, p2);
    107  s7 = vmulq_u32(s7, p3);
    108 
    109  p0 = vrshrq_n_u32(s4, SGRPROJ_RECIP_BITS);
    110  p1 = vrshrq_n_u32(s5, SGRPROJ_RECIP_BITS);
    111  p2 = vrshrq_n_u32(s6, SGRPROJ_RECIP_BITS);
    112  p3 = vrshrq_n_u32(s7, SGRPROJ_RECIP_BITS);
    113 
    114  store_s32_4x4(src2, buf_stride, vreinterpretq_s32_u32(p0),
    115                vreinterpretq_s32_u32(p1), vreinterpretq_s32_u32(p2),
    116                vreinterpretq_s32_u32(p3));
    117 }
    118 static inline void calc_ab_internal_common(
    119    uint32x4_t s0, uint32x4_t s1, uint32x4_t s2, uint32x4_t s3, uint32x4_t s4,
    120    uint32x4_t s5, uint32x4_t s6, uint32x4_t s7, uint16x8_t s16_0,
    121    uint16x8_t s16_1, uint16x8_t s16_2, uint16x8_t s16_3, uint16x8_t s16_4,
    122    uint16x8_t s16_5, uint16x8_t s16_6, uint16x8_t s16_7,
    123    uint32x4_t const_n_val, uint32x4_t s_vec, uint32x4_t const_val,
    124    uint16x4_t one_by_n_minus_1_vec, uint16x8_t sgrproj_sgr, int32_t *src1,
    125    uint16_t *dst_A16, int32_t *dst2, const int buf_stride) {
    126  uint16x4_t d0, d1, d2, d3, d4, d5, d6, d7;
    127  uint32x4_t q0, q1, q2, q3, q4, q5, q6, q7;
    128  uint32x4_t p0, p1, p2, p3, p4, p5, p6, p7;
    129 
    130  s0 = vmulq_u32(s0, const_n_val);
    131  s1 = vmulq_u32(s1, const_n_val);
    132  s2 = vmulq_u32(s2, const_n_val);
    133  s3 = vmulq_u32(s3, const_n_val);
    134  s4 = vmulq_u32(s4, const_n_val);
    135  s5 = vmulq_u32(s5, const_n_val);
    136  s6 = vmulq_u32(s6, const_n_val);
    137  s7 = vmulq_u32(s7, const_n_val);
    138 
    139  d0 = vget_low_u16(s16_4);
    140  d1 = vget_low_u16(s16_5);
    141  d2 = vget_low_u16(s16_6);
    142  d3 = vget_low_u16(s16_7);
    143  d4 = vget_high_u16(s16_4);
    144  d5 = vget_high_u16(s16_5);
    145  d6 = vget_high_u16(s16_6);
    146  d7 = vget_high_u16(s16_7);
    147 
    148  q0 = vmull_u16(d0, d0);
    149  q1 = vmull_u16(d1, d1);
    150  q2 = vmull_u16(d2, d2);
    151  q3 = vmull_u16(d3, d3);
    152  q4 = vmull_u16(d4, d4);
    153  q5 = vmull_u16(d5, d5);
    154  q6 = vmull_u16(d6, d6);
    155  q7 = vmull_u16(d7, d7);
    156 
    157  p0 = vcleq_u32(q0, s0);
    158  p1 = vcleq_u32(q1, s1);
    159  p2 = vcleq_u32(q2, s2);
    160  p3 = vcleq_u32(q3, s3);
    161  p4 = vcleq_u32(q4, s4);
    162  p5 = vcleq_u32(q5, s5);
    163  p6 = vcleq_u32(q6, s6);
    164  p7 = vcleq_u32(q7, s7);
    165 
    166  q0 = vsubq_u32(s0, q0);
    167  q1 = vsubq_u32(s1, q1);
    168  q2 = vsubq_u32(s2, q2);
    169  q3 = vsubq_u32(s3, q3);
    170  q4 = vsubq_u32(s4, q4);
    171  q5 = vsubq_u32(s5, q5);
    172  q6 = vsubq_u32(s6, q6);
    173  q7 = vsubq_u32(s7, q7);
    174 
    175  p0 = vandq_u32(p0, q0);
    176  p1 = vandq_u32(p1, q1);
    177  p2 = vandq_u32(p2, q2);
    178  p3 = vandq_u32(p3, q3);
    179  p4 = vandq_u32(p4, q4);
    180  p5 = vandq_u32(p5, q5);
    181  p6 = vandq_u32(p6, q6);
    182  p7 = vandq_u32(p7, q7);
    183 
    184  p0 = vmulq_u32(p0, s_vec);
    185  p1 = vmulq_u32(p1, s_vec);
    186  p2 = vmulq_u32(p2, s_vec);
    187  p3 = vmulq_u32(p3, s_vec);
    188  p4 = vmulq_u32(p4, s_vec);
    189  p5 = vmulq_u32(p5, s_vec);
    190  p6 = vmulq_u32(p6, s_vec);
    191  p7 = vmulq_u32(p7, s_vec);
    192 
    193  p0 = vrshrq_n_u32(p0, SGRPROJ_MTABLE_BITS);
    194  p1 = vrshrq_n_u32(p1, SGRPROJ_MTABLE_BITS);
    195  p2 = vrshrq_n_u32(p2, SGRPROJ_MTABLE_BITS);
    196  p3 = vrshrq_n_u32(p3, SGRPROJ_MTABLE_BITS);
    197  p4 = vrshrq_n_u32(p4, SGRPROJ_MTABLE_BITS);
    198  p5 = vrshrq_n_u32(p5, SGRPROJ_MTABLE_BITS);
    199  p6 = vrshrq_n_u32(p6, SGRPROJ_MTABLE_BITS);
    200  p7 = vrshrq_n_u32(p7, SGRPROJ_MTABLE_BITS);
    201 
    202  p0 = vminq_u32(p0, const_val);
    203  p1 = vminq_u32(p1, const_val);
    204  p2 = vminq_u32(p2, const_val);
    205  p3 = vminq_u32(p3, const_val);
    206  p4 = vminq_u32(p4, const_val);
    207  p5 = vminq_u32(p5, const_val);
    208  p6 = vminq_u32(p6, const_val);
    209  p7 = vminq_u32(p7, const_val);
    210 
    211  {
    212    store_u32_4x4((uint32_t *)src1, buf_stride, p0, p1, p2, p3);
    213    store_u32_4x4((uint32_t *)src1 + 4, buf_stride, p4, p5, p6, p7);
    214 
    215    for (int x = 0; x < 4; x++) {
    216      for (int y = 0; y < 8; y++) {
    217        dst_A16[x * buf_stride + y] = av1_x_by_xplus1[src1[x * buf_stride + y]];
    218      }
    219    }
    220    load_u16_8x4(dst_A16, buf_stride, &s16_4, &s16_5, &s16_6, &s16_7);
    221  }
    222 
    223  s16_4 = vsubq_u16(sgrproj_sgr, s16_4);
    224  s16_5 = vsubq_u16(sgrproj_sgr, s16_5);
    225  s16_6 = vsubq_u16(sgrproj_sgr, s16_6);
    226  s16_7 = vsubq_u16(sgrproj_sgr, s16_7);
    227 
    228  s0 = vmull_u16(vget_low_u16(s16_0), one_by_n_minus_1_vec);
    229  s1 = vmull_u16(vget_low_u16(s16_1), one_by_n_minus_1_vec);
    230  s2 = vmull_u16(vget_low_u16(s16_2), one_by_n_minus_1_vec);
    231  s3 = vmull_u16(vget_low_u16(s16_3), one_by_n_minus_1_vec);
    232  s4 = vmull_u16(vget_high_u16(s16_0), one_by_n_minus_1_vec);
    233  s5 = vmull_u16(vget_high_u16(s16_1), one_by_n_minus_1_vec);
    234  s6 = vmull_u16(vget_high_u16(s16_2), one_by_n_minus_1_vec);
    235  s7 = vmull_u16(vget_high_u16(s16_3), one_by_n_minus_1_vec);
    236 
    237  s0 = vmulq_u32(s0, vmovl_u16(vget_low_u16(s16_4)));
    238  s1 = vmulq_u32(s1, vmovl_u16(vget_low_u16(s16_5)));
    239  s2 = vmulq_u32(s2, vmovl_u16(vget_low_u16(s16_6)));
    240  s3 = vmulq_u32(s3, vmovl_u16(vget_low_u16(s16_7)));
    241  s4 = vmulq_u32(s4, vmovl_u16(vget_high_u16(s16_4)));
    242  s5 = vmulq_u32(s5, vmovl_u16(vget_high_u16(s16_5)));
    243  s6 = vmulq_u32(s6, vmovl_u16(vget_high_u16(s16_6)));
    244  s7 = vmulq_u32(s7, vmovl_u16(vget_high_u16(s16_7)));
    245 
    246  p0 = vrshrq_n_u32(s0, SGRPROJ_RECIP_BITS);
    247  p1 = vrshrq_n_u32(s1, SGRPROJ_RECIP_BITS);
    248  p2 = vrshrq_n_u32(s2, SGRPROJ_RECIP_BITS);
    249  p3 = vrshrq_n_u32(s3, SGRPROJ_RECIP_BITS);
    250  p4 = vrshrq_n_u32(s4, SGRPROJ_RECIP_BITS);
    251  p5 = vrshrq_n_u32(s5, SGRPROJ_RECIP_BITS);
    252  p6 = vrshrq_n_u32(s6, SGRPROJ_RECIP_BITS);
    253  p7 = vrshrq_n_u32(s7, SGRPROJ_RECIP_BITS);
    254 
    255  store_s32_4x4(dst2, buf_stride, vreinterpretq_s32_u32(p0),
    256                vreinterpretq_s32_u32(p1), vreinterpretq_s32_u32(p2),
    257                vreinterpretq_s32_u32(p3));
    258  store_s32_4x4(dst2 + 4, buf_stride, vreinterpretq_s32_u32(p4),
    259                vreinterpretq_s32_u32(p5), vreinterpretq_s32_u32(p6),
    260                vreinterpretq_s32_u32(p7));
    261 }
    262 
    263 static inline void boxsum2_square_sum_calc(
    264    int16x4_t t1, int16x4_t t2, int16x4_t t3, int16x4_t t4, int16x4_t t5,
    265    int16x4_t t6, int16x4_t t7, int16x4_t t8, int16x4_t t9, int16x4_t t10,
    266    int16x4_t t11, int32x4_t *r0, int32x4_t *r1, int32x4_t *r2, int32x4_t *r3) {
    267  int32x4_t d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11;
    268  int32x4_t r12, r34, r67, r89, r1011;
    269  int32x4_t r345, r6789, r789;
    270 
    271  d1 = vmull_s16(t1, t1);
    272  d2 = vmull_s16(t2, t2);
    273  d3 = vmull_s16(t3, t3);
    274  d4 = vmull_s16(t4, t4);
    275  d5 = vmull_s16(t5, t5);
    276  d6 = vmull_s16(t6, t6);
    277  d7 = vmull_s16(t7, t7);
    278  d8 = vmull_s16(t8, t8);
    279  d9 = vmull_s16(t9, t9);
    280  d10 = vmull_s16(t10, t10);
    281  d11 = vmull_s16(t11, t11);
    282 
    283  r12 = vaddq_s32(d1, d2);
    284  r34 = vaddq_s32(d3, d4);
    285  r67 = vaddq_s32(d6, d7);
    286  r89 = vaddq_s32(d8, d9);
    287  r1011 = vaddq_s32(d10, d11);
    288  r345 = vaddq_s32(r34, d5);
    289  r6789 = vaddq_s32(r67, r89);
    290  r789 = vsubq_s32(r6789, d6);
    291  *r0 = vaddq_s32(r12, r345);
    292  *r1 = vaddq_s32(r67, r345);
    293  *r2 = vaddq_s32(d5, r6789);
    294  *r3 = vaddq_s32(r789, r1011);
    295 }
    296 
    297 static inline void boxsum2(int16_t *src, const int src_stride, int16_t *dst16,
    298                           int32_t *dst32, int32_t *dst2, const int dst_stride,
    299                           const int width, const int height) {
    300  assert(width > 2 * SGRPROJ_BORDER_HORZ);
    301  assert(height > 2 * SGRPROJ_BORDER_VERT);
    302 
    303  int16_t *dst1_16_ptr, *src_ptr;
    304  int32_t *dst2_ptr;
    305  int h, w, count = 0;
    306  const int dst_stride_2 = (dst_stride << 1);
    307  const int dst_stride_8 = (dst_stride << 3);
    308 
    309  dst1_16_ptr = dst16;
    310  dst2_ptr = dst2;
    311  src_ptr = src;
    312  w = width;
    313  {
    314    int16x8_t t1, t2, t3, t4, t5, t6, t7;
    315    int16x8_t t8, t9, t10, t11, t12;
    316 
    317    int16x8_t q12345, q56789, q34567, q7891011;
    318    int16x8_t q12, q34, q67, q89, q1011;
    319    int16x8_t q345, q6789, q789;
    320 
    321    int32x4_t r12345, r56789, r34567, r7891011;
    322 
    323    do {
    324      h = height;
    325      dst1_16_ptr = dst16 + (count << 3);
    326      dst2_ptr = dst2 + (count << 3);
    327      src_ptr = src + (count << 3);
    328 
    329      dst1_16_ptr += dst_stride_2;
    330      dst2_ptr += dst_stride_2;
    331      do {
    332        load_s16_8x4(src_ptr, src_stride, &t1, &t2, &t3, &t4);
    333        src_ptr += 4 * src_stride;
    334        load_s16_8x4(src_ptr, src_stride, &t5, &t6, &t7, &t8);
    335        src_ptr += 4 * src_stride;
    336        load_s16_8x4(src_ptr, src_stride, &t9, &t10, &t11, &t12);
    337 
    338        q12 = vaddq_s16(t1, t2);
    339        q34 = vaddq_s16(t3, t4);
    340        q67 = vaddq_s16(t6, t7);
    341        q89 = vaddq_s16(t8, t9);
    342        q1011 = vaddq_s16(t10, t11);
    343        q345 = vaddq_s16(q34, t5);
    344        q6789 = vaddq_s16(q67, q89);
    345        q789 = vaddq_s16(q89, t7);
    346        q12345 = vaddq_s16(q12, q345);
    347        q34567 = vaddq_s16(q67, q345);
    348        q56789 = vaddq_s16(t5, q6789);
    349        q7891011 = vaddq_s16(q789, q1011);
    350 
    351        store_s16_8x4(dst1_16_ptr, dst_stride_2, q12345, q34567, q56789,
    352                      q7891011);
    353        dst1_16_ptr += dst_stride_8;
    354 
    355        boxsum2_square_sum_calc(
    356            vget_low_s16(t1), vget_low_s16(t2), vget_low_s16(t3),
    357            vget_low_s16(t4), vget_low_s16(t5), vget_low_s16(t6),
    358            vget_low_s16(t7), vget_low_s16(t8), vget_low_s16(t9),
    359            vget_low_s16(t10), vget_low_s16(t11), &r12345, &r34567, &r56789,
    360            &r7891011);
    361 
    362        store_s32_4x4(dst2_ptr, dst_stride_2, r12345, r34567, r56789, r7891011);
    363 
    364        boxsum2_square_sum_calc(
    365            vget_high_s16(t1), vget_high_s16(t2), vget_high_s16(t3),
    366            vget_high_s16(t4), vget_high_s16(t5), vget_high_s16(t6),
    367            vget_high_s16(t7), vget_high_s16(t8), vget_high_s16(t9),
    368            vget_high_s16(t10), vget_high_s16(t11), &r12345, &r34567, &r56789,
    369            &r7891011);
    370 
    371        store_s32_4x4(dst2_ptr + 4, dst_stride_2, r12345, r34567, r56789,
    372                      r7891011);
    373        dst2_ptr += (dst_stride_8);
    374        h -= 8;
    375      } while (h > 0);
    376      w -= 8;
    377      count++;
    378    } while (w > 0);
    379 
    380    // memset needed for row pixels as 2nd stage of boxsum filter uses
    381    // first 2 rows of dst16, dst2 buffer which is not filled in first stage.
    382    for (int x = 0; x < 2; x++) {
    383      memset(dst16 + x * dst_stride, 0, (width + 4) * sizeof(*dst16));
    384      memset(dst2 + x * dst_stride, 0, (width + 4) * sizeof(*dst2));
    385    }
    386 
    387    // memset needed for extra columns as 2nd stage of boxsum filter uses
    388    // last 2 columns of dst16, dst2 buffer which is not filled in first stage.
    389    for (int x = 2; x < height + 2; x++) {
    390      int dst_offset = x * dst_stride + width + 2;
    391      memset(dst16 + dst_offset, 0, 3 * sizeof(*dst16));
    392      memset(dst2 + dst_offset, 0, 3 * sizeof(*dst2));
    393    }
    394  }
    395 
    396  {
    397    int16x4_t s1, s2, s3, s4, s5, s6, s7, s8;
    398    int32x4_t d1, d2, d3, d4, d5, d6, d7, d8;
    399    int32x4_t q12345, q34567, q23456, q45678;
    400    int32x4_t q23, q45, q67;
    401    int32x4_t q2345, q4567;
    402 
    403    int32x4_t r12345, r34567, r23456, r45678;
    404    int32x4_t r23, r45, r67;
    405    int32x4_t r2345, r4567;
    406 
    407    int32_t *src2_ptr, *dst1_32_ptr;
    408    int16_t *src1_ptr;
    409    count = 0;
    410    h = height;
    411    do {
    412      dst1_32_ptr = dst32 + count * dst_stride_8 + (dst_stride_2);
    413      dst2_ptr = dst2 + count * dst_stride_8 + (dst_stride_2);
    414      src1_ptr = dst16 + count * dst_stride_8 + (dst_stride_2);
    415      src2_ptr = dst2 + count * dst_stride_8 + (dst_stride_2);
    416      w = width;
    417 
    418      dst1_32_ptr += 2;
    419      dst2_ptr += 2;
    420      load_s16_4x4(src1_ptr, dst_stride_2, &s1, &s2, &s3, &s4);
    421      transpose_elems_inplace_s16_4x4(&s1, &s2, &s3, &s4);
    422      load_s32_4x4(src2_ptr, dst_stride_2, &d1, &d2, &d3, &d4);
    423      transpose_elems_inplace_s32_4x4(&d1, &d2, &d3, &d4);
    424      do {
    425        src1_ptr += 4;
    426        src2_ptr += 4;
    427        load_s16_4x4(src1_ptr, dst_stride_2, &s5, &s6, &s7, &s8);
    428        transpose_elems_inplace_s16_4x4(&s5, &s6, &s7, &s8);
    429        load_s32_4x4(src2_ptr, dst_stride_2, &d5, &d6, &d7, &d8);
    430        transpose_elems_inplace_s32_4x4(&d5, &d6, &d7, &d8);
    431        q23 = vaddl_s16(s2, s3);
    432        q45 = vaddl_s16(s4, s5);
    433        q67 = vaddl_s16(s6, s7);
    434        q2345 = vaddq_s32(q23, q45);
    435        q4567 = vaddq_s32(q45, q67);
    436        q12345 = vaddq_s32(vmovl_s16(s1), q2345);
    437        q23456 = vaddq_s32(q2345, vmovl_s16(s6));
    438        q34567 = vaddq_s32(q4567, vmovl_s16(s3));
    439        q45678 = vaddq_s32(q4567, vmovl_s16(s8));
    440 
    441        transpose_elems_inplace_s32_4x4(&q12345, &q23456, &q34567, &q45678);
    442        store_s32_4x4(dst1_32_ptr, dst_stride_2, q12345, q23456, q34567,
    443                      q45678);
    444        dst1_32_ptr += 4;
    445        s1 = s5;
    446        s2 = s6;
    447        s3 = s7;
    448        s4 = s8;
    449 
    450        r23 = vaddq_s32(d2, d3);
    451        r45 = vaddq_s32(d4, d5);
    452        r67 = vaddq_s32(d6, d7);
    453        r2345 = vaddq_s32(r23, r45);
    454        r4567 = vaddq_s32(r45, r67);
    455        r12345 = vaddq_s32(d1, r2345);
    456        r23456 = vaddq_s32(r2345, d6);
    457        r34567 = vaddq_s32(r4567, d3);
    458        r45678 = vaddq_s32(r4567, d8);
    459 
    460        transpose_elems_inplace_s32_4x4(&r12345, &r23456, &r34567, &r45678);
    461        store_s32_4x4(dst2_ptr, dst_stride_2, r12345, r23456, r34567, r45678);
    462        dst2_ptr += 4;
    463        d1 = d5;
    464        d2 = d6;
    465        d3 = d7;
    466        d4 = d8;
    467        w -= 4;
    468      } while (w > 0);
    469      h -= 8;
    470      count++;
    471    } while (h > 0);
    472  }
    473 }
    474 
    475 static inline void calc_ab_internal_lbd(int32_t *A, uint16_t *A16,
    476                                        uint16_t *B16, int32_t *B,
    477                                        const int buf_stride, const int width,
    478                                        const int height, const int r,
    479                                        const int s, const int ht_inc) {
    480  int32_t *src1, *dst2, count = 0;
    481  uint16_t *dst_A16, *src2;
    482  const uint32_t n = (2 * r + 1) * (2 * r + 1);
    483  const uint32x4_t const_n_val = vdupq_n_u32(n);
    484  const uint16x8_t sgrproj_sgr = vdupq_n_u16(SGRPROJ_SGR);
    485  const uint16x4_t one_by_n_minus_1_vec = vdup_n_u16(av1_one_by_x[n - 1]);
    486  const uint32x4_t const_val = vdupq_n_u32(255);
    487 
    488  uint16x8_t s16_0, s16_1, s16_2, s16_3, s16_4, s16_5, s16_6, s16_7;
    489 
    490  uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
    491 
    492  const uint32x4_t s_vec = vdupq_n_u32(s);
    493  int w, h = height;
    494 
    495  do {
    496    dst_A16 = A16 + (count << 2) * buf_stride;
    497    src1 = A + (count << 2) * buf_stride;
    498    src2 = B16 + (count << 2) * buf_stride;
    499    dst2 = B + (count << 2) * buf_stride;
    500    w = width;
    501    do {
    502      load_u32_4x4((uint32_t *)src1, buf_stride, &s0, &s1, &s2, &s3);
    503      load_u32_4x4((uint32_t *)src1 + 4, buf_stride, &s4, &s5, &s6, &s7);
    504      load_u16_8x4(src2, buf_stride, &s16_0, &s16_1, &s16_2, &s16_3);
    505 
    506      s16_4 = s16_0;
    507      s16_5 = s16_1;
    508      s16_6 = s16_2;
    509      s16_7 = s16_3;
    510 
    511      calc_ab_internal_common(
    512          s0, s1, s2, s3, s4, s5, s6, s7, s16_0, s16_1, s16_2, s16_3, s16_4,
    513          s16_5, s16_6, s16_7, const_n_val, s_vec, const_val,
    514          one_by_n_minus_1_vec, sgrproj_sgr, src1, dst_A16, dst2, buf_stride);
    515 
    516      w -= 8;
    517      dst2 += 8;
    518      src1 += 8;
    519      src2 += 8;
    520      dst_A16 += 8;
    521    } while (w > 0);
    522    count++;
    523    h -= (ht_inc * 4);
    524  } while (h > 0);
    525 }
    526 
    527 #if CONFIG_AV1_HIGHBITDEPTH
    528 static inline void calc_ab_internal_hbd(int32_t *A, uint16_t *A16,
    529                                        uint16_t *B16, int32_t *B,
    530                                        const int buf_stride, const int width,
    531                                        const int height, const int bit_depth,
    532                                        const int r, const int s,
    533                                        const int ht_inc) {
    534  int32_t *src1, *dst2, count = 0;
    535  uint16_t *dst_A16, *src2;
    536  const uint32_t n = (2 * r + 1) * (2 * r + 1);
    537  const int16x8_t bd_min_2_vec = vdupq_n_s16(-(bit_depth - 8));
    538  const int32x4_t bd_min_1_vec = vdupq_n_s32(-((bit_depth - 8) << 1));
    539  const uint32x4_t const_n_val = vdupq_n_u32(n);
    540  const uint16x8_t sgrproj_sgr = vdupq_n_u16(SGRPROJ_SGR);
    541  const uint16x4_t one_by_n_minus_1_vec = vdup_n_u16(av1_one_by_x[n - 1]);
    542  const uint32x4_t const_val = vdupq_n_u32(255);
    543 
    544  int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7;
    545  uint16x8_t s16_0, s16_1, s16_2, s16_3;
    546  uint16x8_t s16_4, s16_5, s16_6, s16_7;
    547  uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
    548 
    549  const uint32x4_t s_vec = vdupq_n_u32(s);
    550  int w, h = height;
    551 
    552  do {
    553    src1 = A + (count << 2) * buf_stride;
    554    src2 = B16 + (count << 2) * buf_stride;
    555    dst2 = B + (count << 2) * buf_stride;
    556    dst_A16 = A16 + (count << 2) * buf_stride;
    557    w = width;
    558    do {
    559      load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3);
    560      load_s32_4x4(src1 + 4, buf_stride, &sr4, &sr5, &sr6, &sr7);
    561      load_u16_8x4(src2, buf_stride, &s16_0, &s16_1, &s16_2, &s16_3);
    562 
    563      s0 = vrshlq_u32(vreinterpretq_u32_s32(sr0), bd_min_1_vec);
    564      s1 = vrshlq_u32(vreinterpretq_u32_s32(sr1), bd_min_1_vec);
    565      s2 = vrshlq_u32(vreinterpretq_u32_s32(sr2), bd_min_1_vec);
    566      s3 = vrshlq_u32(vreinterpretq_u32_s32(sr3), bd_min_1_vec);
    567      s4 = vrshlq_u32(vreinterpretq_u32_s32(sr4), bd_min_1_vec);
    568      s5 = vrshlq_u32(vreinterpretq_u32_s32(sr5), bd_min_1_vec);
    569      s6 = vrshlq_u32(vreinterpretq_u32_s32(sr6), bd_min_1_vec);
    570      s7 = vrshlq_u32(vreinterpretq_u32_s32(sr7), bd_min_1_vec);
    571 
    572      s16_4 = vrshlq_u16(s16_0, bd_min_2_vec);
    573      s16_5 = vrshlq_u16(s16_1, bd_min_2_vec);
    574      s16_6 = vrshlq_u16(s16_2, bd_min_2_vec);
    575      s16_7 = vrshlq_u16(s16_3, bd_min_2_vec);
    576 
    577      calc_ab_internal_common(
    578          s0, s1, s2, s3, s4, s5, s6, s7, s16_0, s16_1, s16_2, s16_3, s16_4,
    579          s16_5, s16_6, s16_7, const_n_val, s_vec, const_val,
    580          one_by_n_minus_1_vec, sgrproj_sgr, src1, dst_A16, dst2, buf_stride);
    581 
    582      w -= 8;
    583      dst2 += 8;
    584      src1 += 8;
    585      src2 += 8;
    586      dst_A16 += 8;
    587    } while (w > 0);
    588    count++;
    589    h -= (ht_inc * 4);
    590  } while (h > 0);
    591 }
    592 #endif  // CONFIG_AV1_HIGHBITDEPTH
    593 
    594 static inline void calc_ab_fast_internal_lbd(int32_t *A, uint16_t *A16,
    595                                             int32_t *B, const int buf_stride,
    596                                             const int width, const int height,
    597                                             const int r, const int s,
    598                                             const int ht_inc) {
    599  int32_t *src1, *src2, count = 0;
    600  uint16_t *dst_A16;
    601  const uint32_t n = (2 * r + 1) * (2 * r + 1);
    602  const uint32x4_t const_n_val = vdupq_n_u32(n);
    603  const uint16x4_t sgrproj_sgr = vdup_n_u16(SGRPROJ_SGR);
    604  const uint32x4_t one_by_n_minus_1_vec = vdupq_n_u32(av1_one_by_x[n - 1]);
    605  const uint32x4_t const_val = vdupq_n_u32(255);
    606 
    607  int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7;
    608  uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
    609 
    610  const uint32x4_t s_vec = vdupq_n_u32(s);
    611  int w, h = height;
    612 
    613  do {
    614    src1 = A + (count << 2) * buf_stride;
    615    src2 = B + (count << 2) * buf_stride;
    616    dst_A16 = A16 + (count << 2) * buf_stride;
    617    w = width;
    618    do {
    619      load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3);
    620      load_s32_4x4(src2, buf_stride, &sr4, &sr5, &sr6, &sr7);
    621 
    622      s0 = vreinterpretq_u32_s32(sr0);
    623      s1 = vreinterpretq_u32_s32(sr1);
    624      s2 = vreinterpretq_u32_s32(sr2);
    625      s3 = vreinterpretq_u32_s32(sr3);
    626      s4 = vreinterpretq_u32_s32(sr4);
    627      s5 = vreinterpretq_u32_s32(sr5);
    628      s6 = vreinterpretq_u32_s32(sr6);
    629      s7 = vreinterpretq_u32_s32(sr7);
    630 
    631      calc_ab_fast_internal_common(s0, s1, s2, s3, s4, s5, s6, s7, sr4, sr5,
    632                                   sr6, sr7, const_n_val, s_vec, const_val,
    633                                   one_by_n_minus_1_vec, sgrproj_sgr, src1,
    634                                   dst_A16, src2, buf_stride);
    635 
    636      w -= 4;
    637      src1 += 4;
    638      src2 += 4;
    639      dst_A16 += 4;
    640    } while (w > 0);
    641    count++;
    642    h -= (ht_inc * 4);
    643  } while (h > 0);
    644 }
    645 
    646 #if CONFIG_AV1_HIGHBITDEPTH
    647 static inline void calc_ab_fast_internal_hbd(int32_t *A, uint16_t *A16,
    648                                             int32_t *B, const int buf_stride,
    649                                             const int width, const int height,
    650                                             const int bit_depth, const int r,
    651                                             const int s, const int ht_inc) {
    652  int32_t *src1, *src2, count = 0;
    653  uint16_t *dst_A16;
    654  const uint32_t n = (2 * r + 1) * (2 * r + 1);
    655  const int32x4_t bd_min_2_vec = vdupq_n_s32(-(bit_depth - 8));
    656  const int32x4_t bd_min_1_vec = vdupq_n_s32(-((bit_depth - 8) << 1));
    657  const uint32x4_t const_n_val = vdupq_n_u32(n);
    658  const uint16x4_t sgrproj_sgr = vdup_n_u16(SGRPROJ_SGR);
    659  const uint32x4_t one_by_n_minus_1_vec = vdupq_n_u32(av1_one_by_x[n - 1]);
    660  const uint32x4_t const_val = vdupq_n_u32(255);
    661 
    662  int32x4_t sr0, sr1, sr2, sr3, sr4, sr5, sr6, sr7;
    663  uint32x4_t s0, s1, s2, s3, s4, s5, s6, s7;
    664 
    665  const uint32x4_t s_vec = vdupq_n_u32(s);
    666  int w, h = height;
    667 
    668  do {
    669    src1 = A + (count << 2) * buf_stride;
    670    src2 = B + (count << 2) * buf_stride;
    671    dst_A16 = A16 + (count << 2) * buf_stride;
    672    w = width;
    673    do {
    674      load_s32_4x4(src1, buf_stride, &sr0, &sr1, &sr2, &sr3);
    675      load_s32_4x4(src2, buf_stride, &sr4, &sr5, &sr6, &sr7);
    676 
    677      s0 = vrshlq_u32(vreinterpretq_u32_s32(sr0), bd_min_1_vec);
    678      s1 = vrshlq_u32(vreinterpretq_u32_s32(sr1), bd_min_1_vec);
    679      s2 = vrshlq_u32(vreinterpretq_u32_s32(sr2), bd_min_1_vec);
    680      s3 = vrshlq_u32(vreinterpretq_u32_s32(sr3), bd_min_1_vec);
    681      s4 = vrshlq_u32(vreinterpretq_u32_s32(sr4), bd_min_2_vec);
    682      s5 = vrshlq_u32(vreinterpretq_u32_s32(sr5), bd_min_2_vec);
    683      s6 = vrshlq_u32(vreinterpretq_u32_s32(sr6), bd_min_2_vec);
    684      s7 = vrshlq_u32(vreinterpretq_u32_s32(sr7), bd_min_2_vec);
    685 
    686      calc_ab_fast_internal_common(s0, s1, s2, s3, s4, s5, s6, s7, sr4, sr5,
    687                                   sr6, sr7, const_n_val, s_vec, const_val,
    688                                   one_by_n_minus_1_vec, sgrproj_sgr, src1,
    689                                   dst_A16, src2, buf_stride);
    690 
    691      w -= 4;
    692      src1 += 4;
    693      src2 += 4;
    694      dst_A16 += 4;
    695    } while (w > 0);
    696    count++;
    697    h -= (ht_inc * 4);
    698  } while (h > 0);
    699 }
    700 #endif  // CONFIG_AV1_HIGHBITDEPTH
    701 
    702 static inline void boxsum1(int16_t *src, const int src_stride, uint16_t *dst1,
    703                           int32_t *dst2, const int dst_stride, const int width,
    704                           const int height) {
    705  assert(width > 2 * SGRPROJ_BORDER_HORZ);
    706  assert(height > 2 * SGRPROJ_BORDER_VERT);
    707 
    708  int16_t *src_ptr;
    709  int32_t *dst2_ptr;
    710  uint16_t *dst1_ptr;
    711  int h, w, count = 0;
    712 
    713  w = width;
    714  {
    715    int16x8_t s1, s2, s3, s4, s5, s6, s7, s8;
    716    int16x8_t q23, q34, q56, q234, q345, q456, q567;
    717    int32x4_t r23, r56, r345, r456, r567, r78, r678;
    718    int32x4_t r4_low, r4_high, r34_low, r34_high, r234_low, r234_high;
    719    int32x4_t r2, r3, r5, r6, r7, r8;
    720    int16x8_t q678, q78;
    721 
    722    do {
    723      dst1_ptr = dst1 + (count << 3);
    724      dst2_ptr = dst2 + (count << 3);
    725      src_ptr = src + (count << 3);
    726      h = height;
    727 
    728      load_s16_8x4(src_ptr, src_stride, &s1, &s2, &s3, &s4);
    729      src_ptr += 4 * src_stride;
    730 
    731      q23 = vaddq_s16(s2, s3);
    732      q234 = vaddq_s16(q23, s4);
    733      q34 = vaddq_s16(s3, s4);
    734      dst1_ptr += (dst_stride << 1);
    735 
    736      r2 = vmull_s16(vget_low_s16(s2), vget_low_s16(s2));
    737      r3 = vmull_s16(vget_low_s16(s3), vget_low_s16(s3));
    738      r4_low = vmull_s16(vget_low_s16(s4), vget_low_s16(s4));
    739      r23 = vaddq_s32(r2, r3);
    740      r234_low = vaddq_s32(r23, r4_low);
    741      r34_low = vaddq_s32(r3, r4_low);
    742 
    743      r2 = vmull_s16(vget_high_s16(s2), vget_high_s16(s2));
    744      r3 = vmull_s16(vget_high_s16(s3), vget_high_s16(s3));
    745      r4_high = vmull_s16(vget_high_s16(s4), vget_high_s16(s4));
    746      r23 = vaddq_s32(r2, r3);
    747      r234_high = vaddq_s32(r23, r4_high);
    748      r34_high = vaddq_s32(r3, r4_high);
    749 
    750      dst2_ptr += (dst_stride << 1);
    751 
    752      do {
    753        load_s16_8x4(src_ptr, src_stride, &s5, &s6, &s7, &s8);
    754        src_ptr += 4 * src_stride;
    755 
    756        q345 = vaddq_s16(s5, q34);
    757        q56 = vaddq_s16(s5, s6);
    758        q456 = vaddq_s16(s4, q56);
    759        q567 = vaddq_s16(s7, q56);
    760        q78 = vaddq_s16(s7, s8);
    761        q678 = vaddq_s16(s6, q78);
    762 
    763        store_s16_8x4((int16_t *)dst1_ptr, dst_stride, q234, q345, q456, q567);
    764        dst1_ptr += (dst_stride << 2);
    765 
    766        s4 = s8;
    767        q34 = q78;
    768        q234 = q678;
    769 
    770        r5 = vmull_s16(vget_low_s16(s5), vget_low_s16(s5));
    771        r6 = vmull_s16(vget_low_s16(s6), vget_low_s16(s6));
    772        r7 = vmull_s16(vget_low_s16(s7), vget_low_s16(s7));
    773        r8 = vmull_s16(vget_low_s16(s8), vget_low_s16(s8));
    774 
    775        r345 = vaddq_s32(r5, r34_low);
    776        r56 = vaddq_s32(r5, r6);
    777        r456 = vaddq_s32(r4_low, r56);
    778        r567 = vaddq_s32(r7, r56);
    779        r78 = vaddq_s32(r7, r8);
    780        r678 = vaddq_s32(r6, r78);
    781        store_s32_4x4(dst2_ptr, dst_stride, r234_low, r345, r456, r567);
    782 
    783        r4_low = r8;
    784        r34_low = r78;
    785        r234_low = r678;
    786 
    787        r5 = vmull_s16(vget_high_s16(s5), vget_high_s16(s5));
    788        r6 = vmull_s16(vget_high_s16(s6), vget_high_s16(s6));
    789        r7 = vmull_s16(vget_high_s16(s7), vget_high_s16(s7));
    790        r8 = vmull_s16(vget_high_s16(s8), vget_high_s16(s8));
    791 
    792        r345 = vaddq_s32(r5, r34_high);
    793        r56 = vaddq_s32(r5, r6);
    794        r456 = vaddq_s32(r4_high, r56);
    795        r567 = vaddq_s32(r7, r56);
    796        r78 = vaddq_s32(r7, r8);
    797        r678 = vaddq_s32(r6, r78);
    798        store_s32_4x4((dst2_ptr + 4), dst_stride, r234_high, r345, r456, r567);
    799        dst2_ptr += (dst_stride << 2);
    800 
    801        r4_high = r8;
    802        r34_high = r78;
    803        r234_high = r678;
    804 
    805        h -= 4;
    806      } while (h > 0);
    807      w -= 8;
    808      count++;
    809    } while (w > 0);
    810 
    811    // memset needed for row pixels as 2nd stage of boxsum filter uses
    812    // first 2 rows of dst1, dst2 buffer which is not filled in first stage.
    813    for (int x = 0; x < 2; x++) {
    814      memset(dst1 + x * dst_stride, 0, (width + 4) * sizeof(*dst1));
    815      memset(dst2 + x * dst_stride, 0, (width + 4) * sizeof(*dst2));
    816    }
    817 
    818    // memset needed for extra columns as 2nd stage of boxsum filter uses
    819    // last 2 columns of dst1, dst2 buffer which is not filled in first stage.
    820    for (int x = 2; x < height + 2; x++) {
    821      int dst_offset = x * dst_stride + width + 2;
    822      memset(dst1 + dst_offset, 0, 3 * sizeof(*dst1));
    823      memset(dst2 + dst_offset, 0, 3 * sizeof(*dst2));
    824    }
    825  }
    826 
    827  {
    828    int16x4_t d1, d2, d3, d4, d5, d6, d7, d8;
    829    int16x4_t q23, q34, q56, q234, q345, q456, q567;
    830    int32x4_t r23, r56, r234, r345, r456, r567, r34, r78, r678;
    831    int32x4_t r1, r2, r3, r4, r5, r6, r7, r8;
    832    int16x4_t q678, q78;
    833 
    834    int32_t *src2_ptr;
    835    uint16_t *src1_ptr;
    836    count = 0;
    837    h = height;
    838    w = width;
    839    do {
    840      dst1_ptr = dst1 + (count << 2) * dst_stride;
    841      dst2_ptr = dst2 + (count << 2) * dst_stride;
    842      src1_ptr = dst1 + (count << 2) * dst_stride;
    843      src2_ptr = dst2 + (count << 2) * dst_stride;
    844      w = width;
    845 
    846      load_s16_4x4((int16_t *)src1_ptr, dst_stride, &d1, &d2, &d3, &d4);
    847      transpose_elems_inplace_s16_4x4(&d1, &d2, &d3, &d4);
    848      load_s32_4x4(src2_ptr, dst_stride, &r1, &r2, &r3, &r4);
    849      transpose_elems_inplace_s32_4x4(&r1, &r2, &r3, &r4);
    850      src1_ptr += 4;
    851      src2_ptr += 4;
    852 
    853      q23 = vadd_s16(d2, d3);
    854      q234 = vadd_s16(q23, d4);
    855      q34 = vadd_s16(d3, d4);
    856      dst1_ptr += 2;
    857      r23 = vaddq_s32(r2, r3);
    858      r234 = vaddq_s32(r23, r4);
    859      r34 = vaddq_s32(r3, r4);
    860      dst2_ptr += 2;
    861 
    862      do {
    863        load_s16_4x4((int16_t *)src1_ptr, dst_stride, &d5, &d6, &d7, &d8);
    864        transpose_elems_inplace_s16_4x4(&d5, &d6, &d7, &d8);
    865        load_s32_4x4(src2_ptr, dst_stride, &r5, &r6, &r7, &r8);
    866        transpose_elems_inplace_s32_4x4(&r5, &r6, &r7, &r8);
    867        src1_ptr += 4;
    868        src2_ptr += 4;
    869 
    870        q345 = vadd_s16(d5, q34);
    871        q56 = vadd_s16(d5, d6);
    872        q456 = vadd_s16(d4, q56);
    873        q567 = vadd_s16(d7, q56);
    874        q78 = vadd_s16(d7, d8);
    875        q678 = vadd_s16(d6, q78);
    876        transpose_elems_inplace_s16_4x4(&q234, &q345, &q456, &q567);
    877        store_s16_4x4((int16_t *)dst1_ptr, dst_stride, q234, q345, q456, q567);
    878        dst1_ptr += 4;
    879 
    880        d4 = d8;
    881        q34 = q78;
    882        q234 = q678;
    883 
    884        r345 = vaddq_s32(r5, r34);
    885        r56 = vaddq_s32(r5, r6);
    886        r456 = vaddq_s32(r4, r56);
    887        r567 = vaddq_s32(r7, r56);
    888        r78 = vaddq_s32(r7, r8);
    889        r678 = vaddq_s32(r6, r78);
    890        transpose_elems_inplace_s32_4x4(&r234, &r345, &r456, &r567);
    891        store_s32_4x4(dst2_ptr, dst_stride, r234, r345, r456, r567);
    892        dst2_ptr += 4;
    893 
    894        r4 = r8;
    895        r34 = r78;
    896        r234 = r678;
    897        w -= 4;
    898      } while (w > 0);
    899      h -= 4;
    900      count++;
    901    } while (h > 0);
    902  }
    903 }
    904 
    905 static inline int32x4_t cross_sum_inp_s32(int32_t *buf, int buf_stride) {
    906  int32x4_t xtr, xt, xtl, xl, x, xr, xbr, xb, xbl;
    907  int32x4_t fours, threes, res;
    908 
    909  xtl = vld1q_s32(buf - buf_stride - 1);
    910  xt = vld1q_s32(buf - buf_stride);
    911  xtr = vld1q_s32(buf - buf_stride + 1);
    912  xl = vld1q_s32(buf - 1);
    913  x = vld1q_s32(buf);
    914  xr = vld1q_s32(buf + 1);
    915  xbl = vld1q_s32(buf + buf_stride - 1);
    916  xb = vld1q_s32(buf + buf_stride);
    917  xbr = vld1q_s32(buf + buf_stride + 1);
    918 
    919  fours = vaddq_s32(xl, vaddq_s32(xt, vaddq_s32(xr, vaddq_s32(xb, x))));
    920  threes = vaddq_s32(xtl, vaddq_s32(xtr, vaddq_s32(xbr, xbl)));
    921  res = vsubq_s32(vshlq_n_s32(vaddq_s32(fours, threes), 2), threes);
    922  return res;
    923 }
    924 
    925 static inline void cross_sum_inp_u16(uint16_t *buf, int buf_stride,
    926                                     int32x4_t *a0, int32x4_t *a1) {
    927  uint16x8_t xtr, xt, xtl, xl, x, xr, xbr, xb, xbl;
    928  uint16x8_t r0, r1;
    929 
    930  xtl = vld1q_u16(buf - buf_stride - 1);
    931  xt = vld1q_u16(buf - buf_stride);
    932  xtr = vld1q_u16(buf - buf_stride + 1);
    933  xl = vld1q_u16(buf - 1);
    934  x = vld1q_u16(buf);
    935  xr = vld1q_u16(buf + 1);
    936  xbl = vld1q_u16(buf + buf_stride - 1);
    937  xb = vld1q_u16(buf + buf_stride);
    938  xbr = vld1q_u16(buf + buf_stride + 1);
    939 
    940  xb = vaddq_u16(xb, x);
    941  xt = vaddq_u16(xt, xr);
    942  xl = vaddq_u16(xl, xb);
    943  xl = vaddq_u16(xl, xt);
    944 
    945  r0 = vshlq_n_u16(xl, 2);
    946 
    947  xbl = vaddq_u16(xbl, xbr);
    948  xtl = vaddq_u16(xtl, xtr);
    949  xtl = vaddq_u16(xtl, xbl);
    950 
    951  r1 = vshlq_n_u16(xtl, 2);
    952  r1 = vsubq_u16(r1, xtl);
    953 
    954  *a0 = vreinterpretq_s32_u32(
    955      vaddq_u32(vmovl_u16(vget_low_u16(r0)), vmovl_u16(vget_low_u16(r1))));
    956  *a1 = vreinterpretq_s32_u32(
    957      vaddq_u32(vmovl_u16(vget_high_u16(r0)), vmovl_u16(vget_high_u16(r1))));
    958 }
    959 
    960 static inline int32x4_t cross_sum_fast_even_row(int32_t *buf, int buf_stride) {
    961  int32x4_t xtr, xt, xtl, xbr, xb, xbl;
    962  int32x4_t fives, sixes, fives_plus_sixes;
    963 
    964  xtl = vld1q_s32(buf - buf_stride - 1);
    965  xt = vld1q_s32(buf - buf_stride);
    966  xtr = vld1q_s32(buf - buf_stride + 1);
    967  xbl = vld1q_s32(buf + buf_stride - 1);
    968  xb = vld1q_s32(buf + buf_stride);
    969  xbr = vld1q_s32(buf + buf_stride + 1);
    970 
    971  fives = vaddq_s32(xtl, vaddq_s32(xtr, vaddq_s32(xbr, xbl)));
    972  sixes = vaddq_s32(xt, xb);
    973  fives_plus_sixes = vaddq_s32(fives, sixes);
    974 
    975  return vaddq_s32(
    976      vaddq_s32(vshlq_n_s32(fives_plus_sixes, 2), fives_plus_sixes), sixes);
    977 }
    978 
    979 static inline void cross_sum_fast_even_row_inp16(uint16_t *buf, int buf_stride,
    980                                                 int32x4_t *a0, int32x4_t *a1) {
    981  uint16x8_t xtr, xt, xtl, xbr, xb, xbl, xb0;
    982 
    983  xtl = vld1q_u16(buf - buf_stride - 1);
    984  xt = vld1q_u16(buf - buf_stride);
    985  xtr = vld1q_u16(buf - buf_stride + 1);
    986  xbl = vld1q_u16(buf + buf_stride - 1);
    987  xb = vld1q_u16(buf + buf_stride);
    988  xbr = vld1q_u16(buf + buf_stride + 1);
    989 
    990  xbr = vaddq_u16(xbr, xbl);
    991  xtr = vaddq_u16(xtr, xtl);
    992  xbr = vaddq_u16(xbr, xtr);
    993  xtl = vshlq_n_u16(xbr, 2);
    994  xbr = vaddq_u16(xtl, xbr);
    995 
    996  xb = vaddq_u16(xb, xt);
    997  xb0 = vshlq_n_u16(xb, 1);
    998  xb = vshlq_n_u16(xb, 2);
    999  xb = vaddq_u16(xb, xb0);
   1000 
   1001  *a0 = vreinterpretq_s32_u32(
   1002      vaddq_u32(vmovl_u16(vget_low_u16(xbr)), vmovl_u16(vget_low_u16(xb))));
   1003  *a1 = vreinterpretq_s32_u32(
   1004      vaddq_u32(vmovl_u16(vget_high_u16(xbr)), vmovl_u16(vget_high_u16(xb))));
   1005 }
   1006 
   1007 static inline int32x4_t cross_sum_fast_odd_row(int32_t *buf) {
   1008  int32x4_t xl, x, xr;
   1009  int32x4_t fives, sixes, fives_plus_sixes;
   1010 
   1011  xl = vld1q_s32(buf - 1);
   1012  x = vld1q_s32(buf);
   1013  xr = vld1q_s32(buf + 1);
   1014  fives = vaddq_s32(xl, xr);
   1015  sixes = x;
   1016  fives_plus_sixes = vaddq_s32(fives, sixes);
   1017 
   1018  return vaddq_s32(
   1019      vaddq_s32(vshlq_n_s32(fives_plus_sixes, 2), fives_plus_sixes), sixes);
   1020 }
   1021 
   1022 static inline void cross_sum_fast_odd_row_inp16(uint16_t *buf, int32x4_t *a0,
   1023                                                int32x4_t *a1) {
   1024  uint16x8_t xl, x, xr;
   1025  uint16x8_t x0;
   1026 
   1027  xl = vld1q_u16(buf - 1);
   1028  x = vld1q_u16(buf);
   1029  xr = vld1q_u16(buf + 1);
   1030  xl = vaddq_u16(xl, xr);
   1031  x0 = vshlq_n_u16(xl, 2);
   1032  xl = vaddq_u16(xl, x0);
   1033 
   1034  x0 = vshlq_n_u16(x, 1);
   1035  x = vshlq_n_u16(x, 2);
   1036  x = vaddq_u16(x, x0);
   1037 
   1038  *a0 = vreinterpretq_s32_u32(
   1039      vaddq_u32(vmovl_u16(vget_low_u16(xl)), vmovl_u16(vget_low_u16(x))));
   1040  *a1 = vreinterpretq_s32_u32(
   1041      vaddq_u32(vmovl_u16(vget_high_u16(xl)), vmovl_u16(vget_high_u16(x))));
   1042 }
   1043 
   1044 static void final_filter_fast_internal(uint16_t *A, int32_t *B,
   1045                                       const int buf_stride, int16_t *src,
   1046                                       const int src_stride, int32_t *dst,
   1047                                       const int dst_stride, const int width,
   1048                                       const int height) {
   1049  int16x8_t s0;
   1050  int32_t *B_tmp, *dst_ptr;
   1051  uint16_t *A_tmp;
   1052  int16_t *src_ptr;
   1053  int32x4_t a_res0, a_res1, b_res0, b_res1;
   1054  int w, h, count = 0;
   1055  assert(SGRPROJ_SGR_BITS == 8);
   1056  assert(SGRPROJ_RST_BITS == 4);
   1057 
   1058  A_tmp = A;
   1059  B_tmp = B;
   1060  src_ptr = src;
   1061  dst_ptr = dst;
   1062  h = height;
   1063  do {
   1064    A_tmp = (A + count * buf_stride);
   1065    B_tmp = (B + count * buf_stride);
   1066    src_ptr = (src + count * src_stride);
   1067    dst_ptr = (dst + count * dst_stride);
   1068    w = width;
   1069    if (!(count & 1)) {
   1070      do {
   1071        s0 = vld1q_s16(src_ptr);
   1072        cross_sum_fast_even_row_inp16(A_tmp, buf_stride, &a_res0, &a_res1);
   1073        a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0);
   1074        a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1);
   1075 
   1076        b_res0 = cross_sum_fast_even_row(B_tmp, buf_stride);
   1077        b_res1 = cross_sum_fast_even_row(B_tmp + 4, buf_stride);
   1078        a_res0 = vaddq_s32(a_res0, b_res0);
   1079        a_res1 = vaddq_s32(a_res1, b_res1);
   1080 
   1081        a_res0 =
   1082            vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
   1083        a_res1 =
   1084            vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
   1085 
   1086        vst1q_s32(dst_ptr, a_res0);
   1087        vst1q_s32(dst_ptr + 4, a_res1);
   1088 
   1089        A_tmp += 8;
   1090        B_tmp += 8;
   1091        src_ptr += 8;
   1092        dst_ptr += 8;
   1093        w -= 8;
   1094      } while (w > 0);
   1095    } else {
   1096      do {
   1097        s0 = vld1q_s16(src_ptr);
   1098        cross_sum_fast_odd_row_inp16(A_tmp, &a_res0, &a_res1);
   1099        a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0);
   1100        a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1);
   1101 
   1102        b_res0 = cross_sum_fast_odd_row(B_tmp);
   1103        b_res1 = cross_sum_fast_odd_row(B_tmp + 4);
   1104        a_res0 = vaddq_s32(a_res0, b_res0);
   1105        a_res1 = vaddq_s32(a_res1, b_res1);
   1106 
   1107        a_res0 =
   1108            vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_ODD - SGRPROJ_RST_BITS);
   1109        a_res1 =
   1110            vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_ODD - SGRPROJ_RST_BITS);
   1111 
   1112        vst1q_s32(dst_ptr, a_res0);
   1113        vst1q_s32(dst_ptr + 4, a_res1);
   1114 
   1115        A_tmp += 8;
   1116        B_tmp += 8;
   1117        src_ptr += 8;
   1118        dst_ptr += 8;
   1119        w -= 8;
   1120      } while (w > 0);
   1121    }
   1122    count++;
   1123    h -= 1;
   1124  } while (h > 0);
   1125 }
   1126 
   1127 static void final_filter_internal(uint16_t *A, int32_t *B, const int buf_stride,
   1128                                  int16_t *src, const int src_stride,
   1129                                  int32_t *dst, const int dst_stride,
   1130                                  const int width, const int height) {
   1131  int16x8_t s0;
   1132  int32_t *B_tmp, *dst_ptr;
   1133  uint16_t *A_tmp;
   1134  int16_t *src_ptr;
   1135  int32x4_t a_res0, a_res1, b_res0, b_res1;
   1136  int w, h, count = 0;
   1137 
   1138  assert(SGRPROJ_SGR_BITS == 8);
   1139  assert(SGRPROJ_RST_BITS == 4);
   1140  h = height;
   1141 
   1142  do {
   1143    A_tmp = (A + count * buf_stride);
   1144    B_tmp = (B + count * buf_stride);
   1145    src_ptr = (src + count * src_stride);
   1146    dst_ptr = (dst + count * dst_stride);
   1147    w = width;
   1148    do {
   1149      s0 = vld1q_s16(src_ptr);
   1150      cross_sum_inp_u16(A_tmp, buf_stride, &a_res0, &a_res1);
   1151      a_res0 = vmulq_s32(vmovl_s16(vget_low_s16(s0)), a_res0);
   1152      a_res1 = vmulq_s32(vmovl_s16(vget_high_s16(s0)), a_res1);
   1153 
   1154      b_res0 = cross_sum_inp_s32(B_tmp, buf_stride);
   1155      b_res1 = cross_sum_inp_s32(B_tmp + 4, buf_stride);
   1156      a_res0 = vaddq_s32(a_res0, b_res0);
   1157      a_res1 = vaddq_s32(a_res1, b_res1);
   1158 
   1159      a_res0 =
   1160          vrshrq_n_s32(a_res0, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
   1161      a_res1 =
   1162          vrshrq_n_s32(a_res1, SGRPROJ_SGR_BITS + NB_EVEN - SGRPROJ_RST_BITS);
   1163      vst1q_s32(dst_ptr, a_res0);
   1164      vst1q_s32(dst_ptr + 4, a_res1);
   1165 
   1166      A_tmp += 8;
   1167      B_tmp += 8;
   1168      src_ptr += 8;
   1169      dst_ptr += 8;
   1170      w -= 8;
   1171    } while (w > 0);
   1172    count++;
   1173    h -= 1;
   1174  } while (h > 0);
   1175 }
   1176 
   1177 static inline int restoration_fast_internal(uint16_t *dgd16, int width,
   1178                                            int height, int dgd_stride,
   1179                                            int32_t *dst, int dst_stride,
   1180                                            int bit_depth, int sgr_params_idx,
   1181                                            int radius_idx) {
   1182  const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx];
   1183  const int r = params->r[radius_idx];
   1184  const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
   1185  const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
   1186  const int buf_stride = ((width_ext + 3) & ~3) + 16;
   1187 
   1188  const size_t buf_size = 3 * sizeof(int32_t) * RESTORATION_PROC_UNIT_PELS;
   1189  int32_t *buf = aom_memalign(8, buf_size);
   1190  if (!buf) return -1;
   1191 
   1192  int32_t *square_sum_buf = buf;
   1193  int32_t *sum_buf = square_sum_buf + RESTORATION_PROC_UNIT_PELS;
   1194  uint16_t *tmp16_buf = (uint16_t *)(sum_buf + RESTORATION_PROC_UNIT_PELS);
   1195  assert((char *)(sum_buf + RESTORATION_PROC_UNIT_PELS) <=
   1196             (char *)buf + buf_size &&
   1197         "Allocated buffer is too small. Resize the buffer.");
   1198 
   1199  assert(r <= MAX_RADIUS && "Need MAX_RADIUS >= r");
   1200  assert(r <= SGRPROJ_BORDER_VERT - 1 && r <= SGRPROJ_BORDER_HORZ - 1 &&
   1201         "Need SGRPROJ_BORDER_* >= r+1");
   1202 
   1203  assert(radius_idx == 0);
   1204  assert(r == 2);
   1205 
   1206  // input(dgd16) is 16bit.
   1207  // sum of pixels 1st stage output will be in 16bit(tmp16_buf). End output is
   1208  // kept in 32bit [sum_buf]. sum of squares output is kept in 32bit
   1209  // buffer(square_sum_buf).
   1210  boxsum2((int16_t *)(dgd16 - dgd_stride * SGRPROJ_BORDER_VERT -
   1211                      SGRPROJ_BORDER_HORZ),
   1212          dgd_stride, (int16_t *)tmp16_buf, sum_buf, square_sum_buf, buf_stride,
   1213          width_ext, height_ext);
   1214 
   1215  square_sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
   1216  sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
   1217  tmp16_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
   1218 
   1219  // Calculation of a, b. a output is in 16bit tmp_buf which is in range of
   1220  // [1, 256] for all bit depths. b output is kept in 32bit buffer.
   1221 
   1222 #if CONFIG_AV1_HIGHBITDEPTH
   1223  if (bit_depth > 8) {
   1224    calc_ab_fast_internal_hbd(
   1225        (square_sum_buf - buf_stride - 1), (tmp16_buf - buf_stride - 1),
   1226        (sum_buf - buf_stride - 1), buf_stride * 2, width + 2, height + 2,
   1227        bit_depth, r, params->s[radius_idx], 2);
   1228  } else {
   1229    calc_ab_fast_internal_lbd(
   1230        (square_sum_buf - buf_stride - 1), (tmp16_buf - buf_stride - 1),
   1231        (sum_buf - buf_stride - 1), buf_stride * 2, width + 2, height + 2, r,
   1232        params->s[radius_idx], 2);
   1233  }
   1234 #else
   1235  (void)bit_depth;
   1236  calc_ab_fast_internal_lbd((square_sum_buf - buf_stride - 1),
   1237                            (tmp16_buf - buf_stride - 1),
   1238                            (sum_buf - buf_stride - 1), buf_stride * 2,
   1239                            width + 2, height + 2, r, params->s[radius_idx], 2);
   1240 #endif
   1241  final_filter_fast_internal(tmp16_buf, sum_buf, buf_stride, (int16_t *)dgd16,
   1242                             dgd_stride, dst, dst_stride, width, height);
   1243  aom_free(buf);
   1244  return 0;
   1245 }
   1246 
   1247 static inline int restoration_internal(uint16_t *dgd16, int width, int height,
   1248                                       int dgd_stride, int32_t *dst,
   1249                                       int dst_stride, int bit_depth,
   1250                                       int sgr_params_idx, int radius_idx) {
   1251  const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx];
   1252  const int r = params->r[radius_idx];
   1253  const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
   1254  const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
   1255  const int buf_stride = ((width_ext + 3) & ~3) + 16;
   1256 
   1257  const size_t buf_size = 3 * sizeof(int32_t) * RESTORATION_PROC_UNIT_PELS;
   1258  int32_t *buf = aom_memalign(8, buf_size);
   1259  if (!buf) return -1;
   1260 
   1261  int32_t *square_sum_buf = buf;
   1262  int32_t *B = square_sum_buf + RESTORATION_PROC_UNIT_PELS;
   1263  uint16_t *A16 = (uint16_t *)(B + RESTORATION_PROC_UNIT_PELS);
   1264  uint16_t *sum_buf = A16 + RESTORATION_PROC_UNIT_PELS;
   1265 
   1266  assert((char *)(sum_buf + RESTORATION_PROC_UNIT_PELS) <=
   1267             (char *)buf + buf_size &&
   1268         "Allocated buffer is too small. Resize the buffer.");
   1269 
   1270  assert(r <= MAX_RADIUS && "Need MAX_RADIUS >= r");
   1271  assert(r <= SGRPROJ_BORDER_VERT - 1 && r <= SGRPROJ_BORDER_HORZ - 1 &&
   1272         "Need SGRPROJ_BORDER_* >= r+1");
   1273 
   1274  assert(radius_idx == 1);
   1275  assert(r == 1);
   1276 
   1277  // input(dgd16) is 16bit.
   1278  // sum of pixels output will be in 16bit(sum_buf).
   1279  // sum of squares output is kept in 32bit buffer(square_sum_buf).
   1280  boxsum1((int16_t *)(dgd16 - dgd_stride * SGRPROJ_BORDER_VERT -
   1281                      SGRPROJ_BORDER_HORZ),
   1282          dgd_stride, sum_buf, square_sum_buf, buf_stride, width_ext,
   1283          height_ext);
   1284 
   1285  square_sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
   1286  B += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
   1287  A16 += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
   1288  sum_buf += SGRPROJ_BORDER_VERT * buf_stride + SGRPROJ_BORDER_HORZ;
   1289 
   1290 #if CONFIG_AV1_HIGHBITDEPTH
   1291  // Calculation of a, b. a output is in 16bit tmp_buf which is in range of
   1292  // [1, 256] for all bit depths. b output is kept in 32bit buffer.
   1293  if (bit_depth > 8) {
   1294    calc_ab_internal_hbd((square_sum_buf - buf_stride - 1),
   1295                         (A16 - buf_stride - 1), (sum_buf - buf_stride - 1),
   1296                         (B - buf_stride - 1), buf_stride, width + 2,
   1297                         height + 2, bit_depth, r, params->s[radius_idx], 1);
   1298  } else {
   1299    calc_ab_internal_lbd((square_sum_buf - buf_stride - 1),
   1300                         (A16 - buf_stride - 1), (sum_buf - buf_stride - 1),
   1301                         (B - buf_stride - 1), buf_stride, width + 2,
   1302                         height + 2, r, params->s[radius_idx], 1);
   1303  }
   1304 #else
   1305  (void)bit_depth;
   1306  calc_ab_internal_lbd((square_sum_buf - buf_stride - 1),
   1307                       (A16 - buf_stride - 1), (sum_buf - buf_stride - 1),
   1308                       (B - buf_stride - 1), buf_stride, width + 2, height + 2,
   1309                       r, params->s[radius_idx], 1);
   1310 #endif
   1311  final_filter_internal(A16, B, buf_stride, (int16_t *)dgd16, dgd_stride, dst,
   1312                        dst_stride, width, height);
   1313  aom_free(buf);
   1314  return 0;
   1315 }
   1316 
   1317 static inline void src_convert_u8_to_u16(const uint8_t *src,
   1318                                         const int src_stride, uint16_t *dst,
   1319                                         const int dst_stride, const int width,
   1320                                         const int height) {
   1321  const uint8_t *src_ptr;
   1322  uint16_t *dst_ptr;
   1323  int h, w, count = 0;
   1324 
   1325  uint8x8_t t1, t2, t3, t4;
   1326  uint16x8_t s1, s2, s3, s4;
   1327  h = height;
   1328  do {
   1329    src_ptr = src + (count << 2) * src_stride;
   1330    dst_ptr = dst + (count << 2) * dst_stride;
   1331    w = width;
   1332    if (w >= 7) {
   1333      do {
   1334        load_u8_8x4(src_ptr, src_stride, &t1, &t2, &t3, &t4);
   1335        s1 = vmovl_u8(t1);
   1336        s2 = vmovl_u8(t2);
   1337        s3 = vmovl_u8(t3);
   1338        s4 = vmovl_u8(t4);
   1339        store_u16_8x4(dst_ptr, dst_stride, s1, s2, s3, s4);
   1340 
   1341        src_ptr += 8;
   1342        dst_ptr += 8;
   1343        w -= 8;
   1344      } while (w > 7);
   1345    }
   1346 
   1347    for (int y = 0; y < w; y++) {
   1348      dst_ptr[y] = src_ptr[y];
   1349      dst_ptr[y + 1 * dst_stride] = src_ptr[y + 1 * src_stride];
   1350      dst_ptr[y + 2 * dst_stride] = src_ptr[y + 2 * src_stride];
   1351      dst_ptr[y + 3 * dst_stride] = src_ptr[y + 3 * src_stride];
   1352    }
   1353    count++;
   1354    h -= 4;
   1355  } while (h > 3);
   1356 
   1357  src_ptr = src + (count << 2) * src_stride;
   1358  dst_ptr = dst + (count << 2) * dst_stride;
   1359  for (int x = 0; x < h; x++) {
   1360    for (int y = 0; y < width; y++) {
   1361      dst_ptr[y + x * dst_stride] = src_ptr[y + x * src_stride];
   1362    }
   1363  }
   1364 
   1365  // memset uninitialized rows of src buffer as they are needed for the
   1366  // boxsum filter calculation.
   1367  for (int x = height; x < height + 5; x++)
   1368    memset(dst + x * dst_stride, 0, (width + 2) * sizeof(*dst));
   1369 }
   1370 
   1371 #if CONFIG_AV1_HIGHBITDEPTH
   1372 static inline void src_convert_hbd_copy(const uint16_t *src, int src_stride,
   1373                                        uint16_t *dst, const int dst_stride,
   1374                                        int width, int height) {
   1375  const uint16_t *src_ptr;
   1376  uint16_t *dst_ptr;
   1377  int h, w, count = 0;
   1378  uint16x8_t s1, s2, s3, s4;
   1379 
   1380  h = height;
   1381  do {
   1382    src_ptr = src + (count << 2) * src_stride;
   1383    dst_ptr = dst + (count << 2) * dst_stride;
   1384    w = width;
   1385    do {
   1386      load_u16_8x4(src_ptr, src_stride, &s1, &s2, &s3, &s4);
   1387      store_u16_8x4(dst_ptr, dst_stride, s1, s2, s3, s4);
   1388      src_ptr += 8;
   1389      dst_ptr += 8;
   1390      w -= 8;
   1391    } while (w > 7);
   1392 
   1393    for (int y = 0; y < w; y++) {
   1394      dst_ptr[y] = src_ptr[y];
   1395      dst_ptr[y + 1 * dst_stride] = src_ptr[y + 1 * src_stride];
   1396      dst_ptr[y + 2 * dst_stride] = src_ptr[y + 2 * src_stride];
   1397      dst_ptr[y + 3 * dst_stride] = src_ptr[y + 3 * src_stride];
   1398    }
   1399    count++;
   1400    h -= 4;
   1401  } while (h > 3);
   1402 
   1403  src_ptr = src + (count << 2) * src_stride;
   1404  dst_ptr = dst + (count << 2) * dst_stride;
   1405 
   1406  for (int x = 0; x < h; x++) {
   1407    memcpy((dst_ptr + x * dst_stride), (src_ptr + x * src_stride),
   1408           sizeof(uint16_t) * width);
   1409  }
   1410  // memset uninitialized rows of src buffer as they are needed for the
   1411  // boxsum filter calculation.
   1412  for (int x = height; x < height + 5; x++)
   1413    memset(dst + x * dst_stride, 0, (width + 2) * sizeof(*dst));
   1414 }
   1415 #endif  // CONFIG_AV1_HIGHBITDEPTH
   1416 
   1417 int av1_selfguided_restoration_neon(const uint8_t *dat8, int width, int height,
   1418                                    int stride, int32_t *flt0, int32_t *flt1,
   1419                                    int flt_stride, int sgr_params_idx,
   1420                                    int bit_depth, int highbd) {
   1421  const sgr_params_type *const params = &av1_sgr_params[sgr_params_idx];
   1422  assert(!(params->r[0] == 0 && params->r[1] == 0));
   1423 
   1424  uint16_t dgd16_[RESTORATION_PROC_UNIT_PELS];
   1425  const int dgd16_stride = width + 2 * SGRPROJ_BORDER_HORZ;
   1426  uint16_t *dgd16 =
   1427      dgd16_ + dgd16_stride * SGRPROJ_BORDER_VERT + SGRPROJ_BORDER_HORZ;
   1428  const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
   1429  const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
   1430  const int dgd_stride = stride;
   1431 
   1432 #if CONFIG_AV1_HIGHBITDEPTH
   1433  if (highbd) {
   1434    const uint16_t *dgd16_tmp = CONVERT_TO_SHORTPTR(dat8);
   1435    src_convert_hbd_copy(
   1436        dgd16_tmp - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
   1437        dgd_stride,
   1438        dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
   1439        dgd16_stride, width_ext, height_ext);
   1440  } else {
   1441    src_convert_u8_to_u16(
   1442        dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
   1443        dgd_stride,
   1444        dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
   1445        dgd16_stride, width_ext, height_ext);
   1446  }
   1447 #else
   1448  (void)highbd;
   1449  src_convert_u8_to_u16(
   1450      dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, dgd_stride,
   1451      dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
   1452      dgd16_stride, width_ext, height_ext);
   1453 #endif
   1454 
   1455  if (params->r[0] > 0) {
   1456    int ret =
   1457        restoration_fast_internal(dgd16, width, height, dgd16_stride, flt0,
   1458                                  flt_stride, bit_depth, sgr_params_idx, 0);
   1459    if (ret != 0) return ret;
   1460  }
   1461  if (params->r[1] > 0) {
   1462    int ret = restoration_internal(dgd16, width, height, dgd16_stride, flt1,
   1463                                   flt_stride, bit_depth, sgr_params_idx, 1);
   1464    if (ret != 0) return ret;
   1465  }
   1466  return 0;
   1467 }
   1468 
   1469 int av1_apply_selfguided_restoration_neon(const uint8_t *dat8, int width,
   1470                                          int height, int stride, int eps,
   1471                                          const int *xqd, uint8_t *dst8,
   1472                                          int dst_stride, int32_t *tmpbuf,
   1473                                          int bit_depth, int highbd) {
   1474  int32_t *flt0 = tmpbuf;
   1475  int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX;
   1476  assert(width * height <= RESTORATION_UNITPELS_MAX);
   1477  uint16_t dgd16_[RESTORATION_PROC_UNIT_PELS];
   1478  const int dgd16_stride = width + 2 * SGRPROJ_BORDER_HORZ;
   1479  uint16_t *dgd16 =
   1480      dgd16_ + dgd16_stride * SGRPROJ_BORDER_VERT + SGRPROJ_BORDER_HORZ;
   1481  const int width_ext = width + 2 * SGRPROJ_BORDER_HORZ;
   1482  const int height_ext = height + 2 * SGRPROJ_BORDER_VERT;
   1483  const int dgd_stride = stride;
   1484  const sgr_params_type *const params = &av1_sgr_params[eps];
   1485  int xq[2];
   1486 
   1487  assert(!(params->r[0] == 0 && params->r[1] == 0));
   1488 
   1489 #if CONFIG_AV1_HIGHBITDEPTH
   1490  if (highbd) {
   1491    const uint16_t *dgd16_tmp = CONVERT_TO_SHORTPTR(dat8);
   1492    src_convert_hbd_copy(
   1493        dgd16_tmp - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
   1494        dgd_stride,
   1495        dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
   1496        dgd16_stride, width_ext, height_ext);
   1497  } else {
   1498    src_convert_u8_to_u16(
   1499        dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ,
   1500        dgd_stride,
   1501        dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
   1502        dgd16_stride, width_ext, height_ext);
   1503  }
   1504 #else
   1505  (void)highbd;
   1506  src_convert_u8_to_u16(
   1507      dat8 - SGRPROJ_BORDER_VERT * dgd_stride - SGRPROJ_BORDER_HORZ, dgd_stride,
   1508      dgd16 - SGRPROJ_BORDER_VERT * dgd16_stride - SGRPROJ_BORDER_HORZ,
   1509      dgd16_stride, width_ext, height_ext);
   1510 #endif
   1511  if (params->r[0] > 0) {
   1512    int ret = restoration_fast_internal(dgd16, width, height, dgd16_stride,
   1513                                        flt0, width, bit_depth, eps, 0);
   1514    if (ret != 0) return ret;
   1515  }
   1516  if (params->r[1] > 0) {
   1517    int ret = restoration_internal(dgd16, width, height, dgd16_stride, flt1,
   1518                                   width, bit_depth, eps, 1);
   1519    if (ret != 0) return ret;
   1520  }
   1521 
   1522  av1_decode_xq(xqd, xq, params);
   1523 
   1524  {
   1525    int16_t *src_ptr;
   1526    uint8_t *dst_ptr;
   1527 #if CONFIG_AV1_HIGHBITDEPTH
   1528    uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst8);
   1529    uint16_t *dst16_ptr;
   1530 #endif
   1531    int16x4_t d0, d4;
   1532    int16x8_t r0, s0;
   1533    uint16x8_t r4;
   1534    int32x4_t u0, u4, v0, v4, f00, f10;
   1535    uint8x8_t t0;
   1536    int count = 0, w = width, h = height, rc = 0;
   1537 
   1538    const int32x4_t xq0_vec = vdupq_n_s32(xq[0]);
   1539    const int32x4_t xq1_vec = vdupq_n_s32(xq[1]);
   1540    const int16x8_t zero = vdupq_n_s16(0);
   1541    const uint16x8_t max = vdupq_n_u16((1 << bit_depth) - 1);
   1542    src_ptr = (int16_t *)dgd16;
   1543    do {
   1544      w = width;
   1545      count = 0;
   1546      dst_ptr = dst8 + rc * dst_stride;
   1547 #if CONFIG_AV1_HIGHBITDEPTH
   1548      dst16_ptr = dst16 + rc * dst_stride;
   1549 #endif
   1550      do {
   1551        s0 = vld1q_s16(src_ptr + count);
   1552 
   1553        u0 = vshll_n_s16(vget_low_s16(s0), SGRPROJ_RST_BITS);
   1554        u4 = vshll_n_s16(vget_high_s16(s0), SGRPROJ_RST_BITS);
   1555 
   1556        v0 = vshlq_n_s32(u0, SGRPROJ_PRJ_BITS);
   1557        v4 = vshlq_n_s32(u4, SGRPROJ_PRJ_BITS);
   1558 
   1559        if (params->r[0] > 0) {
   1560          f00 = vld1q_s32(flt0 + count);
   1561          f10 = vld1q_s32(flt0 + count + 4);
   1562 
   1563          f00 = vsubq_s32(f00, u0);
   1564          f10 = vsubq_s32(f10, u4);
   1565 
   1566          v0 = vmlaq_s32(v0, xq0_vec, f00);
   1567          v4 = vmlaq_s32(v4, xq0_vec, f10);
   1568        }
   1569 
   1570        if (params->r[1] > 0) {
   1571          f00 = vld1q_s32(flt1 + count);
   1572          f10 = vld1q_s32(flt1 + count + 4);
   1573 
   1574          f00 = vsubq_s32(f00, u0);
   1575          f10 = vsubq_s32(f10, u4);
   1576 
   1577          v0 = vmlaq_s32(v0, xq1_vec, f00);
   1578          v4 = vmlaq_s32(v4, xq1_vec, f10);
   1579        }
   1580 
   1581        d0 = vqrshrn_n_s32(v0, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
   1582        d4 = vqrshrn_n_s32(v4, SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
   1583 
   1584        r0 = vcombine_s16(d0, d4);
   1585 
   1586        r4 = vreinterpretq_u16_s16(vmaxq_s16(r0, zero));
   1587 
   1588 #if CONFIG_AV1_HIGHBITDEPTH
   1589        if (highbd) {
   1590          r4 = vminq_u16(r4, max);
   1591          vst1q_u16(dst16_ptr, r4);
   1592          dst16_ptr += 8;
   1593        } else {
   1594          t0 = vqmovn_u16(r4);
   1595          vst1_u8(dst_ptr, t0);
   1596          dst_ptr += 8;
   1597        }
   1598 #else
   1599        (void)max;
   1600        t0 = vqmovn_u16(r4);
   1601        vst1_u8(dst_ptr, t0);
   1602        dst_ptr += 8;
   1603 #endif
   1604        w -= 8;
   1605        count += 8;
   1606      } while (w > 0);
   1607 
   1608      src_ptr += dgd16_stride;
   1609      flt1 += width;
   1610      flt0 += width;
   1611      rc++;
   1612      h--;
   1613    } while (h > 0);
   1614  }
   1615  return 0;
   1616 }