tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

highbd_sse_neon.c (10973B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 
     14 #include "config/aom_dsp_rtcd.h"
     15 #include "aom_dsp/arm/sum_neon.h"
     16 
     17 static inline void highbd_sse_8x1_init_neon(const uint16_t *src,
     18                                            const uint16_t *ref,
     19                                            uint32x4_t *sse_acc0,
     20                                            uint32x4_t *sse_acc1) {
     21  uint16x8_t s = vld1q_u16(src);
     22  uint16x8_t r = vld1q_u16(ref);
     23 
     24  uint16x8_t abs_diff = vabdq_u16(s, r);
     25  uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
     26  uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
     27 
     28  *sse_acc0 = vmull_u16(abs_diff_lo, abs_diff_lo);
     29  *sse_acc1 = vmull_u16(abs_diff_hi, abs_diff_hi);
     30 }
     31 
     32 static inline void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref,
     33                                       uint32x4_t *sse_acc0,
     34                                       uint32x4_t *sse_acc1) {
     35  uint16x8_t s = vld1q_u16(src);
     36  uint16x8_t r = vld1q_u16(ref);
     37 
     38  uint16x8_t abs_diff = vabdq_u16(s, r);
     39  uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
     40  uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
     41 
     42  *sse_acc0 = vmlal_u16(*sse_acc0, abs_diff_lo, abs_diff_lo);
     43  *sse_acc1 = vmlal_u16(*sse_acc1, abs_diff_hi, abs_diff_hi);
     44 }
     45 
     46 static inline int64_t highbd_sse_128xh_neon(const uint16_t *src, int src_stride,
     47                                            const uint16_t *ref, int ref_stride,
     48                                            int height) {
     49  uint32x4_t sse[16];
     50  highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
     51  highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
     52  highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
     53  highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
     54  highbd_sse_8x1_init_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]);
     55  highbd_sse_8x1_init_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]);
     56  highbd_sse_8x1_init_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]);
     57  highbd_sse_8x1_init_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]);
     58  highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]);
     59  highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]);
     60  highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]);
     61  highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]);
     62  highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]);
     63  highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]);
     64  highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]);
     65  highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]);
     66 
     67  src += src_stride;
     68  ref += ref_stride;
     69 
     70  while (--height != 0) {
     71    highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
     72    highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
     73    highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
     74    highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
     75    highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[8], &sse[9]);
     76    highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[10], &sse[11]);
     77    highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[12], &sse[13]);
     78    highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[14], &sse[15]);
     79    highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0], &sse[1]);
     80    highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[2], &sse[3]);
     81    highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[4], &sse[5]);
     82    highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[6], &sse[7]);
     83    highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[8], &sse[9]);
     84    highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[10], &sse[11]);
     85    highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[12], &sse[13]);
     86    highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[14], &sse[15]);
     87 
     88    src += src_stride;
     89    ref += ref_stride;
     90  }
     91 
     92  return horizontal_long_add_u32x4_x16(sse);
     93 }
     94 
     95 static inline int64_t highbd_sse_64xh_neon(const uint16_t *src, int src_stride,
     96                                           const uint16_t *ref, int ref_stride,
     97                                           int height) {
     98  uint32x4_t sse[8];
     99  highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
    100  highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
    101  highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
    102  highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
    103  highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]);
    104  highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]);
    105  highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]);
    106  highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]);
    107 
    108  src += src_stride;
    109  ref += ref_stride;
    110 
    111  while (--height != 0) {
    112    highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
    113    highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
    114    highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
    115    highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
    116    highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0], &sse[1]);
    117    highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[2], &sse[3]);
    118    highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[4], &sse[5]);
    119    highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[6], &sse[7]);
    120 
    121    src += src_stride;
    122    ref += ref_stride;
    123  }
    124 
    125  return horizontal_long_add_u32x4_x8(sse);
    126 }
    127 
    128 static inline int64_t highbd_sse_32xh_neon(const uint16_t *src, int src_stride,
    129                                           const uint16_t *ref, int ref_stride,
    130                                           int height) {
    131  uint32x4_t sse[8];
    132  highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
    133  highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
    134  highbd_sse_8x1_init_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
    135  highbd_sse_8x1_init_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
    136 
    137  src += src_stride;
    138  ref += ref_stride;
    139 
    140  while (--height != 0) {
    141    highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
    142    highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
    143    highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[4], &sse[5]);
    144    highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[6], &sse[7]);
    145 
    146    src += src_stride;
    147    ref += ref_stride;
    148  }
    149 
    150  return horizontal_long_add_u32x4_x8(sse);
    151 }
    152 
    153 static inline int64_t highbd_sse_16xh_neon(const uint16_t *src, int src_stride,
    154                                           const uint16_t *ref, int ref_stride,
    155                                           int height) {
    156  uint32x4_t sse[4];
    157  highbd_sse_8x1_init_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
    158  highbd_sse_8x1_init_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
    159 
    160  src += src_stride;
    161  ref += ref_stride;
    162 
    163  while (--height != 0) {
    164    highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0], &sse[1]);
    165    highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[2], &sse[3]);
    166 
    167    src += src_stride;
    168    ref += ref_stride;
    169  }
    170 
    171  return horizontal_long_add_u32x4_x4(sse);
    172 }
    173 
    174 static inline int64_t highbd_sse_8xh_neon(const uint16_t *src, int src_stride,
    175                                          const uint16_t *ref, int ref_stride,
    176                                          int height) {
    177  uint32x4_t sse[2];
    178  highbd_sse_8x1_init_neon(src, ref, &sse[0], &sse[1]);
    179 
    180  src += src_stride;
    181  ref += ref_stride;
    182 
    183  while (--height != 0) {
    184    highbd_sse_8x1_neon(src, ref, &sse[0], &sse[1]);
    185 
    186    src += src_stride;
    187    ref += ref_stride;
    188  }
    189 
    190  return horizontal_long_add_u32x4_x2(sse);
    191 }
    192 
    193 static inline int64_t highbd_sse_4xh_neon(const uint16_t *src, int src_stride,
    194                                          const uint16_t *ref, int ref_stride,
    195                                          int height) {
    196  // Peel the first loop iteration.
    197  uint16x4_t s = vld1_u16(src);
    198  uint16x4_t r = vld1_u16(ref);
    199 
    200  uint16x4_t abs_diff = vabd_u16(s, r);
    201  uint32x4_t sse = vmull_u16(abs_diff, abs_diff);
    202 
    203  src += src_stride;
    204  ref += ref_stride;
    205 
    206  while (--height != 0) {
    207    s = vld1_u16(src);
    208    r = vld1_u16(ref);
    209 
    210    abs_diff = vabd_u16(s, r);
    211    sse = vmlal_u16(sse, abs_diff, abs_diff);
    212 
    213    src += src_stride;
    214    ref += ref_stride;
    215  }
    216 
    217  return horizontal_long_add_u32x4(sse);
    218 }
    219 
    220 static inline int64_t highbd_sse_wxh_neon(const uint16_t *src, int src_stride,
    221                                          const uint16_t *ref, int ref_stride,
    222                                          int width, int height) {
    223  // { 0, 1, 2, 3, 4, 5, 6, 7 }
    224  uint16x8_t k01234567 = vmovl_u8(vcreate_u8(0x0706050403020100));
    225  uint16x8_t remainder_mask = vcltq_u16(k01234567, vdupq_n_u16(width & 7));
    226  uint64_t sse = 0;
    227 
    228  do {
    229    int w = width;
    230    int offset = 0;
    231 
    232    do {
    233      uint16x8_t s = vld1q_u16(src + offset);
    234      uint16x8_t r = vld1q_u16(ref + offset);
    235 
    236      if (w < 8) {
    237        // Mask out-of-range elements.
    238        s = vandq_u16(s, remainder_mask);
    239        r = vandq_u16(r, remainder_mask);
    240      }
    241 
    242      uint16x8_t abs_diff = vabdq_u16(s, r);
    243      uint16x4_t abs_diff_lo = vget_low_u16(abs_diff);
    244      uint16x4_t abs_diff_hi = vget_high_u16(abs_diff);
    245 
    246      uint32x4_t sse_u32 = vmull_u16(abs_diff_lo, abs_diff_lo);
    247      sse_u32 = vmlal_u16(sse_u32, abs_diff_hi, abs_diff_hi);
    248 
    249      sse += horizontal_long_add_u32x4(sse_u32);
    250 
    251      offset += 8;
    252      w -= 8;
    253    } while (w > 0);
    254 
    255    src += src_stride;
    256    ref += ref_stride;
    257  } while (--height != 0);
    258 
    259  return sse;
    260 }
    261 
    262 int64_t aom_highbd_sse_neon(const uint8_t *src8, int src_stride,
    263                            const uint8_t *ref8, int ref_stride, int width,
    264                            int height) {
    265  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
    266  uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
    267 
    268  switch (width) {
    269    case 4:
    270      return highbd_sse_4xh_neon(src, src_stride, ref, ref_stride, height);
    271    case 8:
    272      return highbd_sse_8xh_neon(src, src_stride, ref, ref_stride, height);
    273    case 16:
    274      return highbd_sse_16xh_neon(src, src_stride, ref, ref_stride, height);
    275    case 32:
    276      return highbd_sse_32xh_neon(src, src_stride, ref, ref_stride, height);
    277    case 64:
    278      return highbd_sse_64xh_neon(src, src_stride, ref, ref_stride, height);
    279    case 128:
    280      return highbd_sse_128xh_neon(src, src_stride, ref, ref_stride, height);
    281    default:
    282      return highbd_sse_wxh_neon(src, src_stride, ref, ref_stride, width,
    283                                 height);
    284  }
    285 }