tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

highbd_blend_a64_mask_neon.c (25799B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <arm_neon.h>
     13 #include <assert.h>
     14 
     15 #include "config/aom_config.h"
     16 #include "config/aom_dsp_rtcd.h"
     17 
     18 #include "aom_dsp/arm/blend_neon.h"
     19 #include "aom_dsp/arm/mem_neon.h"
     20 #include "aom_dsp/blend.h"
     21 
     22 #define HBD_BLEND_A64_D16_MASK(bd, round0_bits)                               \
     23  static inline uint16x8_t alpha_##bd##_blend_a64_d16_u16x8(                  \
     24      uint16x8_t m, uint16x8_t a, uint16x8_t b, int32x4_t round_offset) {     \
     25    const uint16x8_t m_inv =                                                  \
     26        vsubq_u16(vdupq_n_u16(AOM_BLEND_A64_MAX_ALPHA), m);                   \
     27                                                                              \
     28    uint32x4_t blend_u32_lo = vmlal_u16(vreinterpretq_u32_s32(round_offset),  \
     29                                        vget_low_u16(m), vget_low_u16(a));    \
     30    uint32x4_t blend_u32_hi = vmlal_u16(vreinterpretq_u32_s32(round_offset),  \
     31                                        vget_high_u16(m), vget_high_u16(a));  \
     32                                                                              \
     33    blend_u32_lo =                                                            \
     34        vmlal_u16(blend_u32_lo, vget_low_u16(m_inv), vget_low_u16(b));        \
     35    blend_u32_hi =                                                            \
     36        vmlal_u16(blend_u32_hi, vget_high_u16(m_inv), vget_high_u16(b));      \
     37                                                                              \
     38    uint16x4_t blend_u16_lo =                                                 \
     39        vqrshrun_n_s32(vreinterpretq_s32_u32(blend_u32_lo),                   \
     40                       AOM_BLEND_A64_ROUND_BITS + 2 * FILTER_BITS -           \
     41                           round0_bits - COMPOUND_ROUND1_BITS);               \
     42    uint16x4_t blend_u16_hi =                                                 \
     43        vqrshrun_n_s32(vreinterpretq_s32_u32(blend_u32_hi),                   \
     44                       AOM_BLEND_A64_ROUND_BITS + 2 * FILTER_BITS -           \
     45                           round0_bits - COMPOUND_ROUND1_BITS);               \
     46                                                                              \
     47    uint16x8_t blend_u16 = vcombine_u16(blend_u16_lo, blend_u16_hi);          \
     48    blend_u16 = vminq_u16(blend_u16, vdupq_n_u16((1 << bd) - 1));             \
     49                                                                              \
     50    return blend_u16;                                                         \
     51  }                                                                           \
     52                                                                              \
     53  static inline void highbd_##bd##_blend_a64_d16_mask_neon(                   \
     54      uint16_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,          \
     55      uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,  \
     56      const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw,      \
     57      int subh) {                                                             \
     58    const int offset_bits = bd + 2 * FILTER_BITS - round0_bits;               \
     59    int32_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +      \
     60                           (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));   \
     61    int32x4_t offset =                                                        \
     62        vdupq_n_s32(-(round_offset << AOM_BLEND_A64_ROUND_BITS));             \
     63                                                                              \
     64    if ((subw | subh) == 0) {                                                 \
     65      if (w >= 8) {                                                           \
     66        do {                                                                  \
     67          int i = 0;                                                          \
     68          do {                                                                \
     69            uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i));                      \
     70            uint16x8_t s0 = vld1q_u16(src0 + i);                              \
     71            uint16x8_t s1 = vld1q_u16(src1 + i);                              \
     72                                                                              \
     73            uint16x8_t blend =                                                \
     74                alpha_##bd##_blend_a64_d16_u16x8(m0, s0, s1, offset);         \
     75                                                                              \
     76            vst1q_u16(dst + i, blend);                                        \
     77            i += 8;                                                           \
     78          } while (i < w);                                                    \
     79                                                                              \
     80          mask += mask_stride;                                                \
     81          src0 += src0_stride;                                                \
     82          src1 += src1_stride;                                                \
     83          dst += dst_stride;                                                  \
     84        } while (--h != 0);                                                   \
     85      } else {                                                                \
     86        do {                                                                  \
     87          uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride)); \
     88          uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);          \
     89          uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);          \
     90                                                                              \
     91          uint16x8_t blend =                                                  \
     92              alpha_##bd##_blend_a64_d16_u16x8(m0, s0, s1, offset);           \
     93                                                                              \
     94          store_u16x4_strided_x2(dst, dst_stride, blend);                     \
     95                                                                              \
     96          mask += 2 * mask_stride;                                            \
     97          src0 += 2 * src0_stride;                                            \
     98          src1 += 2 * src1_stride;                                            \
     99          dst += 2 * dst_stride;                                              \
    100          h -= 2;                                                             \
    101        } while (h != 0);                                                     \
    102      }                                                                       \
    103    } else if ((subw & subh) == 1) {                                          \
    104      if (w >= 8) {                                                           \
    105        do {                                                                  \
    106          int i = 0;                                                          \
    107          do {                                                                \
    108            uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + 2 * i);         \
    109            uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + 2 * i);         \
    110            uint16x8_t s0 = vld1q_u16(src0 + i);                              \
    111            uint16x8_t s1 = vld1q_u16(src1 + i);                              \
    112                                                                              \
    113            uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4(            \
    114                vget_low_u8(m0), vget_low_u8(m1), vget_high_u8(m0),           \
    115                vget_high_u8(m1)));                                           \
    116            uint16x8_t blend =                                                \
    117                alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);      \
    118                                                                              \
    119            vst1q_u16(dst + i, blend);                                        \
    120            i += 8;                                                           \
    121          } while (i < w);                                                    \
    122                                                                              \
    123          mask += 2 * mask_stride;                                            \
    124          src0 += src0_stride;                                                \
    125          src1 += src1_stride;                                                \
    126          dst += dst_stride;                                                  \
    127        } while (--h != 0);                                                   \
    128      } else {                                                                \
    129        do {                                                                  \
    130          uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);                     \
    131          uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);                     \
    132          uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride);                     \
    133          uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride);                     \
    134          uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);          \
    135          uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);          \
    136                                                                              \
    137          uint16x8_t m_avg =                                                  \
    138              vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3));            \
    139          uint16x8_t blend =                                                  \
    140              alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);        \
    141                                                                              \
    142          store_u16x4_strided_x2(dst, dst_stride, blend);                     \
    143                                                                              \
    144          mask += 4 * mask_stride;                                            \
    145          src0 += 2 * src0_stride;                                            \
    146          src1 += 2 * src1_stride;                                            \
    147          dst += 2 * dst_stride;                                              \
    148          h -= 2;                                                             \
    149        } while (h != 0);                                                     \
    150      }                                                                       \
    151    } else if (subw == 1 && subh == 0) {                                      \
    152      if (w >= 8) {                                                           \
    153        do {                                                                  \
    154          int i = 0;                                                          \
    155          do {                                                                \
    156            uint8x8_t m0 = vld1_u8(mask + 2 * i);                             \
    157            uint8x8_t m1 = vld1_u8(mask + 2 * i + 8);                         \
    158            uint16x8_t s0 = vld1q_u16(src0 + i);                              \
    159            uint16x8_t s1 = vld1q_u16(src1 + i);                              \
    160                                                                              \
    161            uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));     \
    162            uint16x8_t blend =                                                \
    163                alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);      \
    164                                                                              \
    165            vst1q_u16(dst + i, blend);                                        \
    166            i += 8;                                                           \
    167          } while (i < w);                                                    \
    168                                                                              \
    169          mask += mask_stride;                                                \
    170          src0 += src0_stride;                                                \
    171          src1 += src1_stride;                                                \
    172          dst += dst_stride;                                                  \
    173        } while (--h != 0);                                                   \
    174      } else {                                                                \
    175        do {                                                                  \
    176          uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);                     \
    177          uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);                     \
    178          uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);          \
    179          uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);          \
    180                                                                              \
    181          uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));       \
    182          uint16x8_t blend =                                                  \
    183              alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);        \
    184                                                                              \
    185          store_u16x4_strided_x2(dst, dst_stride, blend);                     \
    186                                                                              \
    187          mask += 2 * mask_stride;                                            \
    188          src0 += 2 * src0_stride;                                            \
    189          src1 += 2 * src1_stride;                                            \
    190          dst += 2 * dst_stride;                                              \
    191          h -= 2;                                                             \
    192        } while (h != 0);                                                     \
    193      }                                                                       \
    194    } else {                                                                  \
    195      if (w >= 8) {                                                           \
    196        do {                                                                  \
    197          int i = 0;                                                          \
    198          do {                                                                \
    199            uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i);               \
    200            uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i);               \
    201            uint16x8_t s0 = vld1q_u16(src0 + i);                              \
    202            uint16x8_t s1 = vld1q_u16(src1 + i);                              \
    203                                                                              \
    204            uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1));              \
    205            uint16x8_t blend =                                                \
    206                alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);      \
    207                                                                              \
    208            vst1q_u16(dst + i, blend);                                        \
    209            i += 8;                                                           \
    210          } while (i < w);                                                    \
    211                                                                              \
    212          mask += 2 * mask_stride;                                            \
    213          src0 += src0_stride;                                                \
    214          src1 += src1_stride;                                                \
    215          dst += dst_stride;                                                  \
    216        } while (--h != 0);                                                   \
    217      } else {                                                                \
    218        do {                                                                  \
    219          uint8x8_t m0_2 =                                                    \
    220              load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); \
    221          uint8x8_t m1_3 =                                                    \
    222              load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); \
    223          uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);          \
    224          uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);          \
    225                                                                              \
    226          uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3));            \
    227          uint16x8_t blend =                                                  \
    228              alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);        \
    229                                                                              \
    230          store_u16x4_strided_x2(dst, dst_stride, blend);                     \
    231                                                                              \
    232          mask += 4 * mask_stride;                                            \
    233          src0 += 2 * src0_stride;                                            \
    234          src1 += 2 * src1_stride;                                            \
    235          dst += 2 * dst_stride;                                              \
    236          h -= 2;                                                             \
    237        } while (h != 0);                                                     \
    238      }                                                                       \
    239    }                                                                         \
    240  }
    241 
    242 // 12 bitdepth
    243 HBD_BLEND_A64_D16_MASK(12, (ROUND0_BITS + 2))
    244 // 10 bitdepth
    245 HBD_BLEND_A64_D16_MASK(10, ROUND0_BITS)
    246 // 8 bitdepth
    247 HBD_BLEND_A64_D16_MASK(8, ROUND0_BITS)
    248 
    249 void aom_highbd_blend_a64_d16_mask_neon(
    250    uint8_t *dst_8, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
    251    uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
    252    const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh,
    253    ConvolveParams *conv_params, const int bd) {
    254  (void)conv_params;
    255  assert(h >= 1);
    256  assert(w >= 1);
    257  assert(IS_POWER_OF_TWO(h));
    258  assert(IS_POWER_OF_TWO(w));
    259 
    260  uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8);
    261  assert(IMPLIES(src0 == dst, src0_stride == dst_stride));
    262  assert(IMPLIES(src1 == dst, src1_stride == dst_stride));
    263 
    264  if (bd == 12) {
    265    highbd_12_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1,
    266                                      src1_stride, mask, mask_stride, w, h,
    267                                      subw, subh);
    268  } else if (bd == 10) {
    269    highbd_10_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1,
    270                                      src1_stride, mask, mask_stride, w, h,
    271                                      subw, subh);
    272  } else {
    273    highbd_8_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1,
    274                                     src1_stride, mask, mask_stride, w, h, subw,
    275                                     subh);
    276  }
    277 }
    278 
    279 void aom_highbd_blend_a64_mask_neon(uint8_t *dst_8, uint32_t dst_stride,
    280                                    const uint8_t *src0_8, uint32_t src0_stride,
    281                                    const uint8_t *src1_8, uint32_t src1_stride,
    282                                    const uint8_t *mask, uint32_t mask_stride,
    283                                    int w, int h, int subw, int subh, int bd) {
    284  (void)bd;
    285 
    286  const uint16_t *src0 = CONVERT_TO_SHORTPTR(src0_8);
    287  const uint16_t *src1 = CONVERT_TO_SHORTPTR(src1_8);
    288  uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8);
    289 
    290  assert(IMPLIES(src0 == dst, src0_stride == dst_stride));
    291  assert(IMPLIES(src1 == dst, src1_stride == dst_stride));
    292 
    293  assert(h >= 1);
    294  assert(w >= 1);
    295  assert(IS_POWER_OF_TWO(h));
    296  assert(IS_POWER_OF_TWO(w));
    297 
    298  assert(bd == 8 || bd == 10 || bd == 12);
    299 
    300  if ((subw | subh) == 0) {
    301    if (w >= 8) {
    302      do {
    303        int i = 0;
    304        do {
    305          uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i));
    306          uint16x8_t s0 = vld1q_u16(src0 + i);
    307          uint16x8_t s1 = vld1q_u16(src1 + i);
    308 
    309          uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1);
    310 
    311          vst1q_u16(dst + i, blend);
    312          i += 8;
    313        } while (i < w);
    314 
    315        mask += mask_stride;
    316        src0 += src0_stride;
    317        src1 += src1_stride;
    318        dst += dst_stride;
    319      } while (--h != 0);
    320    } else {
    321      do {
    322        uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride));
    323        uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
    324        uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
    325 
    326        uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1);
    327 
    328        store_u16x4_strided_x2(dst, dst_stride, blend);
    329 
    330        mask += 2 * mask_stride;
    331        src0 += 2 * src0_stride;
    332        src1 += 2 * src1_stride;
    333        dst += 2 * dst_stride;
    334        h -= 2;
    335      } while (h != 0);
    336    }
    337  } else if ((subw & subh) == 1) {
    338    if (w >= 8) {
    339      do {
    340        int i = 0;
    341        do {
    342          uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + 2 * i);
    343          uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + 2 * i);
    344          uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 2 * i + 8);
    345          uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 2 * i + 8);
    346          uint16x8_t s0 = vld1q_u16(src0 + i);
    347          uint16x8_t s1 = vld1q_u16(src1 + i);
    348 
    349          uint16x8_t m_avg =
    350              vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3));
    351 
    352          uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
    353 
    354          vst1q_u16(dst + i, blend);
    355 
    356          i += 8;
    357        } while (i < w);
    358 
    359        mask += 2 * mask_stride;
    360        src0 += src0_stride;
    361        src1 += src1_stride;
    362        dst += dst_stride;
    363      } while (--h != 0);
    364    } else {
    365      do {
    366        uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
    367        uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
    368        uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride);
    369        uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride);
    370        uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
    371        uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
    372 
    373        uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3));
    374        uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
    375 
    376        store_u16x4_strided_x2(dst, dst_stride, blend);
    377 
    378        mask += 4 * mask_stride;
    379        src0 += 2 * src0_stride;
    380        src1 += 2 * src1_stride;
    381        dst += 2 * dst_stride;
    382        h -= 2;
    383      } while (h != 0);
    384    }
    385  } else if (subw == 1 && subh == 0) {
    386    if (w >= 8) {
    387      do {
    388        int i = 0;
    389 
    390        do {
    391          uint8x8_t m0 = vld1_u8(mask + 2 * i);
    392          uint8x8_t m1 = vld1_u8(mask + 2 * i + 8);
    393          uint16x8_t s0 = vld1q_u16(src0 + i);
    394          uint16x8_t s1 = vld1q_u16(src1 + i);
    395 
    396          uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));
    397          uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
    398 
    399          vst1q_u16(dst + i, blend);
    400 
    401          i += 8;
    402        } while (i < w);
    403 
    404        mask += mask_stride;
    405        src0 += src0_stride;
    406        src1 += src1_stride;
    407        dst += dst_stride;
    408      } while (--h != 0);
    409    } else {
    410      do {
    411        uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
    412        uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
    413        uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
    414        uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
    415 
    416        uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));
    417        uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
    418 
    419        store_u16x4_strided_x2(dst, dst_stride, blend);
    420 
    421        mask += 2 * mask_stride;
    422        src0 += 2 * src0_stride;
    423        src1 += 2 * src1_stride;
    424        dst += 2 * dst_stride;
    425        h -= 2;
    426      } while (h != 0);
    427    }
    428  } else {
    429    if (w >= 8) {
    430      do {
    431        int i = 0;
    432        do {
    433          uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i);
    434          uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i);
    435          uint16x8_t s0 = vld1q_u16(src0 + i);
    436          uint16x8_t s1 = vld1q_u16(src1 + i);
    437 
    438          uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1));
    439          uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
    440 
    441          vst1q_u16(dst + i, blend);
    442 
    443          i += 8;
    444        } while (i < w);
    445 
    446        mask += 2 * mask_stride;
    447        src0 += src0_stride;
    448        src1 += src1_stride;
    449        dst += dst_stride;
    450      } while (--h != 0);
    451    } else {
    452      do {
    453        uint8x8_t m0_2 =
    454            load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride);
    455        uint8x8_t m1_3 =
    456            load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride);
    457        uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
    458        uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
    459 
    460        uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3));
    461        uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
    462 
    463        store_u16x4_strided_x2(dst, dst_stride, blend);
    464 
    465        mask += 4 * mask_stride;
    466        src0 += 2 * src0_stride;
    467        src1 += 2 * src1_stride;
    468        dst += 2 * dst_stride;
    469        h -= 2;
    470      } while (h != 0);
    471    }
    472  }
    473 }