tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ml_neon.c (13384B)


      1 /*
      2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <stdbool.h>
     13 #include <assert.h>
     14 #include <arm_neon.h>
     15 
     16 #include "config/aom_config.h"
     17 #include "config/av1_rtcd.h"
     18 #include "av1/encoder/ml.h"
     19 
     20 static void nn_activate8(float32x4_t *out_h, float32x4_t *out_l,
     21                         const float32x4_t *zero) {
     22  *out_h = vmaxq_f32(*out_h, *zero);
     23  *out_l = vmaxq_f32(*out_l, *zero);
     24 }
     25 
     26 static void nn_activate4(float32x4_t *x, const float32x4_t *zero) {
     27  *x = vmaxq_f32(*x, *zero);
     28 }
     29 
     30 #define CLAMP_0(x) (x = x > 0 ? x : 0)
     31 
     32 static void nn_propagate_8to1(int num_inputs, const float *const inputs,
     33                              const float *const weights,
     34                              const float *layer_bias,
     35                              float *const output_nodes, bool output_layer) {
     36  const float32x4_t zero = vdupq_n_f32(0);
     37  float32x4_t vadd = zero;
     38  float total = *layer_bias;
     39 
     40  for (int in = 0; in < num_inputs; in += 8) {
     41    const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
     42    const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
     43 
     44    const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
     45    const float32x4_t weights_l = vld1q_f32(&weights[in]);
     46 
     47    vadd = vmlaq_f32(vadd, inputs_h, weights_h);
     48    vadd = vmlaq_f32(vadd, inputs_l, weights_l);
     49  }
     50 #if AOM_ARCH_AARCH64
     51  total += vaddvq_f32(vadd);
     52 #else
     53  float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
     54  vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
     55  total += vget_lane_f32(vadd_lo, 0);
     56 #endif
     57 
     58  if (!output_layer) CLAMP_0(total);
     59  *output_nodes = total;
     60 }
     61 
     62 static void nn_propagate_xto1(int num_inputs, const float *const inputs,
     63                              const float *const weights,
     64                              const float *layer_bias,
     65                              float *const output_nodes) {
     66  float32x4_t vadd = vdupq_n_f32(0);
     67 
     68  float total = *layer_bias;
     69  int j = num_inputs;
     70  int in = 0;
     71  while (j > 7) {
     72    const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
     73    const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
     74 
     75    const float32x4_t weights_h = vld1q_f32(&weights[in + 4]);
     76    const float32x4_t weights_l = vld1q_f32(&weights[in]);
     77 
     78    vadd = vmlaq_f32(vadd, inputs_h, weights_h);
     79    vadd = vmlaq_f32(vadd, inputs_l, weights_l);
     80    in += 8;
     81    j -= 8;
     82  }
     83 
     84 #if AOM_ARCH_AARCH64
     85  total += vaddvq_f32(vadd);
     86 
     87 #else
     88  float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
     89  vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
     90  total += vget_lane_f32(vadd_lo, 0);
     91 #endif
     92  for (; in < num_inputs; in++) total += weights[in] * inputs[in];
     93 
     94  *output_nodes = CLAMP_0(total);
     95 }
     96 
     97 static void nn_propagate_xsto1(int num_inputs, const float *const inputs,
     98                               const float *const weights,
     99                               const float *layer_bias,
    100                               float *const output_nodes) {
    101  float total = *layer_bias;
    102 #if AOM_ARCH_AARCH64
    103  const float32x4_t v_inputs = vld1q_f32(inputs);
    104  const float32x4_t v_weights = vld1q_f32(weights);
    105  const float32x4_t vadd = vmulq_f32(v_inputs, v_weights);
    106  total += vaddvq_f32(vadd);
    107  int in = 4;
    108 #else
    109  int in = 0;
    110 #endif
    111  for (; in < num_inputs; in++) total += weights[in] * inputs[in];
    112 
    113  *output_nodes = CLAMP_0(total);
    114 }
    115 
    116 static void nn_propagate_4to1(int num_inputs, const float *const inputs,
    117                              const float *const weights,
    118                              const float *layer_bias,
    119                              float *const output_nodes, bool output_layer) {
    120  const float32x4_t zero = vdupq_n_f32(0);
    121  float32x4_t vadd = zero;
    122  float total = *layer_bias;
    123 
    124  for (int in = 0; in < num_inputs; in += 4) {
    125    const float32x4_t v_inputs = vld1q_f32(&inputs[in]);
    126    const float32x4_t v_weights = vld1q_f32(&weights[in]);
    127    vadd = vmlaq_f32(vadd, v_inputs, v_weights);
    128  }
    129 
    130 #if AOM_ARCH_AARCH64
    131  total += vaddvq_f32(vadd);
    132 #else
    133  float32x2_t vadd_lo = vadd_f32(vget_low_f32(vadd), vget_high_f32(vadd));
    134  vadd_lo = vpadd_f32(vadd_lo, vadd_lo);
    135  total += vget_lane_f32(vadd_lo, 0);
    136 #endif
    137 
    138  if (!output_layer) CLAMP_0(total);
    139  *output_nodes = total;
    140 }
    141 
    142 static void nn_propagate_4to4(int num_inputs, const float *const inputs,
    143                              const float *const weights,
    144                              const float *layer_bias,
    145                              float *const output_nodes, bool output_layer) {
    146  float32x4_t outputs = vld1q_f32(layer_bias);
    147  const float32x4_t zero = vdupq_n_f32(0);
    148 
    149  float32x4_t mul0[2] = { zero, zero };
    150  float32x4_t mul1[2] = { zero, zero };
    151  for (int in = 0; in < num_inputs; in += 4) {
    152    const float32x4_t v_input = vld1q_f32(&inputs[in]);
    153 
    154    for (int i = 0; i < 2; i++) {
    155      const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
    156      mul0[i] = vmlaq_f32(mul0[i], weight0, v_input);
    157      const float32x4_t weight1 =
    158          vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
    159      mul1[i] = vmlaq_f32(mul1[i], weight1, v_input);
    160    }
    161  }
    162  for (int i = 0; i < 2; i++)
    163 #if AOM_ARCH_AARCH64
    164    mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
    165  const float32x4_t hh = vpaddq_f32(mul0[0], mul0[1]);
    166 #else
    167    mul0[i] =
    168        vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
    169                     vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
    170  const float32x4_t hh =
    171      vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
    172                   vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
    173 #endif
    174 
    175  outputs = vaddq_f32(outputs, hh);
    176  if (!output_layer) nn_activate4(&outputs, &zero);
    177  vst1q_f32(output_nodes, outputs);
    178 }
    179 
    180 static void nn_propagate_4to8(const int num_inputs, const float *const inputs,
    181                              const float *const weights,
    182                              const float *layer_bias,
    183                              float *const output_nodes, bool output_layer) {
    184  float32x4_t out_h = vld1q_f32(&layer_bias[4]);
    185  float32x4_t out_l = vld1q_f32(layer_bias);
    186  const float32x4_t zero = vdupq_n_f32(0);
    187  float32x4_t mul0[4] = { zero, zero, zero, zero };
    188  float32x4_t mul1[4] = { zero, zero, zero, zero };
    189 
    190  for (int in = 0; in < num_inputs; in += 4) {
    191    const float32x4_t v_input = vld1q_f32(&inputs[in]);
    192    for (int i = 0; i < 4; i++) {
    193      const float32x4_t weight0 = vld1q_f32(&weights[in + 2 * i * num_inputs]);
    194      const float32x4_t weight1 =
    195          vld1q_f32(&weights[in + (2 * i + 1) * num_inputs]);
    196      mul0[i] = vmlaq_f32(mul0[i], v_input, weight0);
    197      mul1[i] = vmlaq_f32(mul1[i], v_input, weight1);
    198    }
    199  }
    200  for (int i = 0; i < 4; i++)
    201 #if AOM_ARCH_AARCH64
    202    mul0[i] = vpaddq_f32(mul0[i], mul1[i]);
    203  const float32x4_t hh0 = vpaddq_f32(mul0[0], mul0[1]);
    204  const float32x4_t hh1 = vpaddq_f32(mul0[2], mul0[3]);
    205 #else
    206    mul0[i] =
    207        vcombine_f32(vpadd_f32(vget_low_f32(mul0[i]), vget_high_f32(mul0[i])),
    208                     vpadd_f32(vget_low_f32(mul1[i]), vget_high_f32(mul1[i])));
    209  const float32x4_t hh0 =
    210      vcombine_f32(vpadd_f32(vget_low_f32(mul0[0]), vget_high_f32(mul0[0])),
    211                   vpadd_f32(vget_low_f32(mul0[1]), vget_high_f32(mul0[1])));
    212  const float32x4_t hh1 =
    213      vcombine_f32(vpadd_f32(vget_low_f32(mul0[2]), vget_high_f32(mul0[2])),
    214                   vpadd_f32(vget_low_f32(mul0[3]), vget_high_f32(mul0[3])));
    215 #endif
    216 
    217  out_h = vaddq_f32(out_h, hh1);
    218  out_l = vaddq_f32(out_l, hh0);
    219 
    220  if (!output_layer) nn_activate8(&out_h, &out_l, &zero);
    221  vst1q_f32(&output_nodes[4], out_h);
    222  vst1q_f32(output_nodes, out_l);
    223 }
    224 
    225 static void nn_propagate_8to4(const int num_inputs, const float *const inputs,
    226                              const float *const weights,
    227                              const float *layer_bias,
    228                              float *const output_nodes, bool output_layer) {
    229  float32x4_t outputs = vld1q_f32(layer_bias);
    230  const float32x4_t zero = vdupq_n_f32(0);
    231  float32x4_t add[4] = { zero, zero, zero, zero };
    232  for (int in = 0; in < num_inputs; in += 8) {
    233    const float32x4_t inputs_l = vld1q_f32(&inputs[in]);
    234    const float32x4_t inputs_h = vld1q_f32(&inputs[in + 4]);
    235 
    236    for (int i = 0; i < 4; i++) {
    237      const float32x4_t weight_l = vld1q_f32(&weights[in + i * num_inputs]);
    238      const float32x4_t weight_h = vld1q_f32(&weights[in + i * num_inputs + 4]);
    239      add[i] = vmlaq_f32(add[i], inputs_l, weight_l);
    240      add[i] = vmlaq_f32(add[i], inputs_h, weight_h);
    241    }
    242  }
    243 #if AOM_ARCH_AARCH64
    244  const float32x4_t hadd_h = vpaddq_f32(add[2], add[3]);
    245  const float32x4_t hadd_l = vpaddq_f32(add[0], add[1]);
    246  const float32x4_t haddhadd = vpaddq_f32(hadd_l, hadd_h);
    247 #else
    248  const float32x4_t hadd_h =
    249      vcombine_f32(vpadd_f32(vget_low_f32(add[2]), vget_high_f32(add[2])),
    250                   vpadd_f32(vget_low_f32(add[3]), vget_high_f32(add[3])));
    251  const float32x4_t hadd_l =
    252      vcombine_f32(vpadd_f32(vget_low_f32(add[0]), vget_high_f32(add[0])),
    253                   vpadd_f32(vget_low_f32(add[1]), vget_high_f32(add[1])));
    254  const float32x4_t haddhadd =
    255      vcombine_f32(vpadd_f32(vget_low_f32(hadd_l), vget_high_f32(hadd_l)),
    256                   vpadd_f32(vget_low_f32(hadd_h), vget_high_f32(hadd_h)));
    257 #endif
    258 
    259  outputs = vaddq_f32(outputs, haddhadd);
    260  if (!output_layer) nn_activate4(&outputs, &zero);
    261  vst1q_f32(output_nodes, outputs);
    262 }
    263 
    264 // Calculate prediction based on the given input features and neural net config.
    265 // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
    266 // layer.
    267 void av1_nn_predict_neon(const float *input_nodes,
    268                         const NN_CONFIG *const nn_config, int reduce_prec,
    269                         float *const output) {
    270  float buf[2][NN_MAX_NODES_PER_LAYER];
    271  int buf_index = 0;
    272  int num_inputs = nn_config->num_inputs;
    273  // Hidden layers, except the final iteration is the output layer.
    274  for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
    275    const float *layer_weights = nn_config->weights[layer];
    276    const float *layer_bias = nn_config->bias[layer];
    277    bool output_layer = (layer == nn_config->num_hidden_layers);
    278    float *const output_nodes = output_layer ? output : buf[buf_index];
    279    const int num_outputs = output_layer ? nn_config->num_outputs
    280                                         : nn_config->num_hidden_nodes[layer];
    281 
    282    if (num_inputs % 4 == 0 && num_outputs % 8 == 0) {
    283      for (int out = 0; out < num_outputs; out += 8) {
    284        nn_propagate_4to8(num_inputs, input_nodes,
    285                          &layer_weights[out * num_inputs], &layer_bias[out],
    286                          &output_nodes[out], output_layer);
    287      }
    288    } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) {
    289      for (int out = 0; out < num_outputs; out += 4) {
    290        nn_propagate_8to4(num_inputs, input_nodes,
    291                          &layer_weights[out * num_inputs], &layer_bias[out],
    292                          &output_nodes[out], output_layer);
    293      }
    294    } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) {
    295      for (int out = 0; out < num_outputs; out += 4) {
    296        nn_propagate_4to4(num_inputs, input_nodes,
    297                          &layer_weights[out * num_inputs], &layer_bias[out],
    298                          &output_nodes[out], output_layer);
    299      }
    300    } else if (num_inputs % 8 == 0) {
    301      for (int out = 0; out < num_outputs; out++) {
    302        nn_propagate_8to1(num_inputs, input_nodes,
    303                          &layer_weights[out * num_inputs], &layer_bias[out],
    304                          &output_nodes[out], output_layer);
    305      }
    306    } else if (num_inputs % 4 == 0) {
    307      for (int out = 0; out < num_outputs; out++) {
    308        nn_propagate_4to1(num_inputs, input_nodes,
    309                          &layer_weights[out * num_inputs], &layer_bias[out],
    310                          &output_nodes[out], output_layer);
    311      }
    312    } else if (num_inputs > 8) {
    313      for (int out = 0; out < num_outputs; out++) {
    314        nn_propagate_xto1(num_inputs, input_nodes,
    315                          &layer_weights[out * num_inputs], &layer_bias[out],
    316                          &output_nodes[out]);
    317      }
    318    } else if (num_inputs >= 4) {
    319      for (int out = 0; out < num_outputs; out++) {
    320        nn_propagate_xsto1(num_inputs, input_nodes,
    321                           &layer_weights[out * num_inputs], &layer_bias[out],
    322                           &output_nodes[out]);
    323      }
    324    } else {
    325      for (int node = 0; node < num_outputs; ++node) {
    326        float val = layer_bias[node];
    327        for (int i = 0; i < num_inputs; ++i)
    328          val += layer_weights[node * num_inputs + i] * input_nodes[i];
    329        // ReLU as activation function.
    330        val = val > 0.0f ? val : 0.0f;  // Could use AOMMAX().
    331        output_nodes[node] = val;
    332      }
    333    }
    334    input_nodes = output_nodes;
    335    num_inputs = num_outputs;
    336    buf_index = 1 - buf_index;
    337  }
    338  if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
    339 }