ml_avx2.c (10861B)
1 /* 2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <stdbool.h> 13 #include <assert.h> 14 #include <immintrin.h> 15 16 #include "config/av1_rtcd.h" 17 #include "av1/encoder/ml.h" 18 #include "av1/encoder/x86/ml_sse3.h" 19 20 #define CALC_OUTPUT_FOR_2ROWS \ 21 const int index = weight_idx + (2 * i * tot_num_inputs); \ 22 const __m256 weight0 = _mm256_loadu_ps(&weights[index]); \ 23 const __m256 weight1 = _mm256_loadu_ps(&weights[index + tot_num_inputs]); \ 24 const __m256 mul0 = _mm256_mul_ps(inputs256, weight0); \ 25 const __m256 mul1 = _mm256_mul_ps(inputs256, weight1); \ 26 hadd[i] = _mm256_hadd_ps(mul0, mul1); 27 28 static inline void nn_propagate_8to1( 29 const float *const inputs, const float *const weights, 30 const float *const bias, int num_inputs_to_process, int tot_num_inputs, 31 int num_outputs, float *const output_nodes, int is_clip_required) { 32 // Process one output row at a time. 33 for (int out = 0; out < num_outputs; out++) { 34 __m256 in_result = _mm256_setzero_ps(); 35 float bias_val = bias[out]; 36 for (int in = 0; in < num_inputs_to_process; in += 8) { 37 const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]); 38 const int weight_idx = in + (out * tot_num_inputs); 39 const __m256 weight0 = _mm256_loadu_ps(&weights[weight_idx]); 40 const __m256 mul0 = _mm256_mul_ps(inputs256, weight0); 41 in_result = _mm256_add_ps(in_result, mul0); 42 } 43 const __m128 low_128 = _mm256_castps256_ps128(in_result); 44 const __m128 high_128 = _mm256_extractf128_ps(in_result, 1); 45 const __m128 sum_par_0 = _mm_add_ps(low_128, high_128); 46 const __m128 sum_par_1 = _mm_hadd_ps(sum_par_0, sum_par_0); 47 const __m128 sum_tot = 48 _mm_add_ps(_mm_shuffle_ps(sum_par_1, sum_par_1, 0x99), sum_par_1); 49 50 bias_val += _mm_cvtss_f32(sum_tot); 51 if (is_clip_required) bias_val = AOMMAX(bias_val, 0); 52 output_nodes[out] = bias_val; 53 } 54 } 55 56 static inline void nn_propagate_8to4( 57 const float *const inputs, const float *const weights, 58 const float *const bias, int num_inputs_to_process, int tot_num_inputs, 59 int num_outputs, float *const output_nodes, int is_clip_required) { 60 __m256 hadd[2]; 61 for (int out = 0; out < num_outputs; out += 4) { 62 __m128 bias_reg = _mm_loadu_ps(&bias[out]); 63 __m128 in_result = _mm_setzero_ps(); 64 for (int in = 0; in < num_inputs_to_process; in += 8) { 65 const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]); 66 const int weight_idx = in + (out * tot_num_inputs); 67 // Process two output row at a time. 68 for (int i = 0; i < 2; i++) { 69 CALC_OUTPUT_FOR_2ROWS 70 } 71 72 const __m256 sum_par = _mm256_hadd_ps(hadd[0], hadd[1]); 73 const __m128 low_128 = _mm256_castps256_ps128(sum_par); 74 const __m128 high_128 = _mm256_extractf128_ps(sum_par, 1); 75 const __m128 result = _mm_add_ps(low_128, high_128); 76 77 in_result = _mm_add_ps(in_result, result); 78 } 79 80 in_result = _mm_add_ps(in_result, bias_reg); 81 if (is_clip_required) in_result = _mm_max_ps(in_result, _mm_setzero_ps()); 82 _mm_storeu_ps(&output_nodes[out], in_result); 83 } 84 } 85 86 static inline void nn_propagate_8to8( 87 const float *const inputs, const float *const weights, 88 const float *const bias, int num_inputs_to_process, int tot_num_inputs, 89 int num_outputs, float *const output_nodes, int is_clip_required) { 90 __m256 hadd[4]; 91 for (int out = 0; out < num_outputs; out += 8) { 92 __m256 bias_reg = _mm256_loadu_ps(&bias[out]); 93 __m256 in_result = _mm256_setzero_ps(); 94 for (int in = 0; in < num_inputs_to_process; in += 8) { 95 const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]); 96 const int weight_idx = in + (out * tot_num_inputs); 97 // Process two output rows at a time. 98 for (int i = 0; i < 4; i++) { 99 CALC_OUTPUT_FOR_2ROWS 100 } 101 const __m256 hh0 = _mm256_hadd_ps(hadd[0], hadd[1]); 102 const __m256 hh1 = _mm256_hadd_ps(hadd[2], hadd[3]); 103 104 __m256 ht_0 = _mm256_permute2f128_ps(hh0, hh1, 0x20); 105 __m256 ht_1 = _mm256_permute2f128_ps(hh0, hh1, 0x31); 106 107 __m256 result = _mm256_add_ps(ht_0, ht_1); 108 in_result = _mm256_add_ps(in_result, result); 109 } 110 in_result = _mm256_add_ps(in_result, bias_reg); 111 if (is_clip_required) 112 in_result = _mm256_max_ps(in_result, _mm256_setzero_ps()); 113 _mm256_storeu_ps(&output_nodes[out], in_result); 114 } 115 } 116 117 static inline void nn_propagate_input_multiple_of_8( 118 const float *const inputs, const float *const weights, 119 const float *const bias, int num_inputs_to_process, int tot_num_inputs, 120 bool is_output_layer, int num_outputs, float *const output_nodes) { 121 // The saturation of output is considered for hidden layer which is not equal 122 // to final hidden layer. 123 const int is_clip_required = 124 !is_output_layer && num_inputs_to_process == tot_num_inputs; 125 if (num_outputs % 8 == 0) { 126 nn_propagate_8to8(inputs, weights, bias, num_inputs_to_process, 127 tot_num_inputs, num_outputs, output_nodes, 128 is_clip_required); 129 } else if (num_outputs % 4 == 0) { 130 nn_propagate_8to4(inputs, weights, bias, num_inputs_to_process, 131 tot_num_inputs, num_outputs, output_nodes, 132 is_clip_required); 133 } else { 134 nn_propagate_8to1(inputs, weights, bias, num_inputs_to_process, 135 tot_num_inputs, num_outputs, output_nodes, 136 is_clip_required); 137 } 138 } 139 140 void av1_nn_predict_avx2(const float *input_nodes, 141 const NN_CONFIG *const nn_config, int reduce_prec, 142 float *const output) { 143 float buf[2][NN_MAX_NODES_PER_LAYER]; 144 int buf_index = 0; 145 int num_inputs = nn_config->num_inputs; 146 assert(num_inputs > 0 && num_inputs <= NN_MAX_NODES_PER_LAYER); 147 148 for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) { 149 const float *layer_weights = nn_config->weights[layer]; 150 const float *layer_bias = nn_config->bias[layer]; 151 bool is_output_layer = layer == nn_config->num_hidden_layers; 152 float *const output_nodes = is_output_layer ? output : &buf[buf_index][0]; 153 const int num_outputs = is_output_layer 154 ? nn_config->num_outputs 155 : nn_config->num_hidden_nodes[layer]; 156 assert(num_outputs > 0 && num_outputs <= NN_MAX_NODES_PER_LAYER); 157 158 // Process input multiple of 8 using AVX2 intrinsic. 159 if (num_inputs % 8 == 0) { 160 nn_propagate_input_multiple_of_8(input_nodes, layer_weights, layer_bias, 161 num_inputs, num_inputs, is_output_layer, 162 num_outputs, output_nodes); 163 } else { 164 // When number of inputs is not multiple of 8, use hybrid approach of AVX2 165 // and SSE3 based on the need. 166 const int in_mul_8 = num_inputs / 8; 167 const int num_inputs_to_process = in_mul_8 * 8; 168 int bias_is_considered = 0; 169 if (in_mul_8) { 170 nn_propagate_input_multiple_of_8( 171 input_nodes, layer_weights, layer_bias, num_inputs_to_process, 172 num_inputs, is_output_layer, num_outputs, output_nodes); 173 bias_is_considered = 1; 174 } 175 176 const float *out_temp = bias_is_considered ? output_nodes : layer_bias; 177 const int input_remaining = num_inputs % 8; 178 if (input_remaining % 4 == 0 && num_outputs % 8 == 0) { 179 for (int out = 0; out < num_outputs; out += 8) { 180 __m128 out_h = _mm_loadu_ps(&out_temp[out + 4]); 181 __m128 out_l = _mm_loadu_ps(&out_temp[out]); 182 for (int in = in_mul_8 * 8; in < num_inputs; in += 4) { 183 av1_nn_propagate_4to8_sse3(&input_nodes[in], 184 &layer_weights[out * num_inputs + in], 185 &out_h, &out_l, num_inputs); 186 } 187 if (!is_output_layer) { 188 const __m128 zero = _mm_setzero_ps(); 189 out_h = _mm_max_ps(out_h, zero); 190 out_l = _mm_max_ps(out_l, zero); 191 } 192 _mm_storeu_ps(&output_nodes[out + 4], out_h); 193 _mm_storeu_ps(&output_nodes[out], out_l); 194 } 195 } else if (input_remaining % 4 == 0 && num_outputs % 4 == 0) { 196 for (int out = 0; out < num_outputs; out += 4) { 197 __m128 outputs = _mm_loadu_ps(&out_temp[out]); 198 for (int in = in_mul_8 * 8; in < num_inputs; in += 4) { 199 av1_nn_propagate_4to4_sse3(&input_nodes[in], 200 &layer_weights[out * num_inputs + in], 201 &outputs, num_inputs); 202 } 203 if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps()); 204 _mm_storeu_ps(&output_nodes[out], outputs); 205 } 206 } else if (input_remaining % 4 == 0) { 207 for (int out = 0; out < num_outputs; out++) { 208 __m128 outputs = _mm_load1_ps(&out_temp[out]); 209 for (int in = in_mul_8 * 8; in < num_inputs; in += 4) { 210 av1_nn_propagate_4to1_sse3(&input_nodes[in], 211 &layer_weights[out * num_inputs + in], 212 &outputs); 213 } 214 if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps()); 215 output_nodes[out] = _mm_cvtss_f32(outputs); 216 } 217 } else { 218 // Use SSE instructions for scalar operations to avoid the latency 219 // of swapping between SIMD and FPU modes. 220 for (int out = 0; out < num_outputs; out++) { 221 __m128 outputs = _mm_load1_ps(&out_temp[out]); 222 for (int in_node = in_mul_8 * 8; in_node < num_inputs; in_node++) { 223 __m128 input = _mm_load1_ps(&input_nodes[in_node]); 224 __m128 weight = 225 _mm_load1_ps(&layer_weights[num_inputs * out + in_node]); 226 outputs = _mm_add_ps(outputs, _mm_mul_ps(input, weight)); 227 } 228 if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps()); 229 output_nodes[out] = _mm_cvtss_f32(outputs); 230 } 231 } 232 } 233 // Before processing the next layer, treat the output of current layer as 234 // input to next layer. 235 input_nodes = output_nodes; 236 num_inputs = num_outputs; 237 buf_index = 1 - buf_index; 238 } 239 if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs); 240 }