highbd_reconinter_neon.c (10696B)
1 /* 2 * 3 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 4 * 5 * This source code is subject to the terms of the BSD 2 Clause License and 6 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 7 * was not distributed with this source code in the LICENSE file, you can 8 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 9 * Media Patent License 1.0 was not distributed with this source code in the 10 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 11 */ 12 13 #include <arm_neon.h> 14 #include <assert.h> 15 #include <stdbool.h> 16 17 #include "aom_dsp/arm/mem_neon.h" 18 #include "aom_dsp/blend.h" 19 #include "aom_ports/mem.h" 20 #include "config/av1_rtcd.h" 21 22 static inline void diffwtd_mask_highbd_neon(uint8_t *mask, bool inverse, 23 const uint16_t *src0, 24 int src0_stride, 25 const uint16_t *src1, 26 int src1_stride, int h, int w, 27 const unsigned int bd) { 28 assert(DIFF_FACTOR > 0); 29 uint8x16_t max_alpha = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA); 30 uint8x16_t mask_base = vdupq_n_u8(38); 31 uint8x16_t mask_diff = vdupq_n_u8(AOM_BLEND_A64_MAX_ALPHA - 38); 32 33 if (bd == 8) { 34 if (w >= 16) { 35 do { 36 uint8_t *mask_ptr = mask; 37 const uint16_t *src0_ptr = src0; 38 const uint16_t *src1_ptr = src1; 39 int width = w; 40 do { 41 uint16x8_t s0_lo = vld1q_u16(src0_ptr); 42 uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8); 43 uint16x8_t s1_lo = vld1q_u16(src1_ptr); 44 uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8); 45 46 uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo); 47 uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi); 48 uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, DIFF_FACTOR_LOG2); 49 uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, DIFF_FACTOR_LOG2); 50 uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8); 51 52 uint8x16_t m; 53 if (inverse) { 54 m = vqsubq_u8(mask_diff, diff); 55 } else { 56 m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha); 57 } 58 59 vst1q_u8(mask_ptr, m); 60 61 src0_ptr += 16; 62 src1_ptr += 16; 63 mask_ptr += 16; 64 width -= 16; 65 } while (width != 0); 66 mask += w; 67 src0 += src0_stride; 68 src1 += src1_stride; 69 } while (--h != 0); 70 } else if (w == 8) { 71 do { 72 uint8_t *mask_ptr = mask; 73 const uint16_t *src0_ptr = src0; 74 const uint16_t *src1_ptr = src1; 75 int width = w; 76 do { 77 uint16x8_t s0 = vld1q_u16(src0_ptr); 78 uint16x8_t s1 = vld1q_u16(src1_ptr); 79 80 uint16x8_t diff_u16 = vabdq_u16(s0, s1); 81 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2); 82 uint8x8_t m; 83 if (inverse) { 84 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8); 85 } else { 86 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)), 87 vget_low_u8(max_alpha)); 88 } 89 90 vst1_u8(mask_ptr, m); 91 92 src0_ptr += 8; 93 src1_ptr += 8; 94 mask_ptr += 8; 95 width -= 8; 96 } while (width != 0); 97 mask += w; 98 src0 += src0_stride; 99 src1 += src1_stride; 100 } while (--h != 0); 101 } else if (w == 4) { 102 do { 103 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 104 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 105 106 uint16x8_t diff_u16 = vabdq_u16(s0, s1); 107 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, DIFF_FACTOR_LOG2); 108 uint8x8_t m; 109 if (inverse) { 110 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8); 111 } else { 112 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)), 113 vget_low_u8(max_alpha)); 114 } 115 116 store_u8x4_strided_x2(mask, w, m); 117 118 src0 += 2 * src0_stride; 119 src1 += 2 * src1_stride; 120 mask += 2 * w; 121 h -= 2; 122 } while (h != 0); 123 } 124 } else if (bd == 10) { 125 if (w >= 16) { 126 do { 127 uint8_t *mask_ptr = mask; 128 const uint16_t *src0_ptr = src0; 129 const uint16_t *src1_ptr = src1; 130 int width = w; 131 do { 132 uint16x8_t s0_lo = vld1q_u16(src0_ptr); 133 uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8); 134 uint16x8_t s1_lo = vld1q_u16(src1_ptr); 135 uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8); 136 137 uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo); 138 uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi); 139 uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, 2 + DIFF_FACTOR_LOG2); 140 uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, 2 + DIFF_FACTOR_LOG2); 141 uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8); 142 143 uint8x16_t m; 144 if (inverse) { 145 m = vqsubq_u8(mask_diff, diff); 146 } else { 147 m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha); 148 } 149 150 vst1q_u8(mask_ptr, m); 151 152 src0_ptr += 16; 153 src1_ptr += 16; 154 mask_ptr += 16; 155 width -= 16; 156 } while (width != 0); 157 mask += w; 158 src0 += src0_stride; 159 src1 += src1_stride; 160 } while (--h != 0); 161 } else if (w == 8) { 162 do { 163 uint8_t *mask_ptr = mask; 164 const uint16_t *src0_ptr = src0; 165 const uint16_t *src1_ptr = src1; 166 int width = w; 167 do { 168 uint16x8_t s0 = vld1q_u16(src0_ptr); 169 uint16x8_t s1 = vld1q_u16(src1_ptr); 170 171 uint16x8_t diff_u16 = vabdq_u16(s0, s1); 172 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 2 + DIFF_FACTOR_LOG2); 173 uint8x8_t m; 174 if (inverse) { 175 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8); 176 } else { 177 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)), 178 vget_low_u8(max_alpha)); 179 } 180 181 vst1_u8(mask_ptr, m); 182 183 src0_ptr += 8; 184 src1_ptr += 8; 185 mask_ptr += 8; 186 width -= 8; 187 } while (width != 0); 188 mask += w; 189 src0 += src0_stride; 190 src1 += src1_stride; 191 } while (--h != 0); 192 } else if (w == 4) { 193 do { 194 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 195 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 196 197 uint16x8_t diff_u16 = vabdq_u16(s0, s1); 198 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 2 + DIFF_FACTOR_LOG2); 199 uint8x8_t m; 200 if (inverse) { 201 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8); 202 } else { 203 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)), 204 vget_low_u8(max_alpha)); 205 } 206 207 store_u8x4_strided_x2(mask, w, m); 208 209 src0 += 2 * src0_stride; 210 src1 += 2 * src1_stride; 211 mask += 2 * w; 212 h -= 2; 213 } while (h != 0); 214 } 215 } else { 216 assert(bd == 12); 217 if (w >= 16) { 218 do { 219 uint8_t *mask_ptr = mask; 220 const uint16_t *src0_ptr = src0; 221 const uint16_t *src1_ptr = src1; 222 int width = w; 223 do { 224 uint16x8_t s0_lo = vld1q_u16(src0_ptr); 225 uint16x8_t s0_hi = vld1q_u16(src0_ptr + 8); 226 uint16x8_t s1_lo = vld1q_u16(src1_ptr); 227 uint16x8_t s1_hi = vld1q_u16(src1_ptr + 8); 228 229 uint16x8_t diff_lo_u16 = vabdq_u16(s0_lo, s1_lo); 230 uint16x8_t diff_hi_u16 = vabdq_u16(s0_hi, s1_hi); 231 uint8x8_t diff_lo_u8 = vshrn_n_u16(diff_lo_u16, 4 + DIFF_FACTOR_LOG2); 232 uint8x8_t diff_hi_u8 = vshrn_n_u16(diff_hi_u16, 4 + DIFF_FACTOR_LOG2); 233 uint8x16_t diff = vcombine_u8(diff_lo_u8, diff_hi_u8); 234 235 uint8x16_t m; 236 if (inverse) { 237 m = vqsubq_u8(mask_diff, diff); 238 } else { 239 m = vminq_u8(vaddq_u8(diff, mask_base), max_alpha); 240 } 241 242 vst1q_u8(mask_ptr, m); 243 244 src0_ptr += 16; 245 src1_ptr += 16; 246 mask_ptr += 16; 247 width -= 16; 248 } while (width != 0); 249 mask += w; 250 src0 += src0_stride; 251 src1 += src1_stride; 252 } while (--h != 0); 253 } else if (w == 8) { 254 do { 255 uint8_t *mask_ptr = mask; 256 const uint16_t *src0_ptr = src0; 257 const uint16_t *src1_ptr = src1; 258 int width = w; 259 do { 260 uint16x8_t s0 = vld1q_u16(src0_ptr); 261 uint16x8_t s1 = vld1q_u16(src1_ptr); 262 263 uint16x8_t diff_u16 = vabdq_u16(s0, s1); 264 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 4 + DIFF_FACTOR_LOG2); 265 uint8x8_t m; 266 if (inverse) { 267 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8); 268 } else { 269 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)), 270 vget_low_u8(max_alpha)); 271 } 272 273 vst1_u8(mask_ptr, m); 274 275 src0_ptr += 8; 276 src1_ptr += 8; 277 mask_ptr += 8; 278 width -= 8; 279 } while (width != 0); 280 mask += w; 281 src0 += src0_stride; 282 src1 += src1_stride; 283 } while (--h != 0); 284 } else if (w == 4) { 285 do { 286 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); 287 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); 288 289 uint16x8_t diff_u16 = vabdq_u16(s0, s1); 290 uint8x8_t diff_u8 = vshrn_n_u16(diff_u16, 4 + DIFF_FACTOR_LOG2); 291 uint8x8_t m; 292 if (inverse) { 293 m = vqsub_u8(vget_low_u8(mask_diff), diff_u8); 294 } else { 295 m = vmin_u8(vadd_u8(diff_u8, vget_low_u8(mask_base)), 296 vget_low_u8(max_alpha)); 297 } 298 299 store_u8x4_strided_x2(mask, w, m); 300 301 src0 += 2 * src0_stride; 302 src1 += 2 * src1_stride; 303 mask += 2 * w; 304 h -= 2; 305 } while (h != 0); 306 } 307 } 308 } 309 310 void av1_build_compound_diffwtd_mask_highbd_neon( 311 uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0, 312 int src0_stride, const uint8_t *src1, int src1_stride, int h, int w, 313 int bd) { 314 assert(h % 4 == 0); 315 assert(w % 4 == 0); 316 assert(mask_type == DIFFWTD_38_INV || mask_type == DIFFWTD_38); 317 318 if (mask_type == DIFFWTD_38) { 319 diffwtd_mask_highbd_neon(mask, /*inverse=*/false, CONVERT_TO_SHORTPTR(src0), 320 src0_stride, CONVERT_TO_SHORTPTR(src1), 321 src1_stride, h, w, bd); 322 } else { // mask_type == DIFFWTD_38_INV 323 diffwtd_mask_highbd_neon(mask, /*inverse=*/true, CONVERT_TO_SHORTPTR(src0), 324 src0_stride, CONVERT_TO_SHORTPTR(src1), 325 src1_stride, h, w, bd); 326 } 327 }