tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ml_sse3.c (12769B)


      1 /*
      2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <stdbool.h>
     13 #include <assert.h>
     14 
     15 #include "config/av1_rtcd.h"
     16 #include "av1/encoder/ml.h"
     17 #include "av1/encoder/x86/ml_sse3.h"
     18 
     19 // In order to avoid the high-latency of swapping between FPU and SIMD
     20 // operations, we keep the result in a 128-bit register even though we only
     21 // care about a single value.
     22 static void nn_propagate_8to1(const float *const inputs,
     23                              const float *const weights,
     24                              __m128 *const output) {
     25  const __m128 inputs_h = _mm_loadu_ps(&inputs[4]);
     26  const __m128 inputs_l = _mm_loadu_ps(inputs);
     27 
     28  const __m128 weights_h = _mm_loadu_ps(&weights[4]);
     29  const __m128 weights_l = _mm_loadu_ps(weights);
     30 
     31  const __m128 mul_h = _mm_mul_ps(inputs_h, weights_h);
     32  const __m128 mul_l = _mm_mul_ps(inputs_l, weights_l);
     33  // [7 6 5 4] [3 2 1 0] (weight and input indices)
     34 
     35  const __m128 vadd = _mm_add_ps(mul_l, mul_h);
     36  // [7+3 6+2 5+1 4+0]
     37  const __m128 hadd1 = _mm_hadd_ps(vadd, vadd);
     38  // [7+6+3+2 5+4+1+0 7+6+3+2 5+4+1+0]
     39  const __m128 hadd2 = _mm_hadd_ps(hadd1, hadd1);
     40  // [7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0 7+6+5+4+3+2+1+0]
     41  *output = _mm_add_ps(*output, hadd2);
     42 }
     43 
     44 void av1_nn_propagate_4to1_sse3(const float *const inputs,
     45                                const float *const weights,
     46                                __m128 *const output) {
     47  const __m128 inputs128 = _mm_loadu_ps(inputs);
     48 
     49  const __m128 weights128 = _mm_loadu_ps(weights);
     50 
     51  const __m128 mul = _mm_mul_ps(inputs128, weights128);
     52  // [3 2 1 0] (weight and input indices)
     53 
     54  const __m128 hadd1 = _mm_hadd_ps(mul, mul);
     55  // [3+2 1+0 3+2 1+0]
     56  const __m128 hadd2 = _mm_hadd_ps(hadd1, hadd1);
     57  // [3+2+1+0 3+2+1+0 3+2+1+0 3+2+1+0]
     58  *output = _mm_add_ps(*output, hadd2);
     59 }
     60 
     61 void av1_nn_propagate_4to4_sse3(const float *const inputs,
     62                                const float *const weights,
     63                                __m128 *const outputs, const int num_inputs) {
     64  const __m128 inputs128 = _mm_loadu_ps(inputs);
     65 
     66  __m128 hadd[2];
     67  for (int i = 0; i < 2; i++) {  // For each pair of outputs
     68    const __m128 weight0 = _mm_loadu_ps(&weights[2 * i * num_inputs]);
     69    const __m128 mul0 = _mm_mul_ps(weight0, inputs128);
     70    const __m128 weight1 = _mm_loadu_ps(&weights[(2 * i + 1) * num_inputs]);
     71    const __m128 mul1 = _mm_mul_ps(weight1, inputs128);
     72    hadd[i] = _mm_hadd_ps(mul0, mul1);
     73  }
     74  // hadd[0] = [7+6 5+4 3+2 1+0] (weight indices)
     75  // hadd[1] = [15+14 13+12 11+10 9+8]
     76 
     77  const __m128 hh = _mm_hadd_ps(hadd[0], hadd[1]);
     78  // [15+14+13+12 11+10+9+8 7+6+5+4 3+2+1+0]
     79 
     80  *outputs = _mm_add_ps(*outputs, hh);
     81 }
     82 
     83 void av1_nn_propagate_4to8_sse3(const float *const inputs,
     84                                const float *const weights, __m128 *const out_h,
     85                                __m128 *const out_l, const int num_inputs) {
     86  const __m128 inputs128 = _mm_loadu_ps(inputs);
     87 
     88  __m128 hadd[4];
     89  for (int i = 0; i < 4; i++) {  // For each pair of outputs
     90    const __m128 weight0 = _mm_loadu_ps(&weights[2 * i * num_inputs]);
     91    const __m128 weight1 = _mm_loadu_ps(&weights[(2 * i + 1) * num_inputs]);
     92    const __m128 mul0 = _mm_mul_ps(inputs128, weight0);
     93    const __m128 mul1 = _mm_mul_ps(inputs128, weight1);
     94    hadd[i] = _mm_hadd_ps(mul0, mul1);
     95  }
     96  // hadd[0] = [7+6 5+4 3+2 1+0] (weight indices)
     97  // hadd[1] = [15+14 13+12 11+10 9+8]
     98  // hadd[2] = [23+22 21+20 19+18 17+16]
     99  // hadd[3] = [31+30 29+28 27+26 25+24]
    100 
    101  const __m128 hh0 = _mm_hadd_ps(hadd[0], hadd[1]);
    102  // [15+14+13+12 11+10+9+8 7+6+5+4 3+2+1+0]
    103  const __m128 hh1 = _mm_hadd_ps(hadd[2], hadd[3]);
    104  // [31+30+29+28 27+26+25+24 23+22+21+20 19+18+17+16]
    105 
    106  *out_h = _mm_add_ps(*out_h, hh1);
    107  *out_l = _mm_add_ps(*out_l, hh0);
    108 }
    109 
    110 static void nn_propagate_8to4(const float *const inputs,
    111                              const float *const weights, __m128 *const outputs,
    112                              const int num_inputs) {
    113  const __m128 inputs_h = _mm_loadu_ps(inputs + 4);
    114  const __m128 inputs_l = _mm_loadu_ps(inputs);
    115  // [7 6 5 4] [3 2 1 0] (input indices)
    116 
    117  __m128 add[4];
    118  for (int i = 0; i < 4; i++) {  // For each output:
    119    const __m128 weight_h = _mm_loadu_ps(&weights[i * num_inputs + 4]);
    120    const __m128 weight_l = _mm_loadu_ps(&weights[i * num_inputs]);
    121    const __m128 mul_h = _mm_mul_ps(inputs_h, weight_h);
    122    const __m128 mul_l = _mm_mul_ps(inputs_l, weight_l);
    123    add[i] = _mm_add_ps(mul_l, mul_h);
    124  }
    125  // add[0] = [7+3 6+2 5+1 4+0]
    126  // add[1] = [15+11 14+10 13+9 12+8]
    127  // add[2] = [23+19 22+18 21+17 20+16]
    128  // add[3] = [31+27 30+26 29+25 28+24]
    129 
    130  const __m128 hadd_h = _mm_hadd_ps(add[2], add[3]);
    131  // [31+30+27+26 29+28+25+24 23+22+19+18 21+20+17+16]
    132  const __m128 hadd_l = _mm_hadd_ps(add[0], add[1]);
    133  // [15+14+11+10 13+12+9+8 7+6+3+2 5+4+1+0]
    134 
    135  const __m128 haddhadd = _mm_hadd_ps(hadd_l, hadd_h);
    136  // [31+30+29+28+27+26+25+24 23+22+21+20+19+18+17+16
    137  //  15+14+13+12+11+10+9+8 7+6+5+4+3+2+1+0]
    138 
    139  *outputs = _mm_add_ps(*outputs, haddhadd);
    140 }
    141 
    142 static void nn_activate8(__m128 *out_h, __m128 *out_l) {
    143  const __m128 zero = _mm_setzero_ps();
    144  *out_h = _mm_max_ps(*out_h, zero);
    145  *out_l = _mm_max_ps(*out_l, zero);
    146 }
    147 
    148 static void nn_activate4(__m128 *x) { *x = _mm_max_ps(*x, _mm_setzero_ps()); }
    149 
    150 // Calculate prediction based on the given input features and neural net config.
    151 // Assume there are no more than NN_MAX_NODES_PER_LAYER nodes in each hidden
    152 // layer.
    153 void av1_nn_predict_sse3(const float *input_nodes,
    154                         const NN_CONFIG *const nn_config, int reduce_prec,
    155                         float *const output) {
    156  float buf[2][NN_MAX_NODES_PER_LAYER];
    157  int buf_index = 0;
    158  int num_inputs = nn_config->num_inputs;
    159 
    160  // Hidden layers, except the final iteration is the output layer.
    161  for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
    162    const float *layer_weights = nn_config->weights[layer];
    163    const float *layer_bias = nn_config->bias[layer];
    164    bool output_layer = (layer == nn_config->num_hidden_layers);
    165    float *const output_nodes = output_layer ? output : &buf[buf_index][0];
    166    const int num_outputs = output_layer ? nn_config->num_outputs
    167                                         : nn_config->num_hidden_nodes[layer];
    168 
    169    if (num_inputs % 4 == 0 && num_outputs % 8 == 0) {
    170      for (int out = 0; out < num_outputs; out += 8) {
    171        __m128 out_h = _mm_loadu_ps(&layer_bias[out + 4]);
    172        __m128 out_l = _mm_loadu_ps(&layer_bias[out]);
    173        for (int in = 0; in < num_inputs; in += 4) {
    174          av1_nn_propagate_4to8_sse3(&input_nodes[in],
    175                                     &layer_weights[out * num_inputs + in],
    176                                     &out_h, &out_l, num_inputs);
    177        }
    178        if (!output_layer) nn_activate8(&out_h, &out_l);
    179        _mm_storeu_ps(&output_nodes[out + 4], out_h);
    180        _mm_storeu_ps(&output_nodes[out], out_l);
    181      }
    182    } else if (num_inputs % 8 == 0 && num_outputs % 4 == 0) {
    183      for (int out = 0; out < num_outputs; out += 4) {
    184        __m128 outputs = _mm_loadu_ps(&layer_bias[out]);
    185        for (int in = 0; in < num_inputs; in += 8) {
    186          nn_propagate_8to4(&input_nodes[in],
    187                            &layer_weights[out * num_inputs + in], &outputs,
    188                            num_inputs);
    189        }
    190        if (!output_layer) nn_activate4(&outputs);
    191        _mm_storeu_ps(&output_nodes[out], outputs);
    192      }
    193    } else if (num_inputs % 4 == 0 && num_outputs % 4 == 0) {
    194      for (int out = 0; out < num_outputs; out += 4) {
    195        __m128 outputs = _mm_loadu_ps(&layer_bias[out]);
    196        for (int in = 0; in < num_inputs; in += 4) {
    197          av1_nn_propagate_4to4_sse3(&input_nodes[in],
    198                                     &layer_weights[out * num_inputs + in],
    199                                     &outputs, num_inputs);
    200        }
    201        if (!output_layer) nn_activate4(&outputs);
    202        _mm_storeu_ps(&output_nodes[out], outputs);
    203      }
    204    } else if (num_inputs % 8 == 0) {
    205      for (int out = 0; out < num_outputs; out++) {
    206        __m128 total = _mm_load1_ps(&layer_bias[out]);
    207        for (int in = 0; in < num_inputs; in += 8) {
    208          nn_propagate_8to1(&input_nodes[in],
    209                            &layer_weights[out * num_inputs + in], &total);
    210        }
    211        if (!output_layer) nn_activate4(&total);
    212        output_nodes[out] = _mm_cvtss_f32(total);
    213      }
    214    } else if (num_inputs % 4 == 0) {
    215      for (int out = 0; out < num_outputs; out++) {
    216        __m128 total = _mm_load1_ps(&layer_bias[out]);
    217        for (int in = 0; in < num_inputs; in += 4) {
    218          av1_nn_propagate_4to1_sse3(
    219              &input_nodes[in], &layer_weights[out * num_inputs + in], &total);
    220        }
    221        if (!output_layer) nn_activate4(&total);
    222        output_nodes[out] = _mm_cvtss_f32(total);
    223      }
    224    } else {
    225      // Use SSE instructions for scalar operations to avoid the latency of
    226      // swapping between SIMD and FPU modes.
    227      for (int out = 0; out < num_outputs; out++) {
    228        __m128 total = _mm_load1_ps(&layer_bias[out]);
    229        for (int in_node = 0; in_node < num_inputs; in_node++) {
    230          __m128 input = _mm_load1_ps(&input_nodes[in_node]);
    231          __m128 weight =
    232              _mm_load1_ps(&layer_weights[num_inputs * out + in_node]);
    233          total = _mm_add_ps(total, _mm_mul_ps(input, weight));
    234        }
    235        if (!output_layer) nn_activate4(&total);
    236        output_nodes[out] = _mm_cvtss_f32(total);
    237      }
    238    }
    239    input_nodes = output_nodes;
    240    num_inputs = num_outputs;
    241    buf_index = 1 - buf_index;
    242  }
    243  if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
    244 }
    245 
    246 // Based on N. N. Schraudolph. A Fast, Compact Approximation of the Exponential
    247 // Function. Neural Computation, 11(4):853–862, 1999.
    248 static inline __m128 approx_exp(__m128 y) {
    249 #define A ((1 << 23) / 0.69314718056f)  // (1 << 23) / ln(2)
    250 #define B \
    251  127  // Offset for the exponent according to IEEE floating point standard.
    252 #define C 60801  // Magic number controls the accuracy of approximation
    253  const __m128 multiplier = _mm_set1_ps(A);
    254  const __m128i offset = _mm_set1_epi32(B * (1 << 23) - C);
    255 
    256  y = _mm_mul_ps(y, multiplier);
    257  y = _mm_castsi128_ps(_mm_add_epi32(_mm_cvtps_epi32(y), offset));
    258  return y;
    259 #undef A
    260 #undef B
    261 #undef C
    262 }
    263 
    264 static inline __m128 reduce_max(__m128 reg) {
    265  __m128 tmp_reg;
    266 
    267  tmp_reg = _mm_shuffle_ps(reg, reg, 0x4e);  // 01 00 11 10
    268  reg = _mm_max_ps(reg, tmp_reg);
    269 
    270  tmp_reg = _mm_shuffle_ps(reg, reg, 0xb1);  // 10 11 00 01
    271  reg = _mm_max_ps(reg, tmp_reg);
    272 
    273  return reg;
    274 }
    275 
    276 static inline __m128 reduce_sum(__m128 reg) {
    277  __m128 tmp_reg;
    278 
    279  tmp_reg = _mm_shuffle_ps(reg, reg, 0x4e);  // 01 00 11 10
    280  reg = _mm_add_ps(reg, tmp_reg);
    281 
    282  tmp_reg = _mm_shuffle_ps(reg, reg, 0xb1);  // 10 11 00 01
    283  reg = _mm_add_ps(reg, tmp_reg);
    284 
    285  return reg;
    286 }
    287 
    288 void av1_nn_fast_softmax_16_sse3(const float *input, float *output) {
    289  // Clips at -10 to avoid underflowing
    290  const __m128 clipper = _mm_set1_ps(-10.0f);
    291 
    292  // Load in 16 values
    293  __m128 in_0 = _mm_loadu_ps(&input[0]);
    294  __m128 in_1 = _mm_loadu_ps(&input[4]);
    295  __m128 in_2 = _mm_loadu_ps(&input[8]);
    296  __m128 in_3 = _mm_loadu_ps(&input[12]);
    297 
    298  // Get the max
    299  __m128 max_0 = _mm_max_ps(in_0, in_1);
    300  __m128 max_1 = _mm_max_ps(in_2, in_3);
    301 
    302  max_0 = _mm_max_ps(max_0, max_1);
    303  max_0 = reduce_max(max_0);
    304 
    305  // Subtract the max off and clip
    306  in_0 = _mm_sub_ps(in_0, max_0);
    307  in_1 = _mm_sub_ps(in_1, max_0);
    308  in_2 = _mm_sub_ps(in_2, max_0);
    309  in_3 = _mm_sub_ps(in_3, max_0);
    310 
    311  in_0 = _mm_max_ps(in_0, clipper);
    312  in_1 = _mm_max_ps(in_1, clipper);
    313  in_2 = _mm_max_ps(in_2, clipper);
    314  in_3 = _mm_max_ps(in_3, clipper);
    315 
    316  // Exponentiate and compute the denominator
    317  __m128 sum = in_0 = approx_exp(in_0);
    318  in_1 = approx_exp(in_1);
    319  sum = _mm_add_ps(sum, in_1);
    320  in_2 = approx_exp(in_2);
    321  sum = _mm_add_ps(sum, in_2);
    322  in_3 = approx_exp(in_3);
    323  sum = _mm_add_ps(sum, in_3);
    324  sum = reduce_sum(sum);
    325 
    326  // Divide to get the probability
    327  in_0 = _mm_div_ps(in_0, sum);
    328  in_1 = _mm_div_ps(in_1, sum);
    329  in_2 = _mm_div_ps(in_2, sum);
    330  in_3 = _mm_div_ps(in_3, sum);
    331 
    332  _mm_storeu_ps(&output[0], in_0);
    333  _mm_storeu_ps(&output[4], in_1);
    334  _mm_storeu_ps(&output[8], in_2);
    335  _mm_storeu_ps(&output[12], in_3);
    336 }