tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

reconintra_neon.c (12807B)


      1 /*
      2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 #include <assert.h>
     14 
     15 #include "config/aom_config.h"
     16 #include "config/av1_rtcd.h"
     17 
     18 #include "aom/aom_integer.h"
     19 #include "aom_dsp/arm/mem_neon.h"
     20 #include "aom_dsp/arm/sum_neon.h"
     21 
     22 #define MAX_UPSAMPLE_SZ 16
     23 
     24 // These kernels are a transposed version of those defined in reconintra.c,
     25 // with the absolute value of the negatives taken in the top row.
     26 DECLARE_ALIGNED(16, const uint8_t,
     27                av1_filter_intra_taps_neon[FILTER_INTRA_MODES][7][8]) = {
     28  // clang-format off
     29  {
     30      {  6,  5,  3,  3,  4,  3,  3,  3 },
     31      { 10,  2,  1,  1,  6,  2,  2,  1 },
     32      {  0, 10,  1,  1,  0,  6,  2,  2 },
     33      {  0,  0, 10,  2,  0,  0,  6,  2 },
     34      {  0,  0,  0, 10,  0,  0,  0,  6 },
     35      { 12,  9,  7,  5,  2,  2,  2,  3 },
     36      {  0,  0,  0,  0, 12,  9,  7,  5 }
     37  },
     38  {
     39      { 10,  6,  4,  2, 10,  6,  4,  2 },
     40      { 16,  0,  0,  0, 16,  0,  0,  0 },
     41      {  0, 16,  0,  0,  0, 16,  0,  0 },
     42      {  0,  0, 16,  0,  0,  0, 16,  0 },
     43      {  0,  0,  0, 16,  0,  0,  0, 16 },
     44      { 10,  6,  4,  2,  0,  0,  0,  0 },
     45      {  0,  0,  0,  0, 10,  6,  4,  2 }
     46  },
     47  {
     48      {  8,  8,  8,  8,  4,  4,  4,  4 },
     49      {  8,  0,  0,  0,  4,  0,  0,  0 },
     50      {  0,  8,  0,  0,  0,  4,  0,  0 },
     51      {  0,  0,  8,  0,  0,  0,  4,  0 },
     52      {  0,  0,  0,  8,  0,  0,  0,  4 },
     53      { 16, 16, 16, 16,  0,  0,  0,  0 },
     54      {  0,  0,  0,  0, 16, 16, 16, 16 }
     55  },
     56  {
     57      {  2,  1,  1,  0,  1,  1,  1,  1 },
     58      {  8,  3,  2,  1,  4,  3,  2,  2 },
     59      {  0,  8,  3,  2,  0,  4,  3,  2 },
     60      {  0,  0,  8,  3,  0,  0,  4,  3 },
     61      {  0,  0,  0,  8,  0,  0,  0,  4 },
     62      { 10,  6,  4,  2,  3,  4,  4,  3 },
     63      {  0,  0,  0,  0, 10,  6,  4,  3 }
     64  },
     65  {
     66      { 12, 10,  9,  8, 10,  9,  8,  7 },
     67      { 14,  0,  0,  0, 12,  1,  0,  0 },
     68      {  0, 14,  0,  0,  0, 12,  0,  0 },
     69      {  0,  0, 14,  0,  0,  0, 12,  1 },
     70      {  0,  0,  0, 14,  0,  0,  0, 12 },
     71      { 14, 12, 11, 10,  0,  0,  1,  1 },
     72      {  0,  0,  0,  0, 14, 12, 11,  9 }
     73  }
     74  // clang-format on
     75 };
     76 
     77 #define FILTER_INTRA_SCALE_BITS 4
     78 
     79 void av1_filter_intra_predictor_neon(uint8_t *dst, ptrdiff_t stride,
     80                                     TX_SIZE tx_size, const uint8_t *above,
     81                                     const uint8_t *left, int mode) {
     82  const int width = tx_size_wide[tx_size];
     83  const int height = tx_size_high[tx_size];
     84  assert(width <= 32 && height <= 32);
     85 
     86  const uint8x8_t f0 = vld1_u8(av1_filter_intra_taps_neon[mode][0]);
     87  const uint8x8_t f1 = vld1_u8(av1_filter_intra_taps_neon[mode][1]);
     88  const uint8x8_t f2 = vld1_u8(av1_filter_intra_taps_neon[mode][2]);
     89  const uint8x8_t f3 = vld1_u8(av1_filter_intra_taps_neon[mode][3]);
     90  const uint8x8_t f4 = vld1_u8(av1_filter_intra_taps_neon[mode][4]);
     91  const uint8x8_t f5 = vld1_u8(av1_filter_intra_taps_neon[mode][5]);
     92  const uint8x8_t f6 = vld1_u8(av1_filter_intra_taps_neon[mode][6]);
     93 
     94  uint8_t buffer[33][33];
     95  // Populate the top row in the scratch buffer with data from above.
     96  memcpy(buffer[0], &above[-1], (width + 1) * sizeof(uint8_t));
     97  // Populate the first column in the scratch buffer with data from the left.
     98  int r = 0;
     99  do {
    100    buffer[r + 1][0] = left[r];
    101  } while (++r < height);
    102 
    103  // Computing 4 cols per iteration (instead of 8) for 8x<h> blocks is faster.
    104  if (width <= 8) {
    105    r = 1;
    106    do {
    107      int c = 1;
    108      uint8x8_t s0 = vld1_dup_u8(&buffer[r - 1][c - 1]);
    109      uint8x8_t s5 = vld1_dup_u8(&buffer[r + 0][c - 1]);
    110      uint8x8_t s6 = vld1_dup_u8(&buffer[r + 1][c - 1]);
    111 
    112      do {
    113        uint8x8_t s1234 = load_unaligned_u8_4x1(&buffer[r - 1][c - 1] + 1);
    114        uint8x8_t s1 = vdup_lane_u8(s1234, 0);
    115        uint8x8_t s2 = vdup_lane_u8(s1234, 1);
    116        uint8x8_t s3 = vdup_lane_u8(s1234, 2);
    117        uint8x8_t s4 = vdup_lane_u8(s1234, 3);
    118 
    119        uint16x8_t sum = vmull_u8(s1, f1);
    120        // First row of each filter has all negative values so subtract.
    121        sum = vmlsl_u8(sum, s0, f0);
    122        sum = vmlal_u8(sum, s2, f2);
    123        sum = vmlal_u8(sum, s3, f3);
    124        sum = vmlal_u8(sum, s4, f4);
    125        sum = vmlal_u8(sum, s5, f5);
    126        sum = vmlal_u8(sum, s6, f6);
    127 
    128        uint8x8_t res =
    129            vqrshrun_n_s16(vreinterpretq_s16_u16(sum), FILTER_INTRA_SCALE_BITS);
    130 
    131        // Store buffer[r + 0][c] and buffer[r + 1][c].
    132        store_u8x4_strided_x2(&buffer[r][c], 33, res);
    133 
    134        store_u8x4_strided_x2(dst + (r - 1) * stride + c - 1, stride, res);
    135 
    136        s0 = s4;
    137        s5 = vdup_lane_u8(res, 3);
    138        s6 = vdup_lane_u8(res, 7);
    139        c += 4;
    140      } while (c < width + 1);
    141 
    142      r += 2;
    143    } while (r < height + 1);
    144  } else {
    145    r = 1;
    146    do {
    147      int c = 1;
    148      uint8x8_t s0_lo = vld1_dup_u8(&buffer[r - 1][c - 1]);
    149      uint8x8_t s5_lo = vld1_dup_u8(&buffer[r + 0][c - 1]);
    150      uint8x8_t s6_lo = vld1_dup_u8(&buffer[r + 1][c - 1]);
    151 
    152      do {
    153        uint8x8_t s1234 = vld1_u8(&buffer[r - 1][c - 1] + 1);
    154        uint8x8_t s1_lo = vdup_lane_u8(s1234, 0);
    155        uint8x8_t s2_lo = vdup_lane_u8(s1234, 1);
    156        uint8x8_t s3_lo = vdup_lane_u8(s1234, 2);
    157        uint8x8_t s4_lo = vdup_lane_u8(s1234, 3);
    158 
    159        uint16x8_t sum_lo = vmull_u8(s1_lo, f1);
    160        // First row of each filter has all negative values so subtract.
    161        sum_lo = vmlsl_u8(sum_lo, s0_lo, f0);
    162        sum_lo = vmlal_u8(sum_lo, s2_lo, f2);
    163        sum_lo = vmlal_u8(sum_lo, s3_lo, f3);
    164        sum_lo = vmlal_u8(sum_lo, s4_lo, f4);
    165        sum_lo = vmlal_u8(sum_lo, s5_lo, f5);
    166        sum_lo = vmlal_u8(sum_lo, s6_lo, f6);
    167 
    168        uint8x8_t res_lo = vqrshrun_n_s16(vreinterpretq_s16_u16(sum_lo),
    169                                          FILTER_INTRA_SCALE_BITS);
    170 
    171        uint8x8_t s0_hi = s4_lo;
    172        uint8x8_t s1_hi = vdup_lane_u8(s1234, 4);
    173        uint8x8_t s2_hi = vdup_lane_u8(s1234, 5);
    174        uint8x8_t s3_hi = vdup_lane_u8(s1234, 6);
    175        uint8x8_t s4_hi = vdup_lane_u8(s1234, 7);
    176        uint8x8_t s5_hi = vdup_lane_u8(res_lo, 3);
    177        uint8x8_t s6_hi = vdup_lane_u8(res_lo, 7);
    178 
    179        uint16x8_t sum_hi = vmull_u8(s1_hi, f1);
    180        // First row of each filter has all negative values so subtract.
    181        sum_hi = vmlsl_u8(sum_hi, s0_hi, f0);
    182        sum_hi = vmlal_u8(sum_hi, s2_hi, f2);
    183        sum_hi = vmlal_u8(sum_hi, s3_hi, f3);
    184        sum_hi = vmlal_u8(sum_hi, s4_hi, f4);
    185        sum_hi = vmlal_u8(sum_hi, s5_hi, f5);
    186        sum_hi = vmlal_u8(sum_hi, s6_hi, f6);
    187 
    188        uint8x8_t res_hi = vqrshrun_n_s16(vreinterpretq_s16_u16(sum_hi),
    189                                          FILTER_INTRA_SCALE_BITS);
    190 
    191        uint32x2x2_t res =
    192            vzip_u32(vreinterpret_u32_u8(res_lo), vreinterpret_u32_u8(res_hi));
    193 
    194        vst1_u8(&buffer[r + 0][c], vreinterpret_u8_u32(res.val[0]));
    195        vst1_u8(&buffer[r + 1][c], vreinterpret_u8_u32(res.val[1]));
    196 
    197        vst1_u8(dst + (r - 1) * stride + c - 1,
    198                vreinterpret_u8_u32(res.val[0]));
    199        vst1_u8(dst + (r + 0) * stride + c - 1,
    200                vreinterpret_u8_u32(res.val[1]));
    201 
    202        s0_lo = s4_hi;
    203        s5_lo = vdup_lane_u8(res_hi, 3);
    204        s6_lo = vdup_lane_u8(res_hi, 7);
    205        c += 8;
    206      } while (c < width + 1);
    207 
    208      r += 2;
    209    } while (r < height + 1);
    210  }
    211 }
    212 
    213 void av1_filter_intra_edge_neon(uint8_t *p, int sz, int strength) {
    214  if (!strength) return;
    215  assert(sz >= 0 && sz <= 129);
    216 
    217  uint8_t edge[160];  // Max value of sz + enough padding for vector accesses.
    218  memcpy(edge + 1, p, sz * sizeof(*p));
    219 
    220  // Populate extra space appropriately.
    221  edge[0] = edge[1];
    222  edge[sz + 1] = edge[sz];
    223  edge[sz + 2] = edge[sz];
    224 
    225  // Don't overwrite first pixel.
    226  uint8_t *dst = p + 1;
    227  sz--;
    228 
    229  if (strength == 1) {  // Filter: {4, 8, 4}.
    230    const uint8_t *src = edge + 1;
    231 
    232    while (sz >= 8) {
    233      uint8x8_t s0 = vld1_u8(src);
    234      uint8x8_t s1 = vld1_u8(src + 1);
    235      uint8x8_t s2 = vld1_u8(src + 2);
    236 
    237      // Make use of the identity:
    238      // (4*a + 8*b + 4*c) >> 4 == (a + (b << 1) + c) >> 2
    239      uint16x8_t t0 = vaddl_u8(s0, s2);
    240      uint16x8_t t1 = vaddl_u8(s1, s1);
    241      uint16x8_t sum = vaddq_u16(t0, t1);
    242      uint8x8_t res = vrshrn_n_u16(sum, 2);
    243 
    244      vst1_u8(dst, res);
    245 
    246      src += 8;
    247      dst += 8;
    248      sz -= 8;
    249    }
    250 
    251    if (sz > 0) {  // Handle sz < 8 to avoid modifying out-of-bounds values.
    252      uint8x8_t s0 = vld1_u8(src);
    253      uint8x8_t s1 = vld1_u8(src + 1);
    254      uint8x8_t s2 = vld1_u8(src + 2);
    255 
    256      uint16x8_t t0 = vaddl_u8(s0, s2);
    257      uint16x8_t t1 = vaddl_u8(s1, s1);
    258      uint16x8_t sum = vaddq_u16(t0, t1);
    259      uint8x8_t res = vrshrn_n_u16(sum, 2);
    260 
    261      // Mask off out-of-bounds indices.
    262      uint8x8_t current_dst = vld1_u8(dst);
    263      uint8x8_t mask = vcgt_u8(vdup_n_u8(sz), vcreate_u8(0x0706050403020100));
    264      res = vbsl_u8(mask, res, current_dst);
    265 
    266      vst1_u8(dst, res);
    267    }
    268  } else if (strength == 2) {  // Filter: {5, 6, 5}.
    269    const uint8_t *src = edge + 1;
    270 
    271    const uint8x8x3_t filter = { { vdup_n_u8(5), vdup_n_u8(6), vdup_n_u8(5) } };
    272 
    273    while (sz >= 8) {
    274      uint8x8_t s0 = vld1_u8(src);
    275      uint8x8_t s1 = vld1_u8(src + 1);
    276      uint8x8_t s2 = vld1_u8(src + 2);
    277 
    278      uint16x8_t accum = vmull_u8(s0, filter.val[0]);
    279      accum = vmlal_u8(accum, s1, filter.val[1]);
    280      accum = vmlal_u8(accum, s2, filter.val[2]);
    281      uint8x8_t res = vrshrn_n_u16(accum, 4);
    282 
    283      vst1_u8(dst, res);
    284 
    285      src += 8;
    286      dst += 8;
    287      sz -= 8;
    288    }
    289 
    290    if (sz > 0) {  // Handle sz < 8 to avoid modifying out-of-bounds values.
    291      uint8x8_t s0 = vld1_u8(src);
    292      uint8x8_t s1 = vld1_u8(src + 1);
    293      uint8x8_t s2 = vld1_u8(src + 2);
    294 
    295      uint16x8_t accum = vmull_u8(s0, filter.val[0]);
    296      accum = vmlal_u8(accum, s1, filter.val[1]);
    297      accum = vmlal_u8(accum, s2, filter.val[2]);
    298      uint8x8_t res = vrshrn_n_u16(accum, 4);
    299 
    300      // Mask off out-of-bounds indices.
    301      uint8x8_t current_dst = vld1_u8(dst);
    302      uint8x8_t mask = vcgt_u8(vdup_n_u8(sz), vcreate_u8(0x0706050403020100));
    303      res = vbsl_u8(mask, res, current_dst);
    304 
    305      vst1_u8(dst, res);
    306    }
    307  } else {  // Filter {2, 4, 4, 4, 2}.
    308    const uint8_t *src = edge;
    309 
    310    while (sz >= 8) {
    311      uint8x8_t s0 = vld1_u8(src);
    312      uint8x8_t s1 = vld1_u8(src + 1);
    313      uint8x8_t s2 = vld1_u8(src + 2);
    314      uint8x8_t s3 = vld1_u8(src + 3);
    315      uint8x8_t s4 = vld1_u8(src + 4);
    316 
    317      // Make use of the identity:
    318      // (2*a + 4*b + 4*c + 4*d + 2*e) >> 4 == (a + ((b + c + d) << 1) + e) >> 3
    319      uint16x8_t t0 = vaddl_u8(s0, s4);
    320      uint16x8_t t1 = vaddl_u8(s1, s2);
    321      t1 = vaddw_u8(t1, s3);
    322      t1 = vaddq_u16(t1, t1);
    323      uint16x8_t sum = vaddq_u16(t0, t1);
    324      uint8x8_t res = vrshrn_n_u16(sum, 3);
    325 
    326      vst1_u8(dst, res);
    327 
    328      src += 8;
    329      dst += 8;
    330      sz -= 8;
    331    }
    332 
    333    if (sz > 0) {  // Handle sz < 8 to avoid modifying out-of-bounds values.
    334      uint8x8_t s0 = vld1_u8(src);
    335      uint8x8_t s1 = vld1_u8(src + 1);
    336      uint8x8_t s2 = vld1_u8(src + 2);
    337      uint8x8_t s3 = vld1_u8(src + 3);
    338      uint8x8_t s4 = vld1_u8(src + 4);
    339 
    340      uint16x8_t t0 = vaddl_u8(s0, s4);
    341      uint16x8_t t1 = vaddl_u8(s1, s2);
    342      t1 = vaddw_u8(t1, s3);
    343      t1 = vaddq_u16(t1, t1);
    344      uint16x8_t sum = vaddq_u16(t0, t1);
    345      uint8x8_t res = vrshrn_n_u16(sum, 3);
    346 
    347      // Mask off out-of-bounds indices.
    348      uint8x8_t current_dst = vld1_u8(dst);
    349      uint8x8_t mask = vcgt_u8(vdup_n_u8(sz), vcreate_u8(0x0706050403020100));
    350      res = vbsl_u8(mask, res, current_dst);
    351 
    352      vst1_u8(dst, res);
    353    }
    354  }
    355 }
    356 
    357 void av1_upsample_intra_edge_neon(uint8_t *p, int sz) {
    358  if (!sz) return;
    359 
    360  assert(sz <= MAX_UPSAMPLE_SZ);
    361 
    362  uint8_t edge[MAX_UPSAMPLE_SZ + 3];
    363  const uint8_t *src = edge;
    364 
    365  // Copy p[-1..(sz-1)] and pad out both ends.
    366  edge[0] = p[-1];
    367  edge[1] = p[-1];
    368  memcpy(edge + 2, p, sz);
    369  edge[sz + 2] = p[sz - 1];
    370  p[-2] = p[-1];
    371 
    372  uint8_t *dst = p - 1;
    373 
    374  do {
    375    uint8x8_t s0 = vld1_u8(src);
    376    uint8x8_t s1 = vld1_u8(src + 1);
    377    uint8x8_t s2 = vld1_u8(src + 2);
    378    uint8x8_t s3 = vld1_u8(src + 3);
    379 
    380    int16x8_t t0 = vreinterpretq_s16_u16(vaddl_u8(s0, s3));
    381    int16x8_t t1 = vreinterpretq_s16_u16(vaddl_u8(s1, s2));
    382    t1 = vmulq_n_s16(t1, 9);
    383    t1 = vsubq_s16(t1, t0);
    384 
    385    uint8x8x2_t res = { { vqrshrun_n_s16(t1, 4), s2 } };
    386 
    387    vst2_u8(dst, res);
    388 
    389    src += 8;
    390    dst += 16;
    391    sz -= 8;
    392  } while (sz > 0);
    393 }