tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

highbd_reconinter_neon.c (10696B)


      1 /*
      2 *
      3 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      4 *
      5 * This source code is subject to the terms of the BSD 2 Clause License and
      6 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      7 * was not distributed with this source code in the LICENSE file, you can
      8 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      9 * Media Patent License 1.0 was not distributed with this source code in the
     10 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     11 */
     12 
     13 #include <arm_neon.h>
     14 #include <assert.h>
     15 #include <stdbool.h>
     16 
     17 #include "aom_dsp/arm/mem_neon.h"
     18 #include "aom_dsp/blend.h"
     19 #include "aom_ports/mem.h"
     20 #include "config/av1_rtcd.h"
     21 
     22 static inline void diffwtd_mask_highbd_neon(uint8_t *mask, bool inverse,
     23                                            const uint16_t *src0,
     24                                            int src0_stride,
     25                                            const uint16_t *src1,
     26                                            int src1_stride, int h, int w,
     27                                            const unsigned int bd) {
     28  assert(DIFF_FACTOR > 0);
     29  uint8x16_t max_alpha = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA);
     30  uint8x16_t mask_base = vdupq_n_u8(38);
     31  uint8x16_t mask_diff = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA - 38);
     32 
     33  if (bd == 8) {
     34    if (w >= 16) {
     35      do {
     36        uint8_t *mask_ptr = mask;
     37        const uint16_t *src0_ptr = src0;
     38        const uint16_t *src1_ptr = src1;
     39        int width = w;
     40        do {
     41          uint16x8_t s0_lo = vld1q_u16(src0_ptr);
     42          uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8);
     43          uint16x8_t s1_lo = vld1q_u16(src1_ptr);
     44          uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8);
     45 
     46          uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo);
     47          uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi);
     48          uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, DIFF_FACTOR_LOG2);
     49          uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, DIFF_FACTOR_LOG2);
     50          uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
     51 
     52          uint8x16_t m;
     53          if (inverse) {
     54            m = vqsubq_u8(mask_diff, diff);
     55          } else {
     56            m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha);
     57          }
     58 
     59          vst1q_u8(mask_ptr, m);
     60 
     61          src0_ptr += 16;
     62          src1_ptr += 16;
     63          mask_ptr += 16;
     64          width -= 16;
     65        } while (width != 0);
     66        mask += w;
     67        src0 += src0_stride;
     68        src1 += src1_stride;
     69      } while (--h != 0);
     70    } else if (w == 8) {
     71      do {
     72        uint8_t *mask_ptr = mask;
     73        const uint16_t *src0_ptr = src0;
     74        const uint16_t *src1_ptr = src1;
     75        int width = w;
     76        do {
     77          uint16x8_t s0 = vld1q_u16(src0_ptr);
     78          uint16x8_t s1 = vld1q_u16(src1_ptr);
     79 
     80          uint16x8_t diff_u16 = vabdq_u16(s0, s1);
     81          uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
     82          uint8x8_t m;
     83          if (inverse) {
     84            m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
     85          } else {
     86            m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
     87                        vget_low_u8(max_alpha));
     88          }
     89 
     90          vst1_u8(mask_ptr, m);
     91 
     92          src0_ptr += 8;
     93          src1_ptr += 8;
     94          mask_ptr += 8;
     95          width -= 8;
     96        } while (width != 0);
     97        mask += w;
     98        src0 += src0_stride;
     99        src1 += src1_stride;
    100      } while (--h != 0);
    101    } else if (w == 4) {
    102      do {
    103        uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
    104        uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
    105 
    106        uint16x8_t diff_u16 = vabdq_u16(s0, s1);
    107        uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2);
    108        uint8x8_t m;
    109        if (inverse) {
    110          m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
    111        } else {
    112          m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
    113                      vget_low_u8(max_alpha));
    114        }
    115 
    116        store_u8x4_strided_x2(mask, w, m);
    117 
    118        src0 += 2 * src0_stride;
    119        src1 += 2 * src1_stride;
    120        mask += 2 * w;
    121        h -= 2;
    122      } while (h != 0);
    123    }
    124  } else if (bd == 10) {
    125    if (w >= 16) {
    126      do {
    127        uint8_t *mask_ptr = mask;
    128        const uint16_t *src0_ptr = src0;
    129        const uint16_t *src1_ptr = src1;
    130        int width = w;
    131        do {
    132          uint16x8_t s0_lo = vld1q_u16(src0_ptr);
    133          uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8);
    134          uint16x8_t s1_lo = vld1q_u16(src1_ptr);
    135          uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8);
    136 
    137          uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo);
    138          uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi);
    139          uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, 2 + DIFF_FACTOR_LOG2);
    140          uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, 2 + DIFF_FACTOR_LOG2);
    141          uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
    142 
    143          uint8x16_t m;
    144          if (inverse) {
    145            m = vqsubq_u8(mask_diff, diff);
    146          } else {
    147            m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha);
    148          }
    149 
    150          vst1q_u8(mask_ptr, m);
    151 
    152          src0_ptr += 16;
    153          src1_ptr += 16;
    154          mask_ptr += 16;
    155          width -= 16;
    156        } while (width != 0);
    157        mask += w;
    158        src0 += src0_stride;
    159        src1 += src1_stride;
    160      } while (--h != 0);
    161    } else if (w == 8) {
    162      do {
    163        uint8_t *mask_ptr = mask;
    164        const uint16_t *src0_ptr = src0;
    165        const uint16_t *src1_ptr = src1;
    166        int width = w;
    167        do {
    168          uint16x8_t s0 = vld1q_u16(src0_ptr);
    169          uint16x8_t s1 = vld1q_u16(src1_ptr);
    170 
    171          uint16x8_t diff_u16 = vabdq_u16(s0, s1);
    172          uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 2 + DIFF_FACTOR_LOG2);
    173          uint8x8_t m;
    174          if (inverse) {
    175            m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
    176          } else {
    177            m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
    178                        vget_low_u8(max_alpha));
    179          }
    180 
    181          vst1_u8(mask_ptr, m);
    182 
    183          src0_ptr += 8;
    184          src1_ptr += 8;
    185          mask_ptr += 8;
    186          width -= 8;
    187        } while (width != 0);
    188        mask += w;
    189        src0 += src0_stride;
    190        src1 += src1_stride;
    191      } while (--h != 0);
    192    } else if (w == 4) {
    193      do {
    194        uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
    195        uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
    196 
    197        uint16x8_t diff_u16 = vabdq_u16(s0, s1);
    198        uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 2 + DIFF_FACTOR_LOG2);
    199        uint8x8_t m;
    200        if (inverse) {
    201          m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
    202        } else {
    203          m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
    204                      vget_low_u8(max_alpha));
    205        }
    206 
    207        store_u8x4_strided_x2(mask, w, m);
    208 
    209        src0 += 2 * src0_stride;
    210        src1 += 2 * src1_stride;
    211        mask += 2 * w;
    212        h -= 2;
    213      } while (h != 0);
    214    }
    215  } else {
    216    assert(bd == 12);
    217    if (w >= 16) {
    218      do {
    219        uint8_t *mask_ptr = mask;
    220        const uint16_t *src0_ptr = src0;
    221        const uint16_t *src1_ptr = src1;
    222        int width = w;
    223        do {
    224          uint16x8_t s0_lo = vld1q_u16(src0_ptr);
    225          uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8);
    226          uint16x8_t s1_lo = vld1q_u16(src1_ptr);
    227          uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8);
    228 
    229          uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo);
    230          uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi);
    231          uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, 4 + DIFF_FACTOR_LOG2);
    232          uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, 4 + DIFF_FACTOR_LOG2);
    233          uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8);
    234 
    235          uint8x16_t m;
    236          if (inverse) {
    237            m = vqsubq_u8(mask_diff, diff);
    238          } else {
    239            m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha);
    240          }
    241 
    242          vst1q_u8(mask_ptr, m);
    243 
    244          src0_ptr += 16;
    245          src1_ptr += 16;
    246          mask_ptr += 16;
    247          width -= 16;
    248        } while (width != 0);
    249        mask += w;
    250        src0 += src0_stride;
    251        src1 += src1_stride;
    252      } while (--h != 0);
    253    } else if (w == 8) {
    254      do {
    255        uint8_t *mask_ptr = mask;
    256        const uint16_t *src0_ptr = src0;
    257        const uint16_t *src1_ptr = src1;
    258        int width = w;
    259        do {
    260          uint16x8_t s0 = vld1q_u16(src0_ptr);
    261          uint16x8_t s1 = vld1q_u16(src1_ptr);
    262 
    263          uint16x8_t diff_u16 = vabdq_u16(s0, s1);
    264          uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 4 + DIFF_FACTOR_LOG2);
    265          uint8x8_t m;
    266          if (inverse) {
    267            m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
    268          } else {
    269            m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
    270                        vget_low_u8(max_alpha));
    271          }
    272 
    273          vst1_u8(mask_ptr, m);
    274 
    275          src0_ptr += 8;
    276          src1_ptr += 8;
    277          mask_ptr += 8;
    278          width -= 8;
    279        } while (width != 0);
    280        mask += w;
    281        src0 += src0_stride;
    282        src1 += src1_stride;
    283      } while (--h != 0);
    284    } else if (w == 4) {
    285      do {
    286        uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
    287        uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
    288 
    289        uint16x8_t diff_u16 = vabdq_u16(s0, s1);
    290        uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 4 + DIFF_FACTOR_LOG2);
    291        uint8x8_t m;
    292        if (inverse) {
    293          m = vqsub_u8(vget_low_u8(mask_diff), diff_u8);
    294        } else {
    295          m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)),
    296                      vget_low_u8(max_alpha));
    297        }
    298 
    299        store_u8x4_strided_x2(mask, w, m);
    300 
    301        src0 += 2 * src0_stride;
    302        src1 += 2 * src1_stride;
    303        mask += 2 * w;
    304        h -= 2;
    305      } while (h != 0);
    306    }
    307  }
    308 }
    309 
    310 void av1_build_compound_diffwtd_mask_highbd_neon(
    311    uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0,
    312    int src0_stride, const uint8_t *src1, int src1_stride, int h, int w,
    313    int bd) {
    314  assert(h % 4 == 0);
    315  assert(w % 4 == 0);
    316  assert(mask_type == DIFFWTD_38_INV || mask_type == DIFFWTD_38);
    317 
    318  if (mask_type == DIFFWTD_38) {
    319    diffwtd_mask_highbd_neon(mask, /*inverse=*/false, CONVERT_TO_SHORTPTR(src0),
    320                             src0_stride, CONVERT_TO_SHORTPTR(src1),
    321                             src1_stride, h, w, bd);
    322  } else {  // mask_type == DIFFWTD_38_INV
    323    diffwtd_mask_highbd_neon(mask, /*inverse=*/true, CONVERT_TO_SHORTPTR(src0),
    324                             src0_stride, CONVERT_TO_SHORTPTR(src1),
    325                             src1_stride, h, w, bd);
    326  }
    327 }