ml_neon.c (13384B)
1 /* 2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <stdbool.h> 13 #include <assert.h> 14 #include <arm_neon.h> 15 16 #include "config/aom_config.h" 17 #include "config/av1_rtcd.h" 18 #include "av1/encoder/ml.h" 19 20 static void nn_activate8(float32x4_t *out_h, float32x4_t *out_l, 21 const float32x4_t *zero) { 22 *out_h = vmaxq_f32(*out_h, *zero); 23 *out_l = vmaxq_f32(*out_l, *zero); 24 } 25 26 static void nn_activate4(float32x4_t *x, const float32x4_t *zero) { 27 *x = vmaxq_f32(*x, *zero); 28 } 29 30 #define CLAMP_0(x) (x = x > 0 ? x : 0) 31 32 static void nn_propagate_8to1(int num_inputs, const float *const inputs, 33 const float *const weights, 34 const float *layer_bias, 35 float *const output_nodes, bool output_layer) { 36 const float32x4_t zero = vdupq_n_f32(0); 37 float32x4_t vadd = zero; 38 float total = *layer_bias; 39 40 for (int in = 0; in < num_inputs; in += 8) { 41 const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]); 42 const float32x4_t inputs_l = vld1q_f32(&inputs[in]); 43 44 const float32x4_t weights_h = vld1q_f32(&weights[in + 4]); 45 const float32x4_t weights_l = vld1q_f32(&weights[in]); 46 47 vadd = vmlaq_f32(vadd, inputs_h, weights_h); 48 vadd = vmlaq_f32(vadd, inputs_l, weights_l); 49 } 50 #if AOM_ARCH_AARCH64 51 total += vaddvq_f32(vadd); 52 #else 53 float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd)); 54 vadd_lo = vpadd_f32(vadd_lo, vadd_lo); 55 total += vget_lane_f32(vadd_lo, 0); 56 #endif 57 58 if (!output_layer) CLAMP_0(total); 59 *output_nodes = total; 60 } 61 62 static void nn_propagate_xto1(int num_inputs, const float *const inputs, 63 const float *const weights, 64 const float *layer_bias, 65 float *const output_nodes) { 66 float32x4_t vadd = vdupq_n_f32(0); 67 68 float total = *layer_bias; 69 int j = num_inputs; 70 int in = 0; 71 while (j > 7) { 72 const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]); 73 const float32x4_t inputs_l = vld1q_f32(&inputs[in]); 74 75 const float32x4_t weights_h = vld1q_f32(&weights[in + 4]); 76 const float32x4_t weights_l = vld1q_f32(&weights[in]); 77 78 vadd = vmlaq_f32(vadd, inputs_h, weights_h); 79 vadd = vmlaq_f32(vadd, inputs_l, weights_l); 80 in += 8; 81 j -= 8; 82 } 83 84 #if AOM_ARCH_AARCH64 85 total += vaddvq_f32(vadd); 86 87 #else 88 float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd)); 89 vadd_lo = vpadd_f32(vadd_lo, vadd_lo); 90 total += vget_lane_f32(vadd_lo, 0); 91 #endif 92 for (; in < num_inputs; in++) total += weights[in] * inputs[in]; 93 94 *output_nodes = CLAMP_0(total); 95 } 96 97 static void nn_propagate_xsto1(int num_inputs, const float *const inputs, 98 const float *const weights, 99 const float *layer_bias, 100 float *const output_nodes) { 101 float total = *layer_bias; 102 #if AOM_ARCH_AARCH64 103 const float32x4_t v_inputs = vld1q_f32(inputs); 104 const float32x4_t v_weights = vld1q_f32(weights); 105 const float32x4_t vadd = vmulq_f32(v_inputs, v_weights); 106 total += vaddvq_f32(vadd); 107 int in = 4; 108 #else 109 int in = 0; 110 #endif 111 for (; in < num_inputs; in++) total += weights[in] * inputs[in]; 112 113 *output_nodes = CLAMP_0(total); 114 } 115 116 static void nn_propagate_4to1(int num_inputs, const float *const inputs, 117 const float *const weights, 118 const float *layer_bias, 119 float *const output_nodes, bool output_layer) { 120 const float32x4_t zero = vdupq_n_f32(0); 121 float32x4_t vadd = zero; 122 float total = *layer_bias; 123 124 for (int in = 0; in < num_inputs; in += 4) { 125 const float32x4_t v_inputs = vld1q_f32(&inputs[in]); 126 const float32x4_t v_weights = vld1q_f32(&weights[in]); 127 vadd = vmlaq_f32(vadd, v_inputs, v_weights); 128 } 129 130 #if AOM_ARCH_AARCH64 131 total += vaddvq_f32(vadd); 132 #else 133 float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd)); 134 vadd_lo = vpadd_f32(vadd_lo, vadd_lo); 135 total += vget_lane_f32(vadd_lo, 0); 136 #endif 137 138 if (!output_layer) CLAMP_0(total); 139 *output_nodes = total; 140 } 141 142 static void nn_propagate_4to4(int num_inputs, const float *const inputs, 143 const float *const weights, 144 const float *layer_bias, 145 float *const output_nodes, bool output_layer) { 146 float32x4_t outputs = vld1q_f32(layer_bias); 147 const float32x4_t zero = vdupq_n_f32(0); 148 149 float32x4_t mul0[2] = { zero, zero }; 150 float32x4_t mul1[2] = { zero, zero }; 151 for (int in = 0; in < num_inputs; in += 4) { 152 const float32x4_t v_input = vld1q_f32(&inputs[in]); 153 154 for (int i = 0; i < 2; i++) { 155 const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]); 156 mul0[i] = vmlaq_f32(mul0[i], weight0, v_input); 157 const float32x4_t weight1 = 158 vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]); 159 mul1[i] = vmlaq_f32(mul1[i], weight1, v_input); 160 } 161 } 162 for (int i = 0; i < 2; i++) 163 #if AOM_ARCH_AARCH64 164 mul0[i] = vpaddq_f32(mul0[i], mul1[i]); 165 const float32x4_t hh = vpaddq_f32(mul0[0], mul0[1]); 166 #else 167 mul0[i] = 168 vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])), 169 vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i]))); 170 const float32x4_t hh = 171 vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])), 172 vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1]))); 173 #endif 174 175 outputs = vaddq_f32(outputs, hh); 176 if (!output_layer) nn_activate4(&outputs, &zero); 177 vst1q_f32(output_nodes, outputs); 178 } 179 180 static void nn_propagate_4to8(const int num_inputs, const float *const inputs, 181 const float *const weights, 182 const float *layer_bias, 183 float *const output_nodes, bool output_layer) { 184 float32x4_t out_h = vld1q_f32(&layer_bias[4]); 185 float32x4_t out_l = vld1q_f32(layer_bias); 186 const float32x4_t zero = vdupq_n_f32(0); 187 float32x4_t mul0[4] = { zero, zero, zero, zero }; 188 float32x4_t mul1[4] = { zero, zero, zero, zero }; 189 190 for (int in = 0; in < num_inputs; in += 4) { 191 const float32x4_t v_input = vld1q_f32(&inputs[in]); 192 for (int i = 0; i < 4; i++) { 193 const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]); 194 const float32x4_t weight1 = 195 vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]); 196 mul0[i] = vmlaq_f32(mul0[i], v_input, weight0); 197 mul1[i] = vmlaq_f32(mul1[i], v_input, weight1); 198 } 199 } 200 for (int i = 0; i < 4; i++) 201 #if AOM_ARCH_AARCH64 202 mul0[i] = vpaddq_f32(mul0[i], mul1[i]); 203 const float32x4_t hh0 = vpaddq_f32(mul0[0], mul0[1]); 204 const float32x4_t hh1 = vpaddq_f32(mul0[2], mul0[3]); 205 #else 206 mul0[i] = 207 vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])), 208 vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i]))); 209 const float32x4_t hh0 = 210 vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])), 211 vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1]))); 212 const float32x4_t hh1 = 213 vcombine_f32(vpadd_f32(vget_low_f32(mul0[2]), vget_high_f32(mul0[2])), 214 vpadd_f32(vget_low_f32(mul0[3]), vget_high_f32(mul0[3]))); 215 #endif 216 217 out_h = vaddq_f32(out_h, hh1); 218 out_l = vaddq_f32(out_l, hh0); 219 220 if (!output_layer) nn_activate8(&out_h, &out_l, &zero); 221 vst1q_f32(&output_nodes[4], out_h); 222 vst1q_f32(output_nodes, out_l); 223 } 224 225 static void nn_propagate_8to4(const int num_inputs, const float *const inputs, 226 const float *const weights, 227 const float *layer_bias, 228 float *const output_nodes, bool output_layer) { 229 float32x4_t outputs = vld1q_f32(layer_bias); 230 const float32x4_t zero = vdupq_n_f32(0); 231 float32x4_t add[4] = { zero, zero, zero, zero }; 232 for (int in = 0; in < num_inputs; in += 8) { 233 const float32x4_t inputs_l = vld1q_f32(&inputs[in]); 234 const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]); 235 236 for (int i = 0; i < 4; i++) { 237 const float32x4_t weight_l = vld1q_f32(&weights[in + i * num_inputs]); 238 const float32x4_t weight_h = vld1q_f32(&weights[in + i * num_inputs + 4]); 239 add[i] = vmlaq_f32(add[i], inputs_l, weight_l); 240 add[i] = vmlaq_f32(add[i], inputs_h, weight_h); 241 } 242 } 243 #if AOM_ARCH_AARCH64 244 const float32x4_t hadd_h = vpaddq_f32(add[2], add[3]); 245 const float32x4_t hadd_l = vpaddq_f32(add[0], add[1]); 246 const float32x4_t haddhadd = vpaddq_f32(hadd_l, hadd_h); 247 #else 248 const float32x4_t hadd_h = 249 vcombine_f32(vpadd_f32(vget_low_f32(add[2]), vget_high_f32(add[2])), 250 vpadd_f32(vget_low_f32(add[3]), vget_high_f32(add[3]))); 251 const float32x4_t hadd_l = 252 vcombine_f32(vpadd_f32(vget_low_f32(add[0]), vget_high_f32(add[0])), 253 vpadd_f32(vget_low_f32(add[1]), vget_high_f32(add[1]))); 254 const float32x4_t haddhadd = 255 vcombine_f32(vpadd_f32(vget_low_f32(hadd_l), vget_high_f32(hadd_l)), 256 vpadd_f32(vget_low_f32(hadd_h), vget_high_f32(hadd_h))); 257 #endif 258 259 outputs = vaddq_f32(outputs, haddhadd); 260 if (!output_layer) nn_activate4(&outputs, &zero); 261 vst1q_f32(output_nodes, outputs); 262 } 263 264 // Calculate prediction based on the given input features and neural net config. 265 // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden 266 // layer. 267 void av1_nn_predict_neon(const float *input_nodes, 268 const NN_CONFIG *const nn_config, int reduce_prec, 269 float *const output) { 270 float buf[2][NN_MAX_NODES_PER_LAYER]; 271 int buf_index = 0; 272 int num_inputs = nn_config->num_inputs; 273 // Hidden layers, except the final iteration is the output layer. 274 for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) { 275 const float *layer_weights = nn_config->weights[layer]; 276 const float *layer_bias = nn_config->bias[layer]; 277 bool output_layer = (layer == nn_config->num_hidden_layers); 278 float *const output_nodes = output_layer ? output : buf[buf_index]; 279 const int num_outputs = output_layer ? nn_config->num_outputs 280 : nn_config->num_hidden_nodes[layer]; 281 282 if (num_inputs % 4 == 0 && num_outputs % 8 == 0) { 283 for (int out = 0; out < num_outputs; out += 8) { 284 nn_propagate_4to8(num_inputs, input_nodes, 285 &layer_weights[out * num_inputs], &layer_bias[out], 286 &output_nodes[out], output_layer); 287 } 288 } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) { 289 for (int out = 0; out < num_outputs; out += 4) { 290 nn_propagate_8to4(num_inputs, input_nodes, 291 &layer_weights[out * num_inputs], &layer_bias[out], 292 &output_nodes[out], output_layer); 293 } 294 } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) { 295 for (int out = 0; out < num_outputs; out += 4) { 296 nn_propagate_4to4(num_inputs, input_nodes, 297 &layer_weights[out * num_inputs], &layer_bias[out], 298 &output_nodes[out], output_layer); 299 } 300 } else if (num_inputs % 8 == 0) { 301 for (int out = 0; out < num_outputs; out++) { 302 nn_propagate_8to1(num_inputs, input_nodes, 303 &layer_weights[out * num_inputs], &layer_bias[out], 304 &output_nodes[out], output_layer); 305 } 306 } else if (num_inputs % 4 == 0) { 307 for (int out = 0; out < num_outputs; out++) { 308 nn_propagate_4to1(num_inputs, input_nodes, 309 &layer_weights[out * num_inputs], &layer_bias[out], 310 &output_nodes[out], output_layer); 311 } 312 } else if (num_inputs > 8) { 313 for (int out = 0; out < num_outputs; out++) { 314 nn_propagate_xto1(num_inputs, input_nodes, 315 &layer_weights[out * num_inputs], &layer_bias[out], 316 &output_nodes[out]); 317 } 318 } else if (num_inputs >= 4) { 319 for (int out = 0; out < num_outputs; out++) { 320 nn_propagate_xsto1(num_inputs, input_nodes, 321 &layer_weights[out * num_inputs], &layer_bias[out], 322 &output_nodes[out]); 323 } 324 } else { 325 for (int node = 0; node < num_outputs; ++node) { 326 float val = layer_bias[node]; 327 for (int i = 0; i < num_inputs; ++i) 328 val += layer_weights[node * num_inputs + i] * input_nodes[i]; 329 // ReLU as activation function. 330 val = val > 0.0f ? val : 0.0f; // Could use AOMMAX(). 331 output_nodes[node] = val; 332 } 333 } 334 input_nodes = output_nodes; 335 num_inputs = num_outputs; 336 buf_index = 1 - buf_index; 337 } 338 if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs); 339 }