avg_pred_neon.c (4337B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <assert.h> 14 15 #include "config/aom_dsp_rtcd.h" 16 17 #include "aom_dsp/arm/blend_neon.h" 18 #include "aom_dsp/arm/dist_wtd_avg_neon.h" 19 #include "aom_dsp/arm/mem_neon.h" 20 #include "aom_dsp/blend.h" 21 22 void aom_comp_avg_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width, 23 int height, const uint8_t *ref, int ref_stride) { 24 if (width > 8) { 25 do { 26 const uint8_t *pred_ptr = pred; 27 const uint8_t *ref_ptr = ref; 28 uint8_t *comp_pred_ptr = comp_pred; 29 int w = width; 30 31 do { 32 const uint8x16_t p = vld1q_u8(pred_ptr); 33 const uint8x16_t r = vld1q_u8(ref_ptr); 34 const uint8x16_t avg = vrhaddq_u8(p, r); 35 36 vst1q_u8(comp_pred_ptr, avg); 37 38 ref_ptr += 16; 39 pred_ptr += 16; 40 comp_pred_ptr += 16; 41 w -= 16; 42 } while (w != 0); 43 44 ref += ref_stride; 45 pred += width; 46 comp_pred += width; 47 } while (--height != 0); 48 } else if (width == 8) { 49 int h = height / 2; 50 51 do { 52 const uint8x16_t p = vld1q_u8(pred); 53 const uint8x16_t r = load_u8_8x2(ref, ref_stride); 54 const uint8x16_t avg = vrhaddq_u8(p, r); 55 56 vst1q_u8(comp_pred, avg); 57 58 ref += 2 * ref_stride; 59 pred += 16; 60 comp_pred += 16; 61 } while (--h != 0); 62 } else { 63 int h = height / 4; 64 assert(width == 4); 65 66 do { 67 const uint8x16_t p = vld1q_u8(pred); 68 const uint8x16_t r = load_unaligned_u8q(ref, ref_stride); 69 const uint8x16_t avg = vrhaddq_u8(p, r); 70 71 vst1q_u8(comp_pred, avg); 72 73 ref += 4 * ref_stride; 74 pred += 16; 75 comp_pred += 16; 76 } while (--h != 0); 77 } 78 } 79 80 void aom_comp_mask_pred_neon(uint8_t *comp_pred, const uint8_t *pred, int width, 81 int height, const uint8_t *ref, int ref_stride, 82 const uint8_t *mask, int mask_stride, 83 int invert_mask) { 84 const uint8_t *src0 = invert_mask ? pred : ref; 85 const uint8_t *src1 = invert_mask ? ref : pred; 86 const int src_stride0 = invert_mask ? width : ref_stride; 87 const int src_stride1 = invert_mask ? ref_stride : width; 88 89 if (width > 8) { 90 do { 91 const uint8_t *src0_ptr = src0; 92 const uint8_t *src1_ptr = src1; 93 const uint8_t *mask_ptr = mask; 94 uint8_t *comp_pred_ptr = comp_pred; 95 int w = width; 96 97 do { 98 const uint8x16_t s0 = vld1q_u8(src0_ptr); 99 const uint8x16_t s1 = vld1q_u8(src1_ptr); 100 const uint8x16_t m0 = vld1q_u8(mask_ptr); 101 102 uint8x16_t blend_u8 = alpha_blend_a64_u8x16(m0, s0, s1); 103 104 vst1q_u8(comp_pred_ptr, blend_u8); 105 106 src0_ptr += 16; 107 src1_ptr += 16; 108 mask_ptr += 16; 109 comp_pred_ptr += 16; 110 w -= 16; 111 } while (w != 0); 112 113 src0 += src_stride0; 114 src1 += src_stride1; 115 mask += mask_stride; 116 comp_pred += width; 117 } while (--height != 0); 118 } else if (width == 8) { 119 do { 120 const uint8x8_t s0 = vld1_u8(src0); 121 const uint8x8_t s1 = vld1_u8(src1); 122 const uint8x8_t m0 = vld1_u8(mask); 123 124 uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, s0, s1); 125 126 vst1_u8(comp_pred, blend_u8); 127 128 src0 += src_stride0; 129 src1 += src_stride1; 130 mask += mask_stride; 131 comp_pred += 8; 132 } while (--height != 0); 133 } else { 134 int h = height / 2; 135 assert(width == 4); 136 137 do { 138 const uint8x8_t s0 = load_unaligned_u8(src0, src_stride0); 139 const uint8x8_t s1 = load_unaligned_u8(src1, src_stride1); 140 const uint8x8_t m0 = load_unaligned_u8(mask, mask_stride); 141 142 uint8x8_t blend_u8 = alpha_blend_a64_u8x8(m0, s0, s1); 143 144 vst1_u8(comp_pred, blend_u8); 145 146 src0 += 2 * src_stride0; 147 src1 += 2 * src_stride1; 148 mask += 2 * mask_stride; 149 comp_pred += 8; 150 } while (--h != 0); 151 } 152 }