highbd_pickrst_sve.c (5832B)
1 /* 2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <arm_neon.h> 13 #include <arm_sve.h> 14 15 #include <assert.h> 16 #include <stdint.h> 17 18 #include "aom_dsp/arm/aom_neon_sve_bridge.h" 19 #include "aom_dsp/arm/mem_neon.h" 20 #include "aom_dsp/arm/sum_neon.h" 21 #include "aom_dsp/arm/transpose_neon.h" 22 #include "av1/encoder/arm/pickrst_neon.h" 23 #include "av1/encoder/arm/pickrst_sve.h" 24 #include "av1/encoder/pickrst.h" 25 26 static inline uint16_t highbd_find_average_sve(const uint16_t *src, 27 int src_stride, int width, 28 int height) { 29 uint64x2_t avg_u64 = vdupq_n_u64(0); 30 uint16x8_t ones = vdupq_n_u16(1); 31 32 // Use a predicate to compute the last columns. 33 svbool_t pattern = svwhilelt_b16_u32(0, width % 8 == 0 ? 8 : width % 8); 34 35 int h = height; 36 do { 37 int j = width; 38 const uint16_t *src_ptr = src; 39 while (j > 8) { 40 uint16x8_t s = vld1q_u16(src_ptr); 41 avg_u64 = aom_udotq_u16(avg_u64, s, ones); 42 43 j -= 8; 44 src_ptr += 8; 45 } 46 uint16x8_t s_end = svget_neonq_u16(svld1_u16(pattern, src_ptr)); 47 avg_u64 = aom_udotq_u16(avg_u64, s_end, ones); 48 49 src += src_stride; 50 } while (--h != 0); 51 return (uint16_t)(vaddvq_u64(avg_u64) / (width * height)); 52 } 53 54 static inline void sub_avg_block_highbd_sve(const uint16_t *buf, int buf_stride, 55 int16_t avg, int width, int height, 56 int16_t *buf_avg, 57 int buf_avg_stride) { 58 uint16x8_t avg_u16 = vdupq_n_u16(avg); 59 60 // Use a predicate to compute the last columns. 61 svbool_t pattern = svwhilelt_b16_u32(0, width % 8 == 0 ? 8 : width % 8); 62 63 uint16x8_t avg_end = svget_neonq_u16(svdup_n_u16_z(pattern, avg)); 64 65 do { 66 int j = width; 67 const uint16_t *buf_ptr = buf; 68 int16_t *buf_avg_ptr = buf_avg; 69 while (j > 8) { 70 uint16x8_t d = vld1q_u16(buf_ptr); 71 vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubq_u16(d, avg_u16))); 72 73 j -= 8; 74 buf_ptr += 8; 75 buf_avg_ptr += 8; 76 } 77 uint16x8_t d_end = svget_neonq_u16(svld1_u16(pattern, buf_ptr)); 78 vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubq_u16(d_end, avg_end))); 79 80 buf += buf_stride; 81 buf_avg += buf_avg_stride; 82 } while (--height > 0); 83 } 84 85 void av1_compute_stats_highbd_sve(int32_t wiener_win, const uint8_t *dgd8, 86 const uint8_t *src8, int16_t *dgd_avg, 87 int16_t *src_avg, int32_t h_start, 88 int32_t h_end, int32_t v_start, int32_t v_end, 89 int32_t dgd_stride, int32_t src_stride, 90 int64_t *M, int64_t *H, 91 aom_bit_depth_t bit_depth) { 92 const int32_t wiener_win2 = wiener_win * wiener_win; 93 const int32_t wiener_halfwin = (wiener_win >> 1); 94 const uint16_t *src = CONVERT_TO_SHORTPTR(src8); 95 const uint16_t *dgd = CONVERT_TO_SHORTPTR(dgd8); 96 const int32_t width = h_end - h_start; 97 const int32_t height = v_end - v_start; 98 const int32_t d_stride = (width + 2 * wiener_halfwin + 15) & ~15; 99 const int32_t s_stride = (width + 15) & ~15; 100 101 const uint16_t *dgd_start = dgd + h_start + v_start * dgd_stride; 102 const uint16_t *src_start = src + h_start + v_start * src_stride; 103 const uint16_t avg = 104 highbd_find_average_sve(dgd_start, dgd_stride, width, height); 105 106 sub_avg_block_highbd_sve(src_start, src_stride, avg, width, height, src_avg, 107 s_stride); 108 sub_avg_block_highbd_sve( 109 dgd + (v_start - wiener_halfwin) * dgd_stride + h_start - wiener_halfwin, 110 dgd_stride, avg, width + 2 * wiener_halfwin, height + 2 * wiener_halfwin, 111 dgd_avg, d_stride); 112 113 if (wiener_win == WIENER_WIN) { 114 compute_stats_win7_sve(dgd_avg, d_stride, src_avg, s_stride, width, height, 115 M, H); 116 } else { 117 assert(wiener_win == WIENER_WIN_CHROMA); 118 compute_stats_win5_sve(dgd_avg, d_stride, src_avg, s_stride, width, height, 119 M, H); 120 } 121 122 // H is a symmetric matrix, so we only need to fill out the upper triangle. 123 // We can copy it down to the lower triangle outside the (i, j) loops. 124 if (bit_depth == AOM_BITS_8) { 125 diagonal_copy_stats_neon(wiener_win2, H); 126 } else if (bit_depth == AOM_BITS_10) { // bit_depth == EB_TEN_BIT 127 const int32_t k4 = wiener_win2 & ~3; 128 129 int32_t k = 0; 130 do { 131 int64x2_t dst = div4_neon(vld1q_s64(M + k)); 132 vst1q_s64(M + k, dst); 133 dst = div4_neon(vld1q_s64(M + k + 2)); 134 vst1q_s64(M + k + 2, dst); 135 H[k * wiener_win2 + k] /= 4; 136 k += 4; 137 } while (k < k4); 138 139 H[k * wiener_win2 + k] /= 4; 140 141 for (; k < wiener_win2; ++k) { 142 M[k] /= 4; 143 } 144 145 div4_diagonal_copy_stats_neon(wiener_win2, H); 146 } else { // bit_depth == AOM_BITS_12 147 const int32_t k4 = wiener_win2 & ~3; 148 149 int32_t k = 0; 150 do { 151 int64x2_t dst = div16_neon(vld1q_s64(M + k)); 152 vst1q_s64(M + k, dst); 153 dst = div16_neon(vld1q_s64(M + k + 2)); 154 vst1q_s64(M + k + 2, dst); 155 H[k * wiener_win2 + k] /= 16; 156 k += 4; 157 } while (k < k4); 158 159 H[k * wiener_win2 + k] /= 16; 160 161 for (; k < wiener_win2; ++k) { 162 M[k] /= 16; 163 } 164 165 div16_diagonal_copy_stats_neon(wiener_win2, H); 166 } 167 }