tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

cnn_avx2.c (25637B)


      1 /*
      2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <assert.h>
     13 #include <immintrin.h>
     14 #include <math.h>
     15 
     16 #include "aom_dsp/aom_dsp_common.h"
     17 #include "av1/common/av1_common_int.h"
     18 #include "av1/encoder/cnn.h"
     19 
     20 // This mask rearranges source pixels in the order shown below.
     21 // shuffle_src_layer0[0][8]: applied on source pixels 0 to 7.
     22 // shuffle_src_layer0[1][8]: applied on source pixels 7 to 14.
     23 // This shuffling is needed to process 3 5x5 blocks which need
     24 // source pixels in the following order.
     25 // 1st 5x5 block: source pixels needed are 0 to 4,
     26 // 2nd 5x5 block: source pixels needed are 4 to 8,
     27 // 3rd 5x5 block: source pixels needed are 8 to 12.
     28 // Source pixels are loaded like mentioned below.
     29 // load_src0 : 0, 1, 2, 3, 4, 5, 6, 7
     30 // load_src1 : 7, 8, 9, 10, 11, 12, 13, 14
     31 // After applying masks, source bytes will be in the order:
     32 // load_src0 : 0, 1, 2, 3, 4, 4, 5, 6
     33 //             consists 5 pixels needed for 1st 5x5 block and
     34 //             first 3 pixels needed for 2nd 5x5 block.
     35 // load_src1 : 7, 8, 8, 9, 10, 11, 12, x
     36 //             consists last 2 pixels needed for 2nd 5x5 block and
     37 //             5 pixels needed for 3rd 5x5 block.
     38 DECLARE_ALIGNED(32, static const uint32_t,
     39                shuffle_src_layer0[2][8]) = { { 0, 1, 2, 3, 4, 4, 5, 6 },
     40                                              { 0, 1, 1, 2, 3, 4, 5, 0 } };
     41 
     42 // This mask rearrange the weights to match shuffled source pixels order.
     43 DECLARE_ALIGNED(32, static const uint32_t,
     44                shuffle_weight_layer0[2][8]) = { { 0, 1, 2, 3, 4, 0, 1, 2 },
     45                                                 { 3, 4, 0, 1, 2, 3, 4, 0 } };
     46 
     47 // Shuffle mask used to rearrange weights corresponding to layer 1 and layer 2.
     48 // For layer 1 and layer 2, convolution happens at 2x2 as filter_width and
     49 // filter_height are equal to 2. So rearranging the weights in the
     50 // order shown below to match source pixels. Basically this mask replicates
     51 // the weights across the width of 2.
     52 DECLARE_ALIGNED(32, static const uint32_t,
     53                shuffle_weight_layer_1_and_2[2][8]) = {
     54  { 0, 1, 0, 1, 0, 1, 0, 1 }, { 2, 3, 2, 3, 2, 3, 2, 3 }
     55 };
     56 
     57 // After the stages of multiplication and accumulation, the output values
     58 // in the register will be jumbled. In order to store register into
     59 // output buffer in a proper way, the following mask is applied on output
     60 // register.
     61 DECLARE_ALIGNED(32, static const uint32_t,
     62                shuffle_output_layer_1_and_2[8]) = { 0, 1, 4, 5, 2, 3, 6, 7 };
     63 
     64 // Load weights needed for layer 0 (for 5x5 block processing),
     65 // and fill the registers appropriately to match source pixel mapping.
     66 static inline void prepare_weights_for_5x5_convolve(
     67    const float *layer_config_weights, int off, float weight[5][8],
     68    const int cstep, __m256 *shuffle_weight, const __m256i weight_mask_0,
     69    const __m256i weight_mask_1) {
     70  for (int row = 0; row < 5; ++row) {
     71    for (int col = 0; col < 5; ++col) {
     72      weight[row][col] = layer_config_weights[off];
     73      off += cstep;
     74    }
     75  }
     76  shuffle_weight[0] = _mm256_loadu_ps(weight[0]);
     77  shuffle_weight[1] = _mm256_loadu_ps(weight[1]);
     78  shuffle_weight[2] = _mm256_loadu_ps(weight[2]);
     79  shuffle_weight[3] = _mm256_loadu_ps(weight[3]);
     80  shuffle_weight[4] = _mm256_loadu_ps(weight[4]);
     81 
     82  shuffle_weight[0] =
     83      _mm256_permutevar8x32_ps(shuffle_weight[0], weight_mask_0);
     84  shuffle_weight[1] =
     85      _mm256_permutevar8x32_ps(shuffle_weight[1], weight_mask_0);
     86  shuffle_weight[2] =
     87      _mm256_permutevar8x32_ps(shuffle_weight[2], weight_mask_0);
     88  shuffle_weight[3] =
     89      _mm256_permutevar8x32_ps(shuffle_weight[3], weight_mask_0);
     90  shuffle_weight[4] =
     91      _mm256_permutevar8x32_ps(shuffle_weight[4], weight_mask_0);
     92  shuffle_weight[5] =
     93      _mm256_permutevar8x32_ps(shuffle_weight[0], weight_mask_1);
     94  shuffle_weight[6] =
     95      _mm256_permutevar8x32_ps(shuffle_weight[1], weight_mask_1);
     96  shuffle_weight[7] =
     97      _mm256_permutevar8x32_ps(shuffle_weight[2], weight_mask_1);
     98  shuffle_weight[8] =
     99      _mm256_permutevar8x32_ps(shuffle_weight[3], weight_mask_1);
    100  shuffle_weight[9] =
    101      _mm256_permutevar8x32_ps(shuffle_weight[4], weight_mask_1);
    102 }
    103 
    104 // For each row, loads source pixels 0 to 7(load_src_0), 7 to 14(load_src_1) and
    105 // arranges them appropriately to process 3 blocks.
    106 #define PERFORM_CONVOLVE_FOR_3_5X5_BLOCKS()                            \
    107  do {                                                                 \
    108    for (int row = 0; row < 5; row++) {                                \
    109      load_src_0 = _mm256_loadu_ps(input_ptr);                         \
    110      load_src_1 = _mm256_loadu_ps(input_ptr + 7);                     \
    111      load_src_0 = _mm256_permutevar8x32_ps(load_src_0, block0_1);     \
    112      load_src_1 = _mm256_permutevar8x32_ps(load_src_1, block1_2);     \
    113      load_src_0 = _mm256_mul_ps(load_src_0, shuffle_weight[0 + row]); \
    114      load_src_1 = _mm256_mul_ps(load_src_1, shuffle_weight[5 + row]); \
    115      accum_src_0 = _mm256_add_ps(load_src_0, accum_src_0);            \
    116      accum_src_1 = _mm256_add_ps(load_src_1, accum_src_1);            \
    117      input_ptr += in_stride;                                          \
    118    }                                                                  \
    119  } while (0)
    120 
    121 // Load masks needed for shuffling of output and weights.
    122 static inline void load_shuffle_masks_for_2x2_convolve(__m256i *output_mask,
    123                                                       __m256i *weight_mask) {
    124  // Load shuffle buffer needed to sort the output.
    125  *output_mask =
    126      _mm256_load_si256((const __m256i *)shuffle_output_layer_1_and_2);
    127 
    128  // Load shuffle buffers needed for weight.
    129  weight_mask[0] =
    130      _mm256_load_si256((const __m256i *)shuffle_weight_layer_1_and_2[0]);
    131  weight_mask[1] =
    132      _mm256_load_si256((const __m256i *)shuffle_weight_layer_1_and_2[1]);
    133 }
    134 
    135 // Load weights needed for layer 1 and 2 (for 2x2 block processing),
    136 // and fill the registers appropriately to match source pixel mapping.
    137 static inline void prepare_weights_for_2x2_convolve(
    138    const float *layer_config_weights, int off, const int cstep,
    139    __m256 *shuffle_weight, __m256i *weight_mask) {
    140  // Weights needed for 2x2 block.
    141  float weight[4] = { 0 };
    142  for (int i = 0; i < 4; ++i) {
    143    weight[i] = layer_config_weights[off];
    144    off += cstep;
    145  }
    146 
    147  const __m256 weight_vec = _mm256_castps128_ps256(_mm_loadu_ps(weight));
    148  shuffle_weight[0] = _mm256_permutevar8x32_ps(weight_vec, weight_mask[0]);
    149  shuffle_weight[1] = _mm256_permutevar8x32_ps(weight_vec, weight_mask[1]);
    150 }
    151 
    152 // Do convolution of one 5x5 block.
    153 #define PERFORM_CONVOLVE_FOR_1_5X5_BLOCK(w, accum0, in_stride)           \
    154  do {                                                                   \
    155    __m128 load_src[5];                                                  \
    156    load_src[0] = _mm_loadu_ps(input_ptr);                               \
    157    last_column_sum += input_ptr[4] * weight[0][4];                      \
    158    input_ptr += in_stride;                                              \
    159    load_src[1] = _mm_loadu_ps(input_ptr);                               \
    160    last_column_sum += input_ptr[4] * weight[1][4];                      \
    161    input_ptr += in_stride;                                              \
    162    load_src[2] = _mm_loadu_ps(input_ptr);                               \
    163    last_column_sum += input_ptr[4] * weight[2][4];                      \
    164    input_ptr += in_stride;                                              \
    165    load_src[3] = _mm_loadu_ps(input_ptr);                               \
    166    last_column_sum += input_ptr[4] * weight[3][4];                      \
    167    input_ptr += in_stride;                                              \
    168    load_src[4] = _mm_loadu_ps(input_ptr);                               \
    169    last_column_sum += input_ptr[4] * weight[4][4];                      \
    170                                                                         \
    171    load_src[0] = _mm_mul_ps(load_src[0], _mm256_castps256_ps128(w[0])); \
    172    load_src[1] = _mm_mul_ps(load_src[1], _mm256_castps256_ps128(w[1])); \
    173    load_src[2] = _mm_mul_ps(load_src[2], _mm256_castps256_ps128(w[2])); \
    174    load_src[3] = _mm_mul_ps(load_src[3], _mm256_castps256_ps128(w[3])); \
    175    load_src[4] = _mm_mul_ps(load_src[4], _mm256_castps256_ps128(w[4])); \
    176                                                                         \
    177    accum0 = _mm_add_ps(load_src[0], accum0);                            \
    178    load_src[1] = _mm_add_ps(load_src[1], load_src[2]);                  \
    179    load_src[3] = _mm_add_ps(load_src[3], load_src[4]);                  \
    180    load_src[1] = _mm_add_ps(load_src[1], load_src[3]);                  \
    181    accum0 = _mm_add_ps(accum0, load_src[1]);                            \
    182  } while (0)
    183 
    184 // Do convolution on 8 horizontal 2x2 blocks.
    185 static inline void perform_convolve_for_8h_2x2_blocks(
    186    const float *input_ptr, int in_stride, __m256 *weight, __m256 *out_accum,
    187    __m256i shuffle_output_mask) {
    188  __m256 load_src[4];
    189  // Load input into source registers.
    190  load_src[0] = _mm256_loadu_ps(input_ptr);
    191  load_src[1] = _mm256_loadu_ps(input_ptr + 8);
    192  load_src[2] = _mm256_loadu_ps(input_ptr + in_stride);
    193  load_src[3] = _mm256_loadu_ps(input_ptr + in_stride + 8);
    194 
    195  // Multiply the loaded input with corresponding weights.
    196  load_src[0] = _mm256_mul_ps(load_src[0], weight[0]);
    197  load_src[1] = _mm256_mul_ps(load_src[1], weight[0]);
    198  load_src[2] = _mm256_mul_ps(load_src[2], weight[1]);
    199  load_src[3] = _mm256_mul_ps(load_src[3], weight[1]);
    200 
    201  // Accumulate across 2x2 blocks.
    202  load_src[0] = _mm256_add_ps(load_src[0], load_src[2]);
    203  load_src[1] = _mm256_add_ps(load_src[1], load_src[3]);
    204  load_src[0] = _mm256_hadd_ps(load_src[0], load_src[1]);
    205 
    206  // Sort the output in order to store into output buffer.
    207  load_src[0] = _mm256_permutevar8x32_ps(load_src[0], shuffle_output_mask);
    208  *out_accum = _mm256_add_ps(*out_accum, load_src[0]);
    209 }
    210 
    211 // Do convolution on 8 (4 horizontal x 2 vertical) 2x2 blocks.
    212 static inline void perform_convolve_for_4hx2v_2x2_blocks(
    213    const float *input_ptr, int in_stride, __m256 *weight, __m256 *out_accum,
    214    __m256i shuffle_output_mask) {
    215  __m256 load_src[4];
    216  // Load input into source registers.
    217  load_src[0] = _mm256_loadu_ps(input_ptr);
    218  load_src[1] = _mm256_loadu_ps(input_ptr + in_stride);
    219  load_src[2] = _mm256_loadu_ps(input_ptr + (in_stride * 2));
    220  load_src[3] = _mm256_loadu_ps(input_ptr + (in_stride * 3));
    221 
    222  // Multiply the loaded input with corresponding weights.
    223  load_src[0] = _mm256_mul_ps(load_src[0], weight[0]);
    224  load_src[1] = _mm256_mul_ps(load_src[1], weight[1]);
    225  load_src[2] = _mm256_mul_ps(load_src[2], weight[0]);
    226  load_src[3] = _mm256_mul_ps(load_src[3], weight[1]);
    227 
    228  // Accumulate across 2x2 blocks.
    229  load_src[0] = _mm256_add_ps(load_src[0], load_src[1]);
    230  load_src[2] = _mm256_add_ps(load_src[2], load_src[3]);
    231  load_src[0] = _mm256_hadd_ps(load_src[0], load_src[2]);
    232 
    233  // Sort the output in order to store into output buffer.
    234  load_src[0] = _mm256_permutevar8x32_ps(load_src[0], shuffle_output_mask);
    235  *out_accum = _mm256_add_ps(*out_accum, load_src[0]);
    236 }
    237 
    238 // AVX2 variant of av1_cnn_convolve_no_maxpool_padding_valid_c(), when
    239 // filter_width and filter_height are equal to 5.
    240 // CNN convolve parsing is based on av1_intra_mode_cnn_partition_cnn_config.
    241 // Based on the configuration set for each layer, the current encoder
    242 // always chooses the case of no_maxpool_padding_valid.
    243 // And also for layer 0 convolution happens at 5x5 level as the
    244 // filter_width and filter_height are set as 5.
    245 static void cnn_convolve_no_maxpool_padding_valid_5x5_avx2(
    246    const float **input, int in_width, int in_height, int in_stride,
    247    const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
    248    int start_idx, const int cstep, const int channel_step) {
    249  const int kFilterWidth = 5;
    250  const int kFilterHeight = 5;
    251  const int kSkipWidth = 4;
    252  const int kSkipHeight = 4;
    253  assert(layer_config->filter_width == kFilterWidth &&
    254         layer_config->filter_height == kFilterHeight);
    255  assert(layer_config->skip_width == kSkipWidth &&
    256         layer_config->skip_height == kSkipHeight);
    257 
    258  // Load shuffle buffers needed for source.
    259  const __m256i block0_1 =
    260      _mm256_load_si256((const __m256i *)shuffle_src_layer0[0]);
    261  const __m256i block1_2 =
    262      _mm256_load_si256((const __m256i *)shuffle_src_layer0[1]);
    263 
    264  // Load shuffle buffers needed for weight.
    265  const __m256i weight_mask_0 =
    266      _mm256_load_si256((const __m256i *)shuffle_weight_layer0[0]);
    267  const __m256i weight_mask_1 =
    268      _mm256_load_si256((const __m256i *)shuffle_weight_layer0[1]);
    269 
    270  // Width needs to be moved to go to next iteration of processing 3 5x5 blocks.
    271  const int kSkipWidthForNextIter = kSkipWidth * 3;
    272 
    273  // Minimum width required to process 3 5x5 blocks at a time.
    274  // min width (for processing 3 5x5 block) = 2*skip_width + filter_width
    275  // Here, skip_width specifies how much width we should move while processing
    276  // next block convolution and filter_width specifies for how many pixels
    277  // filter needs to be applied.
    278  const int kMinWidthFor3_5x5Blocks = (kSkipWidth * 2) + kFilterWidth;
    279  for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
    280    const float out_ch_bias = layer_config->bias[i];
    281    for (int k = 0; k < layer_config->in_channels; ++k) {
    282      __m256 shuffle_weight[10];
    283 
    284      // Weights needed are 5x5, for SIMD purpose made this array as 5x8.
    285      float weight[5][8] = { { 0 } };
    286      int off = k * layer_config->out_channels + i;
    287 
    288      // In layer 0, the convolution process happens at 5x5.
    289      // The weights needed for 5x5 block are same across the in-channels,
    290      // which is why the load of weights happens once for each in-channel.
    291      prepare_weights_for_5x5_convolve(layer_config->weights, off, weight,
    292                                       cstep, shuffle_weight, weight_mask_0,
    293                                       weight_mask_1);
    294 
    295      for (int h = 0, u = 0; h < in_height - kFilterHeight + 1;
    296           h += kSkipHeight, ++u) {
    297        const int out_h = u * out_stride;
    298        int v = 0;
    299        int w = 0;
    300        int rem_width = in_width;
    301        // Processing 3 5x5 blocks at a time, if sufficient width is present.
    302        while (rem_width >= kMinWidthFor3_5x5Blocks) {
    303          __m256 load_src_0, load_src_1;
    304          __m256 accum_src_0 = _mm256_setzero_ps();
    305          __m256 accum_src_1 = _mm256_setzero_ps();
    306          const float *input_ptr = &input[k][h * in_stride + w];
    307          PERFORM_CONVOLVE_FOR_3_5X5_BLOCKS();
    308 
    309          // Accumulate across column.
    310          __m256 accum = _mm256_hadd_ps(accum_src_0, accum_src_1);
    311          __m128 tmp_reg_0 = _mm256_extractf128_ps(accum_src_0, 1);
    312          __m128 tmp_reg_1 = _mm256_extractf128_ps(accum_src_1, 1);
    313 
    314          __m128 accum_l = _mm256_castps256_ps128(accum);
    315          __m128 accum_h = _mm256_extractf128_ps(accum, 1);
    316 
    317          __m128 tmp_reg_2 = _mm_add_ps(accum_l, tmp_reg_0);
    318          __m128 tmp_reg_3 = _mm_add_ps(tmp_reg_0, accum_h);
    319          __m128 tmp_reg_4 = _mm_add_ps(tmp_reg_1, accum_h);
    320 
    321          // 1st 5x5 block output.
    322          output[i][out_h + v] =
    323              out_ch_bias + _mm_cvtss_f32(tmp_reg_2) +
    324              _mm_cvtss_f32(_mm_shuffle_ps(accum_l, accum_l, 1));
    325 
    326          // 2nd 5x5 block output.
    327          output[i][out_h + v + 1] =
    328              out_ch_bias +
    329              _mm_cvtss_f32(_mm_shuffle_ps(tmp_reg_3, tmp_reg_3, 1)) +
    330              _mm_cvtss_f32(_mm_shuffle_ps(accum_l, accum_l, 2));
    331 
    332          // 3rd 5x5 block output.
    333          output[i][out_h + v + 2] =
    334              out_ch_bias +
    335              _mm_cvtss_f32(_mm_shuffle_ps(tmp_reg_4, tmp_reg_4, 2)) +
    336              _mm_cvtss_f32(_mm_shuffle_ps(accum_l, accum_l, 3));
    337 
    338          v += 3;
    339          w += kSkipWidthForNextIter;
    340          rem_width -= kSkipWidthForNextIter;
    341        }
    342 
    343        // Process remaining blocks as single 5x5 block at a time.
    344        while (rem_width >= kFilterWidth) {
    345          float last_column_sum = 0;
    346          __m128 accum = _mm_setzero_ps();
    347          const float *input_ptr = &input[k][h * in_stride + w];
    348          PERFORM_CONVOLVE_FOR_1_5X5_BLOCK(shuffle_weight, accum, in_stride);
    349 
    350          // Accumulate across column.
    351          accum = _mm_hadd_ps(accum, accum);
    352          output[i][out_h + v] = out_ch_bias + last_column_sum +
    353                                 _mm_cvtss_f32(accum) +
    354                                 _mm_cvtss_f32(_mm_shuffle_ps(accum, accum, 1));
    355 
    356          v += 1;
    357          w += kSkipWidth;
    358          rem_width -= kSkipWidth;
    359        }
    360      }
    361    }
    362  }
    363 }
    364 
    365 // AVX2 implementation for layer 1.
    366 static inline void cnn_convolve_no_maxpool_padding_valid_layer1_avx2(
    367    const float **input, int in_stride,
    368    const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
    369    int start_idx, const int cstep, const int channel_step) {
    370  __m256i weight_mask[2];
    371  __m256i shuffle_output_mask;
    372  load_shuffle_masks_for_2x2_convolve(&shuffle_output_mask, weight_mask);
    373 
    374  const int kInHeight = 16;
    375  const int kFilterHeight = 2;
    376  const int kSkipHeight = 2;
    377  for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
    378    __m256 bias_reg = _mm256_set1_ps(layer_config->bias[i]);
    379    // out_accum registers are used to store the 2x2 convolve outputs
    380    // (calculated over input block size), which are accumulated across the
    381    // in_channels. As per the design, each iteration of for loop processes 8
    382    // (horizontal) 2x2 blocks and stores in corresponding out_accum register
    383    // (as input size is 16x16, a total of 64 2x2 blocks are present and 8
    384    // out_accum registers are enough to store the outputs).
    385    // Hence for loops corresponding to 'j' and 'h', below, run over the number
    386    // of out_accum registers.
    387    __m256 out_accum[8];
    388    for (int j = 0; j < 8; ++j) out_accum[j] = bias_reg;
    389    for (int k = 0; k < layer_config->in_channels; ++k) {
    390      __m256 shuffle_weight[2];
    391      int off = k * layer_config->out_channels + i;
    392      // In layer 1, the convolution process happens at 2x2.
    393      // The weights needed for 2x2 block are same across the in-channels,
    394      // which is why the load of weights happens once for each in-channel.
    395      prepare_weights_for_2x2_convolve(layer_config->weights, off, cstep,
    396                                       shuffle_weight, weight_mask);
    397 
    398      for (int h = 0, u = 0; h < kInHeight - kFilterHeight + 1;
    399           h += kSkipHeight, ++u) {
    400        const float *input_ptr = &input[k][h * in_stride];
    401        perform_convolve_for_8h_2x2_blocks(input_ptr, in_stride, shuffle_weight,
    402                                           &out_accum[u], shuffle_output_mask);
    403      }
    404    }
    405    // Store output of layer 1.
    406    for (int j = 0; j < 8; ++j) {
    407      _mm256_storeu_ps(&output[i][j * out_stride], out_accum[j]);
    408    }
    409  }
    410 }
    411 
    412 // AVX2 implementation for layer 2.
    413 static inline void cnn_convolve_no_maxpool_padding_valid_layer2_avx2(
    414    const float **input, int in_stride,
    415    const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
    416    int start_idx, const int cstep, const int channel_step) {
    417  __m256i weight_mask[2];
    418  __m256i shuffle_output_mask;
    419  load_shuffle_masks_for_2x2_convolve(&shuffle_output_mask, weight_mask);
    420 
    421  const int kInHeight = 8;
    422  const int kFilterHeight = 2;
    423  const int kSkipHeight = 2;
    424  for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
    425    __m256 bias_reg = _mm256_set1_ps(layer_config->bias[i]);
    426    // out_accum registers are used to store the 2x2 convolve outputs
    427    // (calculated over input block size), which are accumulated across the
    428    // in_channels. As per the design, each iteration of for loop processes 8
    429    // (4 horizontal x 2 vertical) 2x2 blocks and stores in corresponding
    430    // out_accum register (as input size is 8x8, a total of 16 2x2 blocks are
    431    // present and 2 out_accum registers are enough to store the outputs).
    432    // Hence for loops corresponding to 'j' and 'h', below, run over the number
    433    // of out_accum registers.
    434    __m256 out_accum[2];
    435 
    436    // Height needs to be moved to go to next iteration of processing
    437    // while processing 2 2x2 blocks vertically.
    438    const int kSkipHeightForNextIter = kSkipHeight * 2;
    439    for (int j = 0; j < 2; ++j) out_accum[j] = bias_reg;
    440    for (int k = 0; k < layer_config->in_channels; ++k) {
    441      __m256 shuffle_weight[2];
    442      int off = k * layer_config->out_channels + i;
    443      // In layer 2, the convolution process happens at 2x2.
    444      // The weights needed for 2x2 block are same across the in-channels,
    445      // which is why the load of weights happens once for each in-channel.
    446      prepare_weights_for_2x2_convolve(layer_config->weights, off, cstep,
    447                                       shuffle_weight, weight_mask);
    448 
    449      for (int h = 0, u = 0; h < kInHeight - kFilterHeight + 1;
    450           h += kSkipHeightForNextIter, ++u) {
    451        const float *input_ptr = &input[k][h * in_stride];
    452        perform_convolve_for_4hx2v_2x2_blocks(input_ptr, in_stride,
    453                                              shuffle_weight, &out_accum[u],
    454                                              shuffle_output_mask);
    455      }
    456    }
    457    // Store output of layer 2.
    458    for (int j = 0; j < 2; ++j) {
    459      _mm256_storeu_ps(&output[i][j * out_stride * 2], out_accum[j]);
    460    }
    461  }
    462 }
    463 
    464 // AVX2 variant of av1_cnn_convolve_no_maxpool_padding_valid_c(), when
    465 // filter_width and filter_height are equal to 2.
    466 // As per the layer config set by av1_intra_mode_cnn_partition_cnn_config,
    467 // the filter_width and filter_height are equal to 2 for layer >= 1. So
    468 // convolution happens at 2x2 for layer >= 1.
    469 static void cnn_convolve_no_maxpool_padding_valid_2x2_avx2(
    470    const float **input, int in_width, int in_height, int in_stride,
    471    const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
    472    int start_idx, const int cstep, const int channel_step) {
    473  assert(layer_config->filter_width == 2 && layer_config->filter_height == 2);
    474  assert(layer_config->skip_width == 2 && layer_config->skip_height == 2);
    475 
    476  if (in_width == 16 && in_height == 16) {
    477    // This case of in_width and in_height equal to 16 corresponds to layer 1.
    478    // The output size of this layer is 8x8.
    479    cnn_convolve_no_maxpool_padding_valid_layer1_avx2(
    480        input, in_stride, layer_config, output, out_stride, start_idx, cstep,
    481        channel_step);
    482  } else if (in_width == 8 && in_height == 8) {
    483    // This case of in_width and in_height equal to 8 corresponds to layer 2.
    484    // The output size of this layer is 4x4.
    485    cnn_convolve_no_maxpool_padding_valid_layer2_avx2(
    486        input, in_stride, layer_config, output, out_stride, start_idx, cstep,
    487        channel_step);
    488  } else {
    489    // For layer equal to 3 and 4, the input is of size 4x4 and 2x2
    490    // respectively. Implementing SIMD for these cases might not be optimal,
    491    // which is why we call C path for layer >= 3.
    492    av1_cnn_convolve_no_maxpool_padding_valid_c(
    493        input, in_width, in_height, in_stride, layer_config, output, out_stride,
    494        start_idx, cstep, channel_step);
    495  }
    496 }
    497 
    498 // AVX2 variant of av1_cnn_convolve_no_maxpool_padding_valid_c().
    499 // As per the current encoder, av1_cnn_convolve function gets called for
    500 // block size equal to 64x64. av1_cnn_convolve() uses layer config values
    501 // set by av1_intra_mode_cnn_partition_cnn_config. The following are a few
    502 // details related to each layer's config parameters.
    503 // Layer_Number in_size out_size filter_wd filter_ht skip_wd skip_ht
    504 //     0         64x64    16x16      5         5         4       4
    505 //     1         16x16    8x8        2         2         2       2
    506 //     2         8x8      4x4        2         2         2       2
    507 //     3         4x4      2x2        2         2         2       2
    508 //     4         2x2      1x1        2         2         2       2
    509 // Here,
    510 // filter_wd = filter_width and filter_ht = filter_height,
    511 // skip_wd = skip_width and skip_ht = skip_height.
    512 void av1_cnn_convolve_no_maxpool_padding_valid_avx2(
    513    const float **input, int in_width, int in_height, int in_stride,
    514    const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
    515    int start_idx, int cstep, int channel_step) {
    516  if (layer_config->filter_width == 5 && layer_config->filter_height == 5 &&
    517      layer_config->skip_width == 4 && layer_config->skip_height == 4) {
    518    cnn_convolve_no_maxpool_padding_valid_5x5_avx2(
    519        input, in_width, in_height, in_stride, layer_config, output, out_stride,
    520        start_idx, cstep, channel_step);
    521  } else if (layer_config->filter_width == 2 &&
    522             layer_config->filter_height == 2 &&
    523             layer_config->skip_width == 2 && layer_config->skip_height == 2) {
    524    cnn_convolve_no_maxpool_padding_valid_2x2_avx2(
    525        input, in_width, in_height, in_stride, layer_config, output, out_stride,
    526        start_idx, cstep, channel_step);
    527  } else {
    528    av1_cnn_convolve_no_maxpool_padding_valid_c(
    529        input, in_width, in_height, in_stride, layer_config, output, out_stride,
    530        start_idx, cstep, channel_step);
    531  }
    532 }