tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

ml_avx2.c (10861B)


      1 /*
      2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <stdbool.h>
     13 #include <assert.h>
     14 #include <immintrin.h>
     15 
     16 #include "config/av1_rtcd.h"
     17 #include "av1/encoder/ml.h"
     18 #include "av1/encoder/x86/ml_sse3.h"
     19 
     20 #define CALC_OUTPUT_FOR_2ROWS                                               \
     21  const int index = weight_idx + (2 * i * tot_num_inputs);                  \
     22  const __m256 weight0 = _mm256_loadu_ps(&weights[index]);                  \
     23  const __m256 weight1 = _mm256_loadu_ps(&weights[index + tot_num_inputs]); \
     24  const __m256 mul0 = _mm256_mul_ps(inputs256, weight0);                    \
     25  const __m256 mul1 = _mm256_mul_ps(inputs256, weight1);                    \
     26  hadd[i] = _mm256_hadd_ps(mul0, mul1);
     27 
     28 static inline void nn_propagate_8to1(
     29    const float *const inputs, const float *const weights,
     30    const float *const bias, int num_inputs_to_process, int tot_num_inputs,
     31    int num_outputs, float *const output_nodes, int is_clip_required) {
     32  // Process one output row at a time.
     33  for (int out = 0; out < num_outputs; out++) {
     34    __m256 in_result = _mm256_setzero_ps();
     35    float bias_val = bias[out];
     36    for (int in = 0; in < num_inputs_to_process; in += 8) {
     37      const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]);
     38      const int weight_idx = in + (out * tot_num_inputs);
     39      const __m256 weight0 = _mm256_loadu_ps(&weights[weight_idx]);
     40      const __m256 mul0 = _mm256_mul_ps(inputs256, weight0);
     41      in_result = _mm256_add_ps(in_result, mul0);
     42    }
     43    const __m128 low_128 = _mm256_castps256_ps128(in_result);
     44    const __m128 high_128 = _mm256_extractf128_ps(in_result, 1);
     45    const __m128 sum_par_0 = _mm_add_ps(low_128, high_128);
     46    const __m128 sum_par_1 = _mm_hadd_ps(sum_par_0, sum_par_0);
     47    const __m128 sum_tot =
     48        _mm_add_ps(_mm_shuffle_ps(sum_par_1, sum_par_1, 0x99), sum_par_1);
     49 
     50    bias_val += _mm_cvtss_f32(sum_tot);
     51    if (is_clip_required) bias_val = AOMMAX(bias_val, 0);
     52    output_nodes[out] = bias_val;
     53  }
     54 }
     55 
     56 static inline void nn_propagate_8to4(
     57    const float *const inputs, const float *const weights,
     58    const float *const bias, int num_inputs_to_process, int tot_num_inputs,
     59    int num_outputs, float *const output_nodes, int is_clip_required) {
     60  __m256 hadd[2];
     61  for (int out = 0; out < num_outputs; out += 4) {
     62    __m128 bias_reg = _mm_loadu_ps(&bias[out]);
     63    __m128 in_result = _mm_setzero_ps();
     64    for (int in = 0; in < num_inputs_to_process; in += 8) {
     65      const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]);
     66      const int weight_idx = in + (out * tot_num_inputs);
     67      // Process two output row at a time.
     68      for (int i = 0; i < 2; i++) {
     69        CALC_OUTPUT_FOR_2ROWS
     70      }
     71 
     72      const __m256 sum_par = _mm256_hadd_ps(hadd[0], hadd[1]);
     73      const __m128 low_128 = _mm256_castps256_ps128(sum_par);
     74      const __m128 high_128 = _mm256_extractf128_ps(sum_par, 1);
     75      const __m128 result = _mm_add_ps(low_128, high_128);
     76 
     77      in_result = _mm_add_ps(in_result, result);
     78    }
     79 
     80    in_result = _mm_add_ps(in_result, bias_reg);
     81    if (is_clip_required) in_result = _mm_max_ps(in_result, _mm_setzero_ps());
     82    _mm_storeu_ps(&output_nodes[out], in_result);
     83  }
     84 }
     85 
     86 static inline void nn_propagate_8to8(
     87    const float *const inputs, const float *const weights,
     88    const float *const bias, int num_inputs_to_process, int tot_num_inputs,
     89    int num_outputs, float *const output_nodes, int is_clip_required) {
     90  __m256 hadd[4];
     91  for (int out = 0; out < num_outputs; out += 8) {
     92    __m256 bias_reg = _mm256_loadu_ps(&bias[out]);
     93    __m256 in_result = _mm256_setzero_ps();
     94    for (int in = 0; in < num_inputs_to_process; in += 8) {
     95      const __m256 inputs256 = _mm256_loadu_ps(&inputs[in]);
     96      const int weight_idx = in + (out * tot_num_inputs);
     97      // Process two output rows at a time.
     98      for (int i = 0; i < 4; i++) {
     99        CALC_OUTPUT_FOR_2ROWS
    100      }
    101      const __m256 hh0 = _mm256_hadd_ps(hadd[0], hadd[1]);
    102      const __m256 hh1 = _mm256_hadd_ps(hadd[2], hadd[3]);
    103 
    104      __m256 ht_0 = _mm256_permute2f128_ps(hh0, hh1, 0x20);
    105      __m256 ht_1 = _mm256_permute2f128_ps(hh0, hh1, 0x31);
    106 
    107      __m256 result = _mm256_add_ps(ht_0, ht_1);
    108      in_result = _mm256_add_ps(in_result, result);
    109    }
    110    in_result = _mm256_add_ps(in_result, bias_reg);
    111    if (is_clip_required)
    112      in_result = _mm256_max_ps(in_result, _mm256_setzero_ps());
    113    _mm256_storeu_ps(&output_nodes[out], in_result);
    114  }
    115 }
    116 
    117 static inline void nn_propagate_input_multiple_of_8(
    118    const float *const inputs, const float *const weights,
    119    const float *const bias, int num_inputs_to_process, int tot_num_inputs,
    120    bool is_output_layer, int num_outputs, float *const output_nodes) {
    121  // The saturation of output is considered for hidden layer which is not equal
    122  // to final hidden layer.
    123  const int is_clip_required =
    124      !is_output_layer && num_inputs_to_process == tot_num_inputs;
    125  if (num_outputs % 8 == 0) {
    126    nn_propagate_8to8(inputs, weights, bias, num_inputs_to_process,
    127                      tot_num_inputs, num_outputs, output_nodes,
    128                      is_clip_required);
    129  } else if (num_outputs % 4 == 0) {
    130    nn_propagate_8to4(inputs, weights, bias, num_inputs_to_process,
    131                      tot_num_inputs, num_outputs, output_nodes,
    132                      is_clip_required);
    133  } else {
    134    nn_propagate_8to1(inputs, weights, bias, num_inputs_to_process,
    135                      tot_num_inputs, num_outputs, output_nodes,
    136                      is_clip_required);
    137  }
    138 }
    139 
    140 void av1_nn_predict_avx2(const float *input_nodes,
    141                         const NN_CONFIG *const nn_config, int reduce_prec,
    142                         float *const output) {
    143  float buf[2][NN_MAX_NODES_PER_LAYER];
    144  int buf_index = 0;
    145  int num_inputs = nn_config->num_inputs;
    146  assert(num_inputs > 0 && num_inputs <= NN_MAX_NODES_PER_LAYER);
    147 
    148  for (int layer = 0; layer <= nn_config->num_hidden_layers; layer++) {
    149    const float *layer_weights = nn_config->weights[layer];
    150    const float *layer_bias = nn_config->bias[layer];
    151    bool is_output_layer = layer == nn_config->num_hidden_layers;
    152    float *const output_nodes = is_output_layer ? output : &buf[buf_index][0];
    153    const int num_outputs = is_output_layer
    154                                ? nn_config->num_outputs
    155                                : nn_config->num_hidden_nodes[layer];
    156    assert(num_outputs > 0 && num_outputs <= NN_MAX_NODES_PER_LAYER);
    157 
    158    // Process input multiple of 8 using AVX2 intrinsic.
    159    if (num_inputs % 8 == 0) {
    160      nn_propagate_input_multiple_of_8(input_nodes, layer_weights, layer_bias,
    161                                       num_inputs, num_inputs, is_output_layer,
    162                                       num_outputs, output_nodes);
    163    } else {
    164      // When number of inputs is not multiple of 8, use hybrid approach of AVX2
    165      // and SSE3 based on the need.
    166      const int in_mul_8 = num_inputs / 8;
    167      const int num_inputs_to_process = in_mul_8 * 8;
    168      int bias_is_considered = 0;
    169      if (in_mul_8) {
    170        nn_propagate_input_multiple_of_8(
    171            input_nodes, layer_weights, layer_bias, num_inputs_to_process,
    172            num_inputs, is_output_layer, num_outputs, output_nodes);
    173        bias_is_considered = 1;
    174      }
    175 
    176      const float *out_temp = bias_is_considered ? output_nodes : layer_bias;
    177      const int input_remaining = num_inputs % 8;
    178      if (input_remaining % 4 == 0 && num_outputs % 8 == 0) {
    179        for (int out = 0; out < num_outputs; out += 8) {
    180          __m128 out_h = _mm_loadu_ps(&out_temp[out + 4]);
    181          __m128 out_l = _mm_loadu_ps(&out_temp[out]);
    182          for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
    183            av1_nn_propagate_4to8_sse3(&input_nodes[in],
    184                                       &layer_weights[out * num_inputs + in],
    185                                       &out_h, &out_l, num_inputs);
    186          }
    187          if (!is_output_layer) {
    188            const __m128 zero = _mm_setzero_ps();
    189            out_h = _mm_max_ps(out_h, zero);
    190            out_l = _mm_max_ps(out_l, zero);
    191          }
    192          _mm_storeu_ps(&output_nodes[out + 4], out_h);
    193          _mm_storeu_ps(&output_nodes[out], out_l);
    194        }
    195      } else if (input_remaining % 4 == 0 && num_outputs % 4 == 0) {
    196        for (int out = 0; out < num_outputs; out += 4) {
    197          __m128 outputs = _mm_loadu_ps(&out_temp[out]);
    198          for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
    199            av1_nn_propagate_4to4_sse3(&input_nodes[in],
    200                                       &layer_weights[out * num_inputs + in],
    201                                       &outputs, num_inputs);
    202          }
    203          if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
    204          _mm_storeu_ps(&output_nodes[out], outputs);
    205        }
    206      } else if (input_remaining % 4 == 0) {
    207        for (int out = 0; out < num_outputs; out++) {
    208          __m128 outputs = _mm_load1_ps(&out_temp[out]);
    209          for (int in = in_mul_8 * 8; in < num_inputs; in += 4) {
    210            av1_nn_propagate_4to1_sse3(&input_nodes[in],
    211                                       &layer_weights[out * num_inputs + in],
    212                                       &outputs);
    213          }
    214          if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
    215          output_nodes[out] = _mm_cvtss_f32(outputs);
    216        }
    217      } else {
    218        // Use SSE instructions for scalar operations to avoid the latency
    219        // of swapping between SIMD and FPU modes.
    220        for (int out = 0; out < num_outputs; out++) {
    221          __m128 outputs = _mm_load1_ps(&out_temp[out]);
    222          for (int in_node = in_mul_8 * 8; in_node < num_inputs; in_node++) {
    223            __m128 input = _mm_load1_ps(&input_nodes[in_node]);
    224            __m128 weight =
    225                _mm_load1_ps(&layer_weights[num_inputs * out + in_node]);
    226            outputs = _mm_add_ps(outputs, _mm_mul_ps(input, weight));
    227          }
    228          if (!is_output_layer) outputs = _mm_max_ps(outputs, _mm_setzero_ps());
    229          output_nodes[out] = _mm_cvtss_f32(outputs);
    230        }
    231      }
    232    }
    233    // Before processing the next layer, treat the output of current layer as
    234    // input to next layer.
    235    input_nodes = output_nodes;
    236    num_inputs = num_outputs;
    237    buf_index = 1 - buf_index;
    238  }
    239  if (reduce_prec) av1_nn_output_prec_reduce(output, nn_config->num_outputs);
    240 }