tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

cnn_test.cc (118995B)


      1 /*
      2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <assert.h>
     13 #include <math.h>
     14 #include <stdio.h>
     15 
     16 #include "gtest/gtest.h"
     17 
     18 #include "config/av1_rtcd.h"
     19 
     20 #include "aom_ports/aom_timer.h"
     21 #include "av1/encoder/cnn.h"
     22 #include "av1/encoder/partition_cnn_weights.h"
     23 #include "test/acm_random.h"
     24 #include "test/function_equivalence_test.h"
     25 #include "test/util.h"
     26 
     27 #define SQR(x) ((x) * (x))
     28 
     29 // Best possible pixelwise guaranteed precision given each float has at most
     30 // 3 specified decimals.
     31 #define PIXELWISE_FLOAT_TOL 1E-2
     32 
     33 #define MSE_FLOAT_TOL 1E-6
     34 #define MSE_INT_TOL 0
     35 
     36 // CNN convolve pixelwise error threshold for functional equivalence.
     37 #define CNN_CONVOLVE_PIXELWISE_FLOAT_TOL 1E-3f
     38 
     39 namespace {
     40 
     41 class CNNTest : public ::testing::Test {
     42 protected:
     43  static void RunCNNTest(int image_width, int image_height, const float *input,
     44                         const float *expected, const CNN_CONFIG *cnn_config,
     45                         int in_stride, CNN_THREAD_DATA *thread_data,
     46                         double tolerance) {
     47    int out_width, out_height, out_channels;
     48    av1_find_cnn_output_size(image_width, image_height, cnn_config, &out_width,
     49                             &out_height, &out_channels);
     50 
     51    const int out_size = out_width * out_height;
     52    const int out_stride = out_width;
     53 
     54    float *output_ =
     55        (float *)aom_malloc(sizeof(*output_) * out_size * out_channels);
     56    ASSERT_NE(output_, nullptr);
     57    float *output[CNN_MAX_CHANNELS] = { nullptr };
     58    for (int channel = 0; channel < out_channels; ++channel) {
     59      output[channel] = output_ + (channel * out_size);
     60    }
     61    const int num_outputs = 1;
     62    const int output_chs[1] = { out_channels };
     63    const int output_strides[1] = { out_stride };
     64    CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_strides,
     65                                    output };
     66 
     67    RunMultiOutCNNTest(&input, image_width, image_height, in_stride, cnn_config,
     68                       thread_data, &output_struct, &expected, tolerance);
     69 
     70    aom_free(output_);
     71  }
     72 
     73  static void RunMultiOutCNNTest(const float **input, int image_width,
     74                                 int image_height, int in_stride,
     75                                 const CNN_CONFIG *cnn_config,
     76                                 CNN_THREAD_DATA *thread_data,
     77                                 CNN_MULTI_OUT *output, const float **expected,
     78                                 double tolerance) {
     79    const int num_outputs = output->num_outputs;
     80    const int *output_chs = output->output_channels;
     81 
     82    int *out_widths = (int *)aom_calloc(num_outputs, sizeof(*out_widths));
     83    int *out_heights = (int *)aom_calloc(num_outputs, sizeof(*out_heights));
     84    int *not_used = (int *)aom_calloc(num_outputs, sizeof(*not_used));
     85    ASSERT_NE(out_widths, nullptr);
     86    ASSERT_NE(out_heights, nullptr);
     87    ASSERT_NE(not_used, nullptr);
     88 
     89    av1_find_cnn_output_size(image_width, image_height, cnn_config, out_widths,
     90                             out_heights, not_used);
     91    ASSERT_TRUE(av1_cnn_predict(input, image_width, image_height, in_stride,
     92                                cnn_config, thread_data, output));
     93 
     94    int channel_offset = 0;
     95    for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
     96      const float *expected_out = expected[output_idx];
     97      const int curr_output_chs = output_chs[output_idx];
     98      const int out_size = out_widths[output_idx] * out_heights[output_idx];
     99 
    100      double mse = 0;
    101      int expected_ite = 0;
    102      for (int channel = 0; channel < curr_output_chs; ++channel) {
    103        const float *buf_out = output->output_buffer[channel_offset];
    104 
    105        for (int i = 0; i < out_size; ++i) {
    106          EXPECT_NEAR(expected_out[expected_ite], buf_out[i],
    107                      PIXELWISE_FLOAT_TOL)
    108              << " output " << output_idx << " channel " << channel << " pixel "
    109              << expected_ite % out_size << ": " << expected_out[expected_ite]
    110              << "/" << buf_out[i] << std::endl;
    111          mse += SQR(expected_out[expected_ite] - buf_out[i]);
    112          expected_ite++;
    113        }
    114 
    115        channel_offset++;
    116      }
    117      mse /= (out_size * curr_output_chs);
    118      EXPECT_LE(mse, tolerance) << " output " << output_idx << std::endl;
    119    }
    120 
    121    aom_free(out_widths);
    122    aom_free(out_heights);
    123    aom_free(not_used);
    124  }
    125 
    126  static void AssignLayerWeightsBiases(CNN_CONFIG *cnn_config, float *weights,
    127                                       float *bias) {
    128    size_t weight_offset = 0;
    129    size_t bias_offset = 0;
    130    for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
    131      CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
    132      layer_config->weights = weights + weight_offset;
    133      layer_config->bias = bias + bias_offset;
    134      weight_offset += layer_config->filter_width *
    135                       layer_config->filter_height * layer_config->in_channels *
    136                       layer_config->out_channels;
    137      bias_offset += layer_config->out_channels;
    138 
    139      ASSERT_NE(layer_config->weights, nullptr);
    140      ASSERT_NE(layer_config->bias, nullptr);
    141    }
    142  }
    143 };
    144 
    145 }  // namespace
    146 
    147 TEST_F(CNNTest, TestMultilayerConvolution) {
    148  int image_height = 16;
    149  int image_width = 16;
    150  int filter_height = 5;
    151  int filter_width = 4;
    152 
    153  float input[] = {
    154    -3, 1,  -3, 2,  -2, -2, 2,  -2, 1,  -2, -3, 1,  2,  2,  2,  -2, 0,  1,  -1,
    155    -3, -1, -1, 1,  0,  -3, 1,  0,  -1, 1,  0,  0,  -3, -3, -3, 0,  2,  1,  -1,
    156    2,  0,  1,  -3, -1, 2,  2,  1,  -2, 0,  -1, 0,  -2, -2, -1, 1,  0,  0,  0,
    157    -2, -2, -2, 1,  1,  -2, 1,  1,  -2, -2, 1,  -2, -1, -2, -3, 2,  -3, -1, 1,
    158    0,  -2, -2, -2, 1,  -2, -2, -1, -1, 2,  2,  2,  -1, 1,  -3, -3, 0,  2,  0,
    159    2,  1,  -3, -3, 1,  2,  2,  1,  -2, -3, 0,  -3, 0,  -3, -2, 0,  1,  1,  0,
    160    -3, 2,  -1, 2,  1,  0,  1,  -2, 1,  -1, -1, 2,  0,  -2, -3, 1,  1,  -2, -1,
    161    -3, -3, -1, 0,  -3, -2, 0,  0,  1,  0,  -3, -2, -1, 1,  0,  2,  1,  0,  -3,
    162    -2, -3, -3, -1, 0,  -2, 2,  -1, -3, 0,  -1, -1, 2,  0,  -3, -2, -1, 0,  0,
    163    1,  -2, 1,  2,  1,  2,  2,  -3, 2,  -1, 0,  0,  -1, 0,  2,  2,  -1, 2,  -2,
    164    1,  1,  -3, -3, 1,  -1, -1, -2, 2,  -2, -2, 2,  -1, -3, 2,  -3, 1,  -1, -1,
    165    -3, 1,  -1, 1,  0,  -3, -3, 1,  -3, -3, 0,  2,  2,  -2, -1, 2,  0,  2,  1,
    166    -1, -3, 0,  0,  -1, -1, 1,  0,  2,  0,  -3, 2,  1,  0,  1,  -3, 2,  -3, -3,
    167    -1, -3, -3, 2,  0,  2,  -2, 1,  -1,
    168  };
    169 
    170  float weights[] = {
    171    -2, 2,  -2, 2,  -1, -3, 2,  2,  0,  0,  -3, -1, -2, -3, 1,  -1, 0,  0,  0,
    172    2,  -2, 2,  -2, -3, 1,  1,  1,  -3, -1, 0,  1,  2,  -2, 0,  -1, -3, -1, -2,
    173    2,  -3, -3, 1,  -2, -3, 0,  2,  1,  -3, -3, -1, -3, -2, -1, -3, -1, -3, -2,
    174    -1, -3, -1, -2, -2, -3, 2,  0,  -3, 0,  -3, -3, 1,  -3, -1, 0,  -1, 1,  1,
    175    -1, 1,  -2, 0,  2,  0,  -3, 1,  -1, -1, 2,  0,  1,  -3, -3, 1,  2,  -3, -3,
    176    1,  -3, 2,  0,  -3, 1,  2,  2,  -2, -1, -2, 1,  1,  0,  -2, -2, 1,  2,  -1,
    177    -3, 1,  -2, 2,  -3, -2, -3, 2,  1,  0,  -2, 0,  1,  -3, 2,  -2, -2, 0,  2,
    178    -3, 2,  0,  0,  1,  -2, 1,  1,  -2, -1, -2, 1,  -2, 0,  -2, -2, 0,  -1, -1,
    179    -3, -3, -3, 1,  -3, -2, 2,  -1, 2,  0,  2,  -2, 2,  -2, 1,  -3, -3, -1, 0,
    180    2,  2,  1,  -1, -3, -1, -3, 2,  1,  -2, 0,  -3, -1, -3, -1, 2,  1,  0,  2,
    181    -1, 1,  0,  1,  2,  -1, -2, 2,  1,  -3, -1, -3, 0,  1,  -2, 0,  -2, -3, 0,
    182    -2, 2,  2,  0,  0,  2,  -3, 2,  -3, -2, 1,  2,  -3, -3, -1, -3, 0,  -3, -3,
    183    -2, -2, -2, 0,  0,  1,  0,  0,  -1, 0,  0,  -3, 0,  -3, -1, -2, 1,  -2, -1,
    184    2,  -2, 0,  0,  1,  0,  -2, -1, 0,  -3, 1,  0,  -1, -3, 1,  -1, 1,  -1, -3,
    185    1,  0,  1,  1,  -1, 2,  2,  0,  0,  1,  -3, 2,  -2, -2, -3, -2, -1, -2, 2,
    186    0,  2,  -2, -3, -1, -3, 2,  2,  -1, 2,  2,  -1, 0,  -3, 1,
    187  };
    188 
    189  float bias[] = {
    190    1, -1, 0, 1, 1, 1, -2,
    191  };
    192 
    193  float expected_same[] = {
    194    -1125, 2926,  6406,  631,   -1244, 97,    -1454, 2526,  1065,  3292,  3464,
    195    2553,  -330,  532,   1038,  1182,  -402,  3758,  3392,  9854,  4365,  1408,
    196    4736,  3134,  3838,  2409,  3221,  4350,  6750,  4045,  815,   1188,  2959,
    197    9802,  9590,  4572,  5740,  4253,  1701,  7974,  7012,  6854,  7093,  3907,
    198    4539,  3886,  4267,  3505,  465,   7824,  9219,  10026, 7968,  957,   2295,
    199    5594,  10811, 9641,  5950,  10043, 8783,  3132,  1421,  1110,  4108,  13929,
    200    10660, -84,   -61,   3932,  -180,  6811,  13393, 15147, 15640, 9337,  6961,
    201    3808,  1604,  1398,  1047,  6739,  10144, 6517,  4698,  2678,  7389,  2595,
    202    5248,  12075, 11272, 13951, 8820,  1090,  2199,  2206,  2788,  12116, 6683,
    203    2612,  -291,  3183,  9414,  12316, 14524, 12333, 13208, 7832,  4664,  4657,
    204    3534,  1298,  -666,  4250,  7707,  9103,  5760,  688,   9571,  15782, 14203,
    205    14878, 17339, 14684, 8690,  5671,  875,   1429,  1531,  6173,  2984,  5558,
    206    2996,  7928,  6733,  16117, 15262, 12757, 7980,  3923,  4795,  5973,  2051,
    207    455,   -1922, 1816,  5906,  3321,  10908, 10910, 7377,  12204, 12809, 11195,
    208    7451,  6666,  74,    -1645, -35,   -391,  3813,  7324,  892,   1656,  6095,
    209    12193, 14648, 12156, 14663, 10251, 10325, 7821,  3925,  323,   697,   442,
    210    1324,  4669,  7002,  5485,  5171,  5086,  10582, 11053, 9709,  11353, 8543,
    211    5256,  2873,  235,   -628,  1496,  1878,  -867,  3420,  6865,  5937,  10182,
    212    13277, 10069, 10789, 5998,  624,   -2082, 4417,  1258,  -1080, -819,  -1430,
    213    1033,  5220,  6335,  8471,  8980,  11908, 14430, 12584, 8404,  1576,  -803,
    214    985,   1481,  1367,  -193,  873,   3684,  2288,  6676,  9477,  11155, 9602,
    215    9707,  10507, 4739,  3174,  -575,  -178,  3002,  1710,  423,   -477,  554,
    216    3088,  2029,  5113,  5000,  3771,  6090,  5365,  1185,  2855,  399,   -312,
    217    -1577, 176,   955,
    218  };
    219 
    220  float expected_replicate[] = {
    221    13768, 13528, 12999, 6906,  4618,  4043,  2611,  9955,  6685,  4776,  2753,
    222    1036,  3063,  4544,  5183,  7349,  12451, 12501, 9131,  12753, 8908,  4058,
    223    6299,  7542,  7115,  3307,  3360,  3543,  9754,  7808,  5991,  9019,  14320,
    224    14919, 12492, 6871,  7373,  3336,  2085,  10604, 9377,  6882,  5009,  3103,
    225    6220,  6278,  7588,  10196, 11045, 11563, 11842, 11911, 8279,  2030,  1858,
    226    6368,  12123, 9909,  6347,  10345, 9365,  4038,  1673,  3051,  16492, 16649,
    227    12276, 408,   -301,  4122,  -654,  7864,  14038, 15279, 15315, 9744,  8243,
    228    5298,  746,   380,   9824,  9124,  10895, 6640,  4712,  2669,  6980,  2759,
    229    5385,  12345, 11336, 13129, 8600,  2370,  3682,  5219,  12407, 13123, 6784,
    230    2612,  -291,  3183,  9414,  12316, 14524, 12333, 13397, 7543,  3916,  4153,
    231    4477,  4314,  7983,  8418,  9163,  9103,  5760,  688,   9571,  15782, 14203,
    232    14878, 17718, 14570, 7940,  6642,  5094,  7133,  9964,  10219, 3224,  5558,
    233    2996,  7928,  6733,  16117, 15262, 12757, 7958,  4401,  5187,  5476,  5529,
    234    6055,  2206,  3909,  6015,  3321,  10908, 10910, 7377,  12204, 12809, 11195,
    235    6967,  6840,  481,   -1600, 274,   1,     10373, 8514,  1123,  2117,  6758,
    236    12736, 16223, 13585, 15988, 11771, 10600, 7918,  4156,  2840,  3111,  3287,
    237    6359,  7652,  8813,  6530,  6967,  7789,  13671, 13990, 13247, 13241, 9836,
    238    5251,  3024,  2313,  1834,  4187,  2637,  -1312, 2139,  7378,  7665,  11933,
    239    15591, 15314, 15678, 9531,  2820,  -1516, 3400,  1314,  22,    363,   -2896,
    240    -898,  5906,  7308,  10650, 12975, 16978, 20370, 18817, 12381, 4118,  -861,
    241    -137,  236,   1802,  1632,  -350,  2334,  3400,  8680,  14064, 18216, 18675,
    242    21765, 22871, 11491, 4937,  -1555, -11,   1669,  2392,  3265,  -5254, -217,
    243    5001,  8063,  13444, 18884, 19706, 22794, 21064, 9545,  6689,  -7,    289,
    244    -2021, 504,   2347,
    245  };
    246 
    247  float expected_valid[] = {
    248    2612,  -291,  3183,  9414,  12316, 14524, 12333, 9103,  5760,  688,
    249    9571,  15782, 14203, 14878, 5558,  2996,  7928,  6733,  16117, 15262,
    250    12757, 3321,  10908, 10910, 7377,  12204, 12809, 11195,
    251  };
    252 
    253  CNN_CONFIG cnn_config = { 3,
    254                            0,
    255                            0,
    256                            0,
    257                            0,
    258                            {
    259                                {
    260                                    1,
    261                                    filter_width,
    262                                    filter_height,
    263                                    3,
    264                                    1,
    265                                    1,
    266                                    0,
    267                                    nullptr,
    268                                    nullptr,
    269                                    PADDING_SAME_ZERO,
    270                                    NONE,
    271                                    0,
    272                                    0,
    273                                    BRANCH_NO_COPY,
    274                                    BRANCH_NOC,
    275                                    {},
    276                                    {},
    277                                    -1,
    278                                },
    279                                {
    280                                    3,
    281                                    filter_width,
    282                                    filter_height,
    283                                    3,
    284                                    1,
    285                                    1,
    286                                    0,
    287                                    nullptr,
    288                                    nullptr,
    289                                    PADDING_SAME_ZERO,
    290                                    NONE,
    291                                    0,
    292                                    0,
    293                                    BRANCH_NO_COPY,
    294                                    BRANCH_NOC,
    295                                    {},
    296                                    {},
    297                                    -1,
    298                                },
    299                                {
    300                                    3,
    301                                    filter_width,
    302                                    filter_height,
    303                                    1,
    304                                    1,
    305                                    1,
    306                                    0,
    307                                    nullptr,
    308                                    nullptr,
    309                                    PADDING_SAME_ZERO,
    310                                    NONE,
    311                                    0,
    312                                    0,
    313                                    BRANCH_NO_COPY,
    314                                    BRANCH_NOC,
    315                                    {},
    316                                    {},
    317                                    0,
    318                                },
    319                            } };
    320 
    321  // Weights and biases need to be specified separately because
    322  // of the offset.
    323  AssignLayerWeightsBiases(&cnn_config, weights, bias);
    324 
    325  CNN_THREAD_DATA thread_data = { 1, nullptr };
    326 
    327  RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
    328             image_width, &thread_data, MSE_INT_TOL);
    329 
    330  for (int i = 0; i < cnn_config.num_layers; ++i) {
    331    cnn_config.layer_config[i].pad = PADDING_SAME_REPLICATE;
    332  }
    333 
    334  RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
    335             image_width, &thread_data, MSE_INT_TOL);
    336 
    337  for (int i = 0; i < cnn_config.num_layers; ++i) {
    338    cnn_config.layer_config[i].pad = PADDING_VALID;
    339  }
    340 
    341  RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
    342             image_width, &thread_data, MSE_INT_TOL);
    343 }
    344 
    345 TEST_F(CNNTest, TestRELUSingleLayer) {
    346  int image_width = 8;
    347  int image_height = 8;
    348  int filter_height = 5;
    349  int filter_width = 4;
    350  float input[] = {
    351    0, -2, -3, 1,  -1, 2,  -2, 1,  -3, -1, 0,  1,  -2, -3, -2, -2,
    352    1, -3, 2,  -3, -1, -1, 2,  0,  -2, -3, 0,  -2, -3, 1,  -1, -1,
    353    2, -2, 0,  -2, -3, -3, 1,  1,  -1, 1,  0,  1,  -3, 0,  2,  2,
    354    0, -3, 1,  -3, 2,  -2, 1,  -1, -1, -2, -3, -2, -1, -3, -2, -1,
    355  };
    356  float expected_same[] = {
    357    9,  0,  1,  1,  0,  3,  0,  19, 0,  12, 10, 0,  0,  0,  5, 0,
    358    0,  18, 21, 7,  19, 4,  3,  0,  0,  9,  16, 0,  11, 16, 0, 11,
    359    12, 2,  0,  11, 0,  16, 6,  0,  8,  22, 13, 10, 12, 0,  0, 0,
    360    0,  1,  2,  12, 29, 6,  10, 0,  13, 0,  0,  5,  8,  10, 0, 0,
    361  };
    362  float expected_replicate[] = {
    363    18, 17, 12, 2,  0,  0,  5,  11, 0,  17, 22, 6,  0,  0,  17, 0,
    364    0,  18, 21, 7,  19, 4,  3,  5,  3,  9,  16, 0,  11, 16, 0,  3,
    365    3,  2,  0,  11, 0,  16, 6,  0,  17, 22, 13, 10, 12, 0,  0,  0,
    366    0,  4,  1,  10, 30, 7,  10, 0,  23, 8,  0,  13, 15, 19, 8,  10,
    367  };
    368  float expected_valid[] = {
    369    18, 21, 7, 19, 4, 9, 16, 0, 11, 16, 2, 0, 11, 0, 16, 22, 13, 10, 12, 0,
    370  };
    371  float weights[] = {
    372    -2, -3, 1, 2, 2, -2, -3, 0, -3, 2, 2, -3, -3, -2, 0, 1, 2, 0, -1, -1,
    373  };
    374  float bias[] = { -3 };
    375 
    376  CNN_CONFIG cnn_config = { 1,
    377                            0,
    378                            0,
    379                            0,
    380                            0,
    381                            { {
    382                                1,
    383                                filter_width,
    384                                filter_height,
    385                                1,
    386                                1,
    387                                1,
    388                                0,
    389                                weights,
    390                                bias,
    391                                PADDING_SAME_ZERO,
    392                                RELU,
    393                                0,
    394                                0,
    395                                BRANCH_NO_COPY,
    396                                BRANCH_NOC,
    397                                {},
    398                                {},
    399                                0,
    400                            } } };
    401 
    402  CNN_THREAD_DATA thread_data = { 1, nullptr };
    403 
    404  RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
    405             image_width, &thread_data, MSE_INT_TOL);
    406 
    407  cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
    408 
    409  RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
    410             image_width, &thread_data, MSE_INT_TOL);
    411 
    412  cnn_config.layer_config[0].pad = PADDING_VALID;
    413 
    414  RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
    415             image_width, &thread_data, MSE_INT_TOL);
    416 }
    417 
    418 TEST_F(CNNTest, TestVaryingStridesVaryingDimImages) {
    419  float weights[] = {
    420    1,  -5, -3, -4, -1, 1,  2,  -3, 2,  2,  -1, 1,  -5, 1,  1,
    421    -3, -5, 3,  1,  4,  -2, -5, -2, -3, -5, 0,  -1, -5, 2,  -2,
    422    -2, 1,  -2, -4, 1,  3,  -2, 2,  0,  -3, 2,  -3, -2, -3,
    423  };
    424  float bias[] = { 2 };
    425 
    426  CNN_CONFIG cnn_config = { 1,
    427                            0,
    428                            0,
    429                            0,
    430                            0,
    431                            {
    432                                {
    433                                    1,
    434                                    4,
    435                                    11,
    436                                    1,
    437                                    7,
    438                                    6,
    439                                    0,
    440                                    weights,
    441                                    bias,
    442                                    PADDING_SAME_ZERO,
    443                                    NONE,
    444                                    0,
    445                                    0,
    446                                    BRANCH_NO_COPY,
    447                                    BRANCH_NOC,
    448                                    {},
    449                                    {},
    450                                    0,
    451                                },
    452                            } };
    453 
    454  int image_height = 24;
    455  int image_width = 17;
    456  float input[] = {
    457    -1, -3, 4,  4,  -5, 4,  3,  -5, -1, -3, 4,  -4, 2,  -3, 3,  -5, 2,  -1, -5,
    458    1,  -1, 3,  1,  -3, -3, 4,  0,  2,  -3, -5, -5, -4, 0,  -5, -2, -3, -1, -2,
    459    2,  -5, 4,  4,  0,  -4, -3, 1,  -3, -5, -4, -4, 1,  -2, -3, 3,  -3, -3, -1,
    460    -5, -5, -2, 3,  1,  -1, -5, -5, 1,  -4, -2, -1, -2, -4, -4, 2,  -2, 2,  1,
    461    -2, -4, -1, 1,  -2, -5, 3,  -2, -1, -1, -5, -3, 1,  -2, -2, -3, -1, -2, -4,
    462    -2, 1,  -4, -1, 4,  3,  -4, 0,  4,  2,  2,  4,  -3, -5, 2,  2,  1,  -1, -4,
    463    -2, 1,  3,  2,  0,  4,  -1, -3, 2,  1,  -4, 2,  2,  -4, -2, 0,  -2, -1, 4,
    464    4,  2,  3,  -4, 2,  -4, -5, 4,  -1, -3, -1, 0,  -4, 1,  3,  -1, -3, -5, 3,
    465    -2, -4, 1,  2,  -2, -3, -3, -5, 1,  -3, -1, 0,  -1, 3,  -4, -1, -5, -5, 1,
    466    0,  0,  -2, -2, 2,  -2, 0,  0,  2,  0,  -3, 0,  -1, -4, -4, -1, 3,  -4, -4,
    467    -1, 0,  -5, -3, -2, 4,  -3, -4, -4, 0,  -5, 1,  -2, -3, -3, -4, 4,  3,  4,
    468    3,  3,  -1, 3,  1,  -3, -2, 3,  3,  0,  2,  -4, -3, 2,  2,  0,  -2, 4,  -2,
    469    2,  -2, -1, -4, -2, 2,  -4, 3,  -1, 4,  1,  1,  4,  -1, -4, -4, 1,  1,  -2,
    470    4,  -1, 3,  2,  -3, 4,  3,  1,  4,  0,  -4, 2,  0,  2,  4,  -2, -2, 4,  2,
    471    -1, -2, 1,  -3, 2,  3,  -5, -3, 4,  4,  2,  -5, -4, -5, -2, -4, 2,  0,  2,
    472    -5, 4,  -4, -2, -5, 2,  1,  0,  4,  1,  -2, -3, -4, -3, -4, 3,  3,  2,  0,
    473    -3, 1,  -5, 4,  0,  4,  -1, 3,  -5, -5, -2, -1, -1, 4,  3,  3,  4,  3,  -4,
    474    4,  -3, -3, -1, -4, -1, -4, -1, -2, 4,  -2, -4, 4,  4,  -3, -4, -1, 1,  2,
    475    -1, -2, -2, 3,  2,  2,  -3, 0,  -1, 0,  3,  2,  -5, 0,  -4, 0,  0,  2,  -4,
    476    -1, -1, 0,  -2, 0,  1,  0,  0,  4,  -5, -1, -5, 2,  -1, 0,  2,  -1, 1,  3,
    477    -3, -5, -2, -3, 4,  -2, -2, -1, -3, -4, -1, -2, -4, 1,  4,  -3, -2, -1, 3,
    478    -3, -2, 3,  2,  1,  -4, -3, -5, 1,
    479  };
    480  float expected_1[] = {
    481    41, -26, 5, 76, 13, 83, -21, 53, -54, -14, 21, 121,
    482  };
    483 
    484  CNN_THREAD_DATA thread_data = { 1, nullptr };
    485 
    486  RunCNNTest(image_width, image_height, input, expected_1, &cnn_config,
    487             image_width, &thread_data, MSE_INT_TOL);
    488 
    489  cnn_config.layer_config[0].skip_width = 6;
    490  cnn_config.layer_config[0].skip_height = 7;
    491 
    492  float expected_2[] = {
    493    21, -50, 41, 20, 72, 127, -21, 103, 62, -37, 83, -3,
    494  };
    495  RunCNNTest(image_width, image_height, input, expected_2, &cnn_config,
    496             image_width, &thread_data, MSE_INT_TOL);
    497 
    498  cnn_config.layer_config[0].skip_width = 3;
    499  cnn_config.layer_config[0].skip_height = 10;
    500 
    501  float expected_3[] = {
    502    -26, -21, -35, 69, 49,  4,  -51, -43, -56,
    503    -41, 15,  -44, 40, -62, 63, 38,  27,  47,
    504  };
    505  RunCNNTest(image_width, image_height, input, expected_3, &cnn_config,
    506             image_width, &thread_data, MSE_INT_TOL);
    507 
    508  cnn_config.layer_config[0].skip_width = 10;
    509  cnn_config.layer_config[0].skip_height = 3;
    510 
    511  float expected_4[] = {
    512    21, 49, 28, 87, 50, 40, 102, 81, 58, 85, 51, 66, 36, 19, -37, -45,
    513  };
    514 
    515  RunCNNTest(image_width, image_height, input, expected_4, &cnn_config,
    516             image_width, &thread_data, MSE_INT_TOL);
    517 }
    518 
    519 TEST_F(CNNTest, TestMaxPool) {
    520  int image_width = 8;
    521  int image_height = 8;
    522  int stride = 3;
    523  float input[] = {
    524    1,  -4, -4, 8, 0, 7, -5, -2, 8, 2, 2, 8,  5,  -1, -1, 9,
    525    -3, 0,  -2, 0, 6, 3, -4, 8,  7, 8, 7, -1, 4,  -1, 0,  2,
    526    -5, -2, 8,  5, 5, 4, 2,  7,  4, 6, 2, 8,  8,  -4, -3, -4,
    527    -3, -1, 2,  3, 3, 6, -5, 8,  9, 5, 0, -2, -1, 6,  5,  7,
    528  };
    529 
    530  float expected[] = {
    531    49, 58, 70, 68, 68, 70, 48, 57, 88,
    532  };
    533 
    534  float weights[] = {
    535    3, 1, 3, 4, -1, 5, -2, 1, -4,
    536  };
    537 
    538  float bias[] = {
    539    -3,
    540  };
    541 
    542  CNN_CONFIG cnn_config = { 1,
    543                            0,
    544                            0,
    545                            0,
    546                            0,
    547                            { {
    548                                1,
    549                                3,
    550                                3,
    551                                1,
    552                                stride,
    553                                stride,
    554                                1,
    555                                weights,
    556                                bias,
    557                                PADDING_SAME_ZERO,
    558                                NONE,
    559                                0,
    560                                0,
    561                                BRANCH_NO_COPY,
    562                                BRANCH_NOC,
    563                                {},
    564                                {},
    565                                0,
    566                            } } };
    567 
    568  CNN_THREAD_DATA thread_data = { 1, nullptr };
    569 
    570  RunCNNTest(image_width, image_height, input, expected, &cnn_config,
    571             image_width, &thread_data, MSE_INT_TOL);
    572 }
    573 
    574 TEST_F(CNNTest, TestDeconvolveNonActivationSingleLayerSingleKernel) {
    575  int image_width = 4;
    576  int image_height = 7;
    577  float input[] = {
    578    9,  6,   181, 9,  218, 30, 80,  108, 68,  216, 70, 128, 179, 228,
    579    33, 212, 34,  14, 48,  27, 230, 23,  202, 113, 80, 56,  122, 112,
    580  };
    581 
    582  float expected_1_same[] = {
    583    15,   -30,  36,   -525,  377, -193, 558, 531,  6,   -24,  -15,  124,
    584    166,  -561, -356, -754,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
    585    433,  -311, 711,  381,   247, -317, 453, 129,  215, -627, -409, -885,
    586    17,   -255, -55,  -647,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
    587    133,  -719, 633,  -225,  785, 191,  463, 79,   65,  9,    77,   -853,
    588    -365, -949, -15,  -667,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
    589    355,  -866, 990,  207,   747, 12,   520, -116, 176, -312, -133, -1370,
    590    -426, -802, 143,  -771,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
    591    65,   -79,  127,  -59,   135, -90,  195, 114,  31,  -91,  -57,  -133,
    592    17,   -176, -72,  -276,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
    593    457,  -302, 733,  58,    470, -475, 829, 490,  227, -670, -440, -790,
    594    153,  -588, -294, -1150, -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
    595    157,  -251, 349,  -185,  409, -293, 587, 251,  77,  -187, -107, -369,
    596    7,    -481, -135, -827,  -3,  -3,   -3,  -3,   -3,  -3,   -3,   -3,
    597  };
    598  float expected_1_valid[] = {
    599    -30,  15,   -30,  36,   -525,  377,  -193,  558,  531,  24,   24,   6,
    600    6,    -24,  -15,  124,  166,   -561, -356,  -754, -21,  -39,  -3,   -3,
    601    -3,   -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -657, 433,  -311,
    602    711,  381,  247,  -317, 453,   129,  321,   321,  215,  215,  -627, -409,
    603    -885, 17,   -255, -55,  -647,  -219, -435,  -3,   -3,   -3,   -3,   -3,
    604    -3,   -3,   -3,   -3,   -3,    -3,   -207,  133,  -719, 633,  -225, 785,
    605    191,  463,  79,   381,  381,   65,   65,    9,    77,   -853, -365, -949,
    606    -15,  -667, -259, -515, -3,    -3,   -3,    -3,   -3,   -3,   -3,   -3,
    607    -3,   -3,   -3,   -540, 355,   -866, 990,   207,  747,  12,   520,  -116,
    608    633,  633,  176,  176,  -312,  -133, -1370, -426, -802, 143,  -771, -427,
    609    -851, -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -3,   -3,   -3,
    610    -105, 65,   -79,  127,  -59,   135,  -90,   195,  114,  78,   78,   31,
    611    31,   -91,  -57,  -133, 17,    -176, -72,   -276, -57,  -111, -3,   -3,
    612    -3,   -3,   -3,   -3,   -3,    -3,   -3,    -3,   -3,   -693, 457,  -302,
    613    733,  58,   470,  -475, 829,   490,  336,   336,  227,  227,  -670, -440,
    614    -790, 153,  -588, -294, -1150, -229, -455,  -3,   -3,   -3,   -3,   -3,
    615    -3,   -3,   -3,   -3,   -3,    -3,   -243,  157,  -251, 349,  -185, 409,
    616    -293, 587,  251,  333,  333,   77,   77,    -187, -107, -369, 7,    -481,
    617    -135, -827, -227, -451,
    618  };
    619  float weights_1[] = { -3, 2, -1, 3, 3, 1, 1, -3, -2, -4 };
    620  float bias_1[] = { -3 };
    621 
    622  CNN_CONFIG cnn_config = { 1,
    623                            0,
    624                            0,
    625                            0,
    626                            0,
    627                            { {
    628                                1,
    629                                5,
    630                                2,
    631                                1,
    632                                2,
    633                                3,
    634                                0,
    635                                weights_1,
    636                                bias_1,
    637                                PADDING_SAME_ZERO,
    638                                NONE,
    639                                1,
    640                                0,
    641                                BRANCH_NO_COPY,
    642                                BRANCH_NOC,
    643                                {},
    644                                {},
    645                                0,
    646                            } } };
    647 
    648  CNN_THREAD_DATA thread_data = { 1, nullptr };
    649 
    650  RunCNNTest(image_width, image_height, input, expected_1_same, &cnn_config,
    651             image_width, &thread_data, MSE_INT_TOL);
    652 
    653  // Change padding to valid
    654  cnn_config.layer_config[0].pad = PADDING_VALID;
    655 
    656  RunCNNTest(image_width, image_height, input, expected_1_valid, &cnn_config,
    657             image_width, &thread_data, MSE_INT_TOL);
    658 
    659  float expected_12_same[] = {
    660    15,  -12,  6,    36,   -9,   -528, 377,  -184, 513,  558,  -12,  24,
    661    6,   -30,  -15,  -33,  -21,  166,  154,  -546, -356, -718, -30,  -21,
    662    433, -221, 561,  711,  -33,  -153, 247,  -83,  -87,  453,  -111, 321,
    663    215, -657, -409, -845, -93,  17,   -43,  -243, -55,  -215, -327, -219,
    664    133, -71,  -447, 633,  -219, 435,  785,  -73,  -177, 463,  -131, 381,
    665    65,  -207, 77,   -59,  -651, -365, -797, -213, -15,  -155, -387, -259,
    666    355, -182, -150, 990,  -231, 582,  747,  -36,  -540, 520,  -215, 633,
    667    176, -540, -133, -491, -687, -426, -882, -102, 143,  77,   -639, -427,
    668    65,  -37,  57,   127,  -17,  -105, 135,  -51,  60,   195,  -30,  78,
    669    31,  -105, -57,  -125, -45,  17,   -11,  -147, -72,  -168, -84,  -57,
    670    457, -233, 618,  733,  -26,  -540, 470,  -205, 264,  829,  -116, 336,
    671    227, -693, -440, -900, -72,  153,  107,  -609, -294, -698, -342, -229,
    672    157, -83,  69,   349,  -59,  -201, 409,  -125, 27,   587,  -115, 333,
    673    77,  -243, -107, -267, -171, 7,    -105, -369, -135, -379, -339, -227,
    674  };
    675  float expected_12_valid[] = {
    676    -30,  15,   -12,  6,    36,   -9,   -528, 377,  -184, 513,  558,  -12,
    677    24,   24,   6,    6,    -30,  -15,  -33,  -21,  166,  154,  -546, -356,
    678    -718, -30,  -21,  -39,  -657, 433,  -221, 561,  711,  -33,  -153, 247,
    679    -83,  -87,  453,  -111, 321,  321,  215,  215,  -657, -409, -845, -93,
    680    17,   -43,  -243, -55,  -215, -327, -219, -435, -207, 133,  -71,  -447,
    681    633,  -219, 435,  785,  -73,  -177, 463,  -131, 381,  381,  65,   65,
    682    -207, 77,   -59,  -651, -365, -797, -213, -15,  -155, -387, -259, -515,
    683    -540, 355,  -182, -150, 990,  -231, 582,  747,  -36,  -540, 520,  -215,
    684    633,  633,  176,  176,  -540, -133, -491, -687, -426, -882, -102, 143,
    685    77,   -639, -427, -851, -105, 65,   -37,  57,   127,  -17,  -105, 135,
    686    -51,  60,   195,  -30,  78,   78,   31,   31,   -105, -57,  -125, -45,
    687    17,   -11,  -147, -72,  -168, -84,  -57,  -111, -693, 457,  -233, 618,
    688    733,  -26,  -540, 470,  -205, 264,  829,  -116, 336,  336,  227,  227,
    689    -693, -440, -900, -72,  153,  107,  -609, -294, -698, -342, -229, -455,
    690    -243, 157,  -83,  69,   349,  -59,  -201, 409,  -125, 27,   587,  -115,
    691    333,  333,  77,   77,   -243, -107, -267, -171, 7,    -105, -369, -135,
    692    -379, -339, -227, -451,
    693  };
    694 
    695  // Change skip_width, skip_height to {2, 3}
    696  cnn_config.layer_config[0].skip_width = 3;
    697  cnn_config.layer_config[0].skip_height = 2;
    698  // Set padding to same
    699  cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
    700 
    701  RunCNNTest(image_width, image_height, input, expected_12_same, &cnn_config,
    702             image_width, &thread_data, MSE_INT_TOL);
    703 
    704  // Change padding to valid
    705  cnn_config.layer_config[0].pad = PADDING_VALID;
    706  RunCNNTest(image_width, image_height, input, expected_12_valid, &cnn_config,
    707             image_width, &thread_data, MSE_INT_TOL);
    708 
    709  cnn_config.layer_config[0].filter_width = 4;
    710  cnn_config.layer_config[0].filter_height = 3;
    711  float weights_2[] = { -1, -3, -1, -3, 0, 2, -2, 4, 3, 0, 1, 4 };
    712  float bias_2[] = { -4 };
    713  cnn_config.layer_config[0].weights = weights_2;
    714  cnn_config.layer_config[0].bias = bias_2;
    715 
    716  cnn_config.layer_config[0].skip_width = 5;
    717  cnn_config.layer_config[0].skip_height = 2;
    718  float expected_2_same[] = {
    719    -13,  -31,  -13,  -31,  -4,   -10,  -22,  -10,  -22,  -4,   -185, -547,
    720    -185, -547, -4,   -13,  -31,  -13,  -31,  -4,   -4,   14,   -22,  32,
    721    -4,   -4,   8,    -16,  20,   -4,   -4,   358,  -366, 720,  -4,   -4,
    722    14,   -22,  32,   -4,   -195, -658, -213, -622, -4,   -16,  -94,  -28,
    723    -70,  -4,   459,  -244, 97,   480,  -4,   -85,  -328, -103, -292, -4,
    724    -4,   432,  -440, 868,  -4,   -4,   56,   -64,  116,  -4,   -4,   156,
    725    -164, 316,  -4,   -4,   212,  -220, 428,  -4,   582,  -208, 146,  664,
    726    -4,   -130, -652, -190, -532, -4,   166,  -214, 6,    106,  -4,   192,
    727    -388, -24,  44,   -4,   -4,   132,  -140, 268,  -4,   -4,   428,  -436,
    728    860,  -4,   -4,   136,  -144, 276,  -4,   -4,   252,  -260, 508,  -4,
    729    21,   -541, -115, -269, -4,   416,  -688, -16,  176,  -4,   173,  -103,
    730    33,   177,  -4,   168,  -640, -88,  -128, -4,   -4,   354,  -362, 712,
    731    -4,   -4,   452,  -460, 908,  -4,   -4,   62,   -70,  128,  -4,   -4,
    732    420,  -428, 844,  -4,   499,  -106, 141,  610,  -4,   666,  -46,  210,
    733    866,  -4,   47,   -148, -19,  -16,  -4,   605,  -85,  181,  763,  -4,
    734    -4,   64,   -72,  132,  -4,   -4,   24,   -32,  52,   -4,   -4,   92,
    735    -100, 188,  -4,   -4,   50,   -58,  104,  -4,   -132, -694, -200, -558,
    736    -4,   15,   -73,  -13,  -17,  -4,   -62,  -610, -158, -418, -4,   -36,
    737    -343, -90,  -235, -4,   -4,   456,  -464, 916,  -4,   -4,   42,   -50,
    738    88,   -4,   -4,   400,  -408, 804,  -4,   -4,   222,  -230, 448,  -4,
    739    606,  -244, 146,  676,  -4,   9,    -172, -37,  -80,  -4,   480,  -370,
    740    76,   438,  -4,   223,  -340, -3,   112,  -4,   -4,   156,  -164, 316,
    741    -4,   -4,   108,  -116, 220,  -4,   -4,   240,  -248, 484,  -4,   -4,
    742    220,  -228, 444,  -4,
    743  };
    744  float expected_2_valid[] = {
    745    -13,  -31,  -13,  -31,  -4,   -10,  -22,  -10,  -22,  -4,   -185, -547,
    746    -185, -547, -4,   -13,  -31,  -13,  -31,  -4,   14,   -22,  32,   -4,
    747    -4,   8,    -16,  20,   -4,   -4,   358,  -366, 720,  -4,   -4,   14,
    748    -22,  32,   -195, -658, -213, -622, -4,   -16,  -94,  -28,  -70,  -4,
    749    459,  -244, 97,   480,  -4,   -85,  -328, -103, -292, -4,   432,  -440,
    750    868,  -4,   -4,   56,   -64,  116,  -4,   -4,   156,  -164, 316,  -4,
    751    -4,   212,  -220, 428,  582,  -208, 146,  664,  -4,   -130, -652, -190,
    752    -532, -4,   166,  -214, 6,    106,  -4,   192,  -388, -24,  44,   -4,
    753    132,  -140, 268,  -4,   -4,   428,  -436, 860,  -4,   -4,   136,  -144,
    754    276,  -4,   -4,   252,  -260, 508,  21,   -541, -115, -269, -4,   416,
    755    -688, -16,  176,  -4,   173,  -103, 33,   177,  -4,   168,  -640, -88,
    756    -128, -4,   354,  -362, 712,  -4,   -4,   452,  -460, 908,  -4,   -4,
    757    62,   -70,  128,  -4,   -4,   420,  -428, 844,  499,  -106, 141,  610,
    758    -4,   666,  -46,  210,  866,  -4,   47,   -148, -19,  -16,  -4,   605,
    759    -85,  181,  763,  -4,   64,   -72,  132,  -4,   -4,   24,   -32,  52,
    760    -4,   -4,   92,   -100, 188,  -4,   -4,   50,   -58,  104,  -132, -694,
    761    -200, -558, -4,   15,   -73,  -13,  -17,  -4,   -62,  -610, -158, -418,
    762    -4,   -36,  -343, -90,  -235, -4,   456,  -464, 916,  -4,   -4,   42,
    763    -50,  88,   -4,   -4,   400,  -408, 804,  -4,   -4,   222,  -230, 448,
    764    606,  -244, 146,  676,  -4,   9,    -172, -37,  -80,  -4,   480,  -370,
    765    76,   438,  -4,   223,  -340, -3,   112,  -4,   156,  -164, 316,  -4,
    766    -4,   108,  -116, 220,  -4,   -4,   240,  -248, 484,  -4,   -4,   220,
    767    -228, 444,  236,  -4,   76,   316,  -4,   164,  -4,   52,   220,  -4,
    768    362,  -4,   118,  484,  -4,   332,  -4,   108,  444,
    769  };
    770  // Set padding to same
    771  cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
    772 
    773  RunCNNTest(image_width, image_height, input, expected_2_same, &cnn_config,
    774             image_width, &thread_data, MSE_INT_TOL);
    775 
    776  cnn_config.layer_config[0].pad = PADDING_VALID;
    777 
    778  RunCNNTest(image_width, image_height, input, expected_2_valid, &cnn_config,
    779             image_width, &thread_data, MSE_INT_TOL);
    780 
    781  cnn_config.layer_config[0].skip_width = 2;
    782  cnn_config.layer_config[0].skip_height = 5;
    783  float expected_21_same[] = {
    784    -31,  -19,  -49,   -191, -565, -194, -574, -13,  14,   -22,  44,   -16,
    785    382,  -366, 738,   -22,  -4,   23,   32,   545,  20,   204,  720,  5,
    786    -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
    787    -4,   -4,   -4,    -4,   -658, -252, -748, -114, -334, -192, -568, -112,
    788    432,  -440, 928,   -64,  276,  -164, 532,  -220, -4,   304,  868,  266,
    789    116,  400,  316,   104,  -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
    790    -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -208, -288, -856, -290,
    791    -862, -202, -598,  -132, 132,  -140, 700,  -436, 1000, -144, 532,  -260,
    792    -4,   712,  268,   422,  860,  450,  276,  124,  -4,   -4,   -4,   -4,
    793    -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
    794    -541, -411, -1225, -265, -787, -249, -739, -216, 354,  -362, 1168, -460,
    795    974,  -70,  552,   -428, -4,   859,  712,  323,  908,  665,  128,  208,
    796    -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
    797    -4,   -4,   -4,    -4,   -106, -52,  -148, -66,  -190, -79,  -229, -31,
    798    64,   -72,  160,   -32,  148,  -100, 242,  -58,  -4,   72,   132,  154,
    799    52,   125,  188,   23,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
    800    -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -694, -257, -763, -229,
    801    -679, -319, -949,  -117, 456,  -464, 962,  -50,  492,  -408, 1030, -230,
    802    -4,   295,  916,   625,  88,   537,  804,  109,  -4,   -4,   -4,   -4,
    803    -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
    804    -244, -140, -412,  -182, -538, -238, -706, -116, 156,  -164, 428,  -116,
    805    464,  -248, 708,   -228, -4,   244,  316,  418,  220,  454,  484,  108,
    806    -4,   -4,   -4,    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,
    807    -4,   -4,   -4,    -4,
    808  };
    809  float expected_21_valid[] = {
    810    -13,  -31,  -19,  -49,  -191, -565, -194, -574, -13,  -31,   -4,   14,
    811    -22,  44,   -16,  382,  -366, 738,  -22,  32,   23,   -4,    23,   32,
    812    545,  20,   204,  720,  5,    32,   -4,   -4,   -4,   -4,    -4,   -4,
    813    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
    814    -4,   -4,   -222, -658, -252, -748, -114, -334, -192, -568,  -112, -328,
    815    -4,   432,  -440, 928,  -64,  276,  -164, 532,  -220, 428,   650,  -4,
    816    304,  868,  266,  116,  400,  316,  104,  428,  -4,   -4,    -4,   -4,
    817    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
    818    -4,   -4,   -4,   -4,   -72,  -208, -288, -856, -290, -862,  -202, -598,
    819    -132, -388, -4,   132,  -140, 700,  -436, 1000, -144, 532,   -260, 508,
    820    200,  -4,   712,  268,  422,  860,  450,  276,  124,  508,   -4,   -4,
    821    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
    822    -4,   -4,   -4,   -4,   -4,   -4,   -183, -541, -411, -1225, -265, -787,
    823    -249, -739, -216, -640, -4,   354,  -362, 1168, -460, 974,   -70,  552,
    824    -428, 844,  533,  -4,   859,  712,  323,  908,  665,  128,   208,  844,
    825    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
    826    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -38,  -106,  -52,  -148,
    827    -66,  -190, -79,  -229, -31,  -85,  -4,   64,   -72,  160,   -32,  148,
    828    -100, 242,  -58,  104,  98,   -4,   72,   132,  154,  52,    125,  188,
    829    23,   104,  -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
    830    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -234, -694,
    831    -257, -763, -229, -679, -319, -949, -117, -343, -4,   456,   -464, 962,
    832    -50,  492,  -408, 1030, -230, 448,  686,  -4,   295,  916,   625,  88,
    833    537,  804,  109,  448,  -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
    834    -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,   -4,    -4,   -4,
    835    -84,  -244, -140, -412, -182, -538, -238, -706, -116, -340,  -4,   156,
    836    -164, 428,  -116, 464,  -248, 708,  -228, 444,  236,  -4,    244,  316,
    837    418,  220,  454,  484,  108,  444,
    838  };
    839 
    840  cnn_config.layer_config[0].pad = PADDING_SAME_ZERO;
    841 
    842  RunCNNTest(image_width, image_height, input, expected_21_same, &cnn_config,
    843             image_width, &thread_data, MSE_INT_TOL);
    844 
    845  cnn_config.layer_config[0].pad = PADDING_VALID;
    846 
    847  RunCNNTest(image_width, image_height, input, expected_21_valid, &cnn_config,
    848             image_width, &thread_data, MSE_INT_TOL);
    849 }
    850 
    851 TEST_F(CNNTest, TestLargeKernelsAndStrides) {
    852  float input_10x11[] = {
    853    4,  4,  2,  4,  2,  -5, -2, 3, -1, 0,  0,  1,  2,  0,  -5, -2, -5, 1,  -3,
    854    -1, 4,  -3, 2,  -2, 1,  0,  1, -3, -3, -4, -2, -2, 1,  -4, -1, 4,  1,  -4,
    855    -4, -4, 3,  2,  -5, 3,  -5, 1, 2,  -4, 1,  -1, 3,  4,  -2, 3,  -3, 3,  0,
    856    2,  -4, -5, -5, -2, -1, -2, 1, 1,  1,  -2, 4,  -5, 4,  -1, -1, 2,  3,  -4,
    857    2,  2,  3,  0,  0,  1,  0,  3, 2,  3,  1,  -2, 3,  -4, 3,  2,  4,  -2, 0,
    858    4,  -4, 1,  -3, -3, -3, -5, 1, -3, -5, 0,  4,  -1, -3, 2,
    859  };
    860 
    861  float weights_10x11[] = {
    862    -3, 4,  -4, -3, -5, 1,  -2, 3,  1,  -4, -4, 0,  -1, 0,  3,  1,  -3, -2, 0,
    863    -1, 1,  3,  -4, -4, -3, -3, -2, 4,  3,  -5, 4,  2,  -3, 4,  -2, -1, 2,  -1,
    864    -5, 0,  -3, 0,  3,  -5, -5, 3,  -4, -1, -5, 3,  4,  0,  4,  -5, 2,  -1, 2,
    865    -1, -1, -1, -5, 0,  -4, 3,  -1, 1,  1,  -1, 3,  2,  -5, -4, 0,  -4, 4,  -5,
    866    -3, 4,  -5, 2,  -5, -4, -4, -1, 3,  3,  0,  2,  -4, 1,  -2, 1,  1,  0,  3,
    867    -2, 0,  1,  2,  4,  -3, -1, -5, -5, 2,  -4, 1,  1,  2,  -4, -2, -2, 2,  1,
    868    3,  4,  -5, 1,  -1, -3, -3, -1, -2, -5, 1,  -1, 0,  1,  4,  4,  0,  0,  4,
    869    -3, -1, -5, -3, 0,  1,  1,  1,  -5, 3,  4,  3,  -5, 3,  -2, -2, 0,  -4, 0,
    870    0,  -2, 1,  -4, -1, 0,  -5, -2, -2, -5, -3, -3, 1,  1,  -3, 2,  4,  2,  4,
    871    -4, -3, 3,  1,  1,  3,  -4, 4,  -2, -3, -3, -3, -3, -4, -2, 3,  -5, 2,  4,
    872    -1, -4, -4, 4,  -2, -1, 3,  -3, -4, -4, -2, 4,  1,  0,  2,  -1, 4,  -3, 1,
    873    4,  -3, 4,  4,  0,  -4, 3,  -2, -3, 2,  3,  -1, -3, 2,  1,  4,  -2, -3, 1,
    874    4,  -2, 2,  -2, -5, -2, 1,  4,  -1, -4, 4,  -5, 2,  -5, -4, -1, -2, 3,  1,
    875    2,  1,  -5, 1,  -5, -4, -1, -2, 2,  -2, -4, -3, -2, -2, 4,  -1, 2,  2,  -4,
    876    2,  -2, 4,  -4, -2, -2, 1,  -1, 1,  1,  1,  -4, -5, -2, 3,  -4, -1, 3,  -2,
    877    3,  2,  -5, -4, 0,  3,  -2, -4, -5, 3,  -2, -4, 2,  -2, 1,  -4, 0,  2,  -5,
    878    1,  -4, -1, -1, 4,  -5, -4, 0,  -5, -4, -3, -5, -4, 0,  2,  0,  -4, 2,  -2,
    879    1,  1,  -3, 2,  0,  -4, 0,  -4, 1,  0,  -5, -1, -1, -1, -5, 4,  2,  2,  -4,
    880    3,  -2, -2, 2,  -3, -2, -1, 2,  -4, -5, 2,  -2, -4, -5, -5, -1, 2,  -1, 0,
    881    -5, -2, -2, -5, 0,  1,  -1, -5, 0,  3,  2,  3,  0,  -3, -2, 0,  -5, -1, -2,
    882    2,  -4, -1, 2,  2,  -5, 2,  -4, 0,  3,  -3, 1,  0,  0,  1,  -5, -3, 1,  -1,
    883    0,  -4, -3, 2,  -4, -4, 4,  -1, 0,  1,  2,  -4, -5, 4,  -2, 1,  -4, -4, -3,
    884    -1, -1, 1,  -1, -4, -1, -4, -3, 2,  -1, -2, -4, 1,  1,  0,  -2, 0,  -4, 3,
    885    -3, 0,  -4, -1, -4, 2,  -1, -2, -5, -1, -2, -3, 3,  -1, 0,  -3, 0,  1,  -5,
    886    1,  -5, 0,  1,
    887  };
    888 
    889  float bias_10x11[] = { 3 };
    890 
    891  float expected_10x11[] = {
    892    118,
    893  };
    894 
    895  CNN_CONFIG cnn_config = { 1,
    896                            0,
    897                            0,
    898                            0,
    899                            0,
    900                            { {
    901                                1,
    902                                23,
    903                                20,
    904                                1,
    905                                15,
    906                                20,
    907                                0,
    908                                weights_10x11,
    909                                bias_10x11,
    910                                PADDING_SAME_ZERO,
    911                                NONE,
    912                                0,
    913                                0,
    914                                BRANCH_NO_COPY,
    915                                BRANCH_NOC,
    916                                {},
    917                                {},
    918                                0,
    919                            } } };
    920 
    921  int image_height = 10;
    922  int image_width = 11;
    923 
    924  CNN_THREAD_DATA thread_data = { 1, nullptr };
    925 
    926  RunCNNTest(image_width, image_height, input_10x11, expected_10x11,
    927             &cnn_config, image_width, &thread_data, MSE_INT_TOL);
    928 
    929  float input_11x10[] = {
    930    -2, -2, 3,  -5, -1, -3, 1,  3,  2,  1,  1,  -5, 4,  1,  3,  -5, 3,  -3, -5,
    931    0,  -1, -3, -3, 1,  1,  -5, -1, -5, -5, -3, 0,  1,  -3, -1, -3, -3, 0,  3,
    932    4,  -4, -1, 3,  -3, -1, -3, 1,  -3, -2, -1, -4, -3, 2,  -4, 1,  -4, -1, -3,
    933    -5, -1, 2,  3,  0,  2,  2,  -5, 4,  1,  2,  -1, -4, 4,  -4, -4, 0,  -1, 1,
    934    -1, 1,  -3, -3, -2, 1,  2,  4,  4,  4,  -3, -3, 0,  1,  0,  1,  4,  1,  3,
    935    4,  -3, -2, -4, 4,  2,  0,  3,  4,  -1, 2,  -2, 1,  -3, -2,
    936  };
    937 
    938  float weights_11x10[] = {
    939    4,  -1, 1,  -1, 2,  4,  3,  3,  -4, 3,  -5, 1,  -1, -1, -2, -2, 0,  2,  -3,
    940    -2, 3,  -5, -1, 0,  -1, -2, -2, -1, 2,  4,  3,  1,  0,  0,  -3, 3,  -4, -1,
    941    -5, 4,  -2, -2, 1,  2,  -1, -3, 1,  2,  -5, 1,  -3, 3,  3,  0,  -4, -4, -5,
    942    -3, -4, -4, 4,  -2, 4,  4,  -2, 2,  -5, -1, -2, -5, -1, 4,  -3, 3,  -2, 0,
    943    -4, -3, 0,  -1, -2, 4,  2,  0,  -2, -5, -4, 1,  4,  -4, -2, 2,  -2, 1,  1,
    944    -4, 1,  -4, -4, -2, 4,  2,  -1, -5, -5, 1,  -3, -3, 3,  -3, -5, -3, 4,  -1,
    945    -1, -3, 0,  -4, 3,  -1, 0,  -2, 0,  -5, -2, -5, 2,  0,  -5, 2,  3,  -2, 2,
    946    4,  -1, 1,  -3, 2,  3,  2,  0,  -5, -4, -5, 2,  1,  1,  -1, -2, 3,  4,  2,
    947    -2, 4,  -2, 3,  1,  -4, -3, -1, 4,  4,  -3, -5, -2, 2,  0,  3,  -2, 3,  -1,
    948    -4, 0,  -2, 0,  3,  4,  -2, -3, -2, 0,  3,  4,  2,  -4, 0,  1,  2,  2,  -1,
    949    -1, 4,  1,  4,  -2, -1, -1, -5, 1,  -3, 3,  3,  -1, -4, 3,  -5, 0,  0,  -1,
    950    -4, -1, -2, 4,  -2, 3,  3,  -3, 1,  -1, 2,  -1, 4,  4,  -2, -2, 4,  -2, 0,
    951    3,  -3, -5, -1, -2, 4,  -4, 2,  -4, 0,  -2, 3,  -3, 2,  2,  -2, -5, -1, 4,
    952    3,  -2, -1, 3,  3,  -1, 3,  0,  -3, 0,  4,  2,  0,  -1, 4,  1,  1,  2,  1,
    953    3,  1,  1,  1,  -3, -5, -4, 4,  -4, 2,  0,  0,  -4, 1,  4,  -5, 4,  4,  0,
    954    1,  0,  -2, -4, -4, -3, 0,  1,  -5, 4,  0,  -3, -2, -4, 2,  4,  1,  -5, 1,
    955    -4, 1,  0,  -3, -3, 0,  2,  -5, 4,  3,  -2, -5, 3,  1,  -1, 0,  3,  -2, -2,
    956    3,  -2, -5, 4,  1,  -2, 2,  -1, 0,  4,  0,  -5, 3,  -2, 1,  2,  1,  -5, -3,
    957    -2, -5, 4,  -4, 0,  3,  2,  -1, -4, -1, 2,  1,  -2, 3,  -1, -4, 2,  0,  -3,
    958    1,  -1, 2,  -5, -4, -1, -5, 1,  4,  3,  4,  2,  -3, 1,  -5, -1, 3,  0,  -1,
    959    -4, 3,  4,  -5, 4,  4,  -3, 2,  -3, -1, -3, -5, -3, 2,  -3, -2, 1,  1,  0,
    960    -5, 3,  2,  1,  -5, 1,  1,  1,  3,  4,  -4, -1, -2, 0,  -5, -3, -5, -2, -4,
    961    3,  3,  3,  4,  0,  -4, -1, -5, 0,  -3, 1,  4,  4,  -4, 4,  -5, -5, -1, -2,
    962    -5, 3,  -4, 4,  3,  0,  -3, 2,  -2, 0,  0,  4,  4,  0,  -2, 1,  -1, -3, 2,
    963    -1, 1,  -3, -5,
    964  };
    965 
    966  float bias_11x10[] = {
    967    -5,
    968  };
    969 
    970  float expected_11x10[] = {
    971    36,  -84,  95,   45,  18,   46,   77,  -54, -99,  -149, 66,  49,  161, 11,
    972    39,  61,   -66,  61,  4,    -3,   34,  -44, -23,  31,   64,  29,  47,  72,
    973    -27, -27,  121,  -3,  100,  1,    30,  -78, -12,  -89,  -59, 8,   -16, 112,
    974    91,  -102, -26,  -4,  30,   54,   4,   -84, -24,  -58,  27,  -53, -33, 5,
    975    53,  -26,  63,   50,  -103, -130, -23, 6,   -104, -207, 73,  23,  77,  132,
    976    38,  32,   -130, -44, -60,  7,    27,  176, 45,   -32,  -2,  99,  -97, 63,
    977    69,  126,  47,   63,  136,  -57,  5,   16,  -40,  -157, 8,   38,  -44, -10,
    978    91,  7,    122,  140, 30,   -105, 4,   -1,  113,  64,   180, 141,
    979  };
    980 
    981  cnn_config.layer_config[0].weights = weights_11x10;
    982  cnn_config.layer_config[0].bias = bias_11x10;
    983  cnn_config.layer_config[0].filter_width = 20;
    984  cnn_config.layer_config[0].filter_height = 23;
    985  cnn_config.layer_config[0].skip_width = 1;
    986  cnn_config.layer_config[0].skip_height = 1;
    987  image_height = 11;
    988  image_width = 10;
    989 
    990  RunCNNTest(image_width, image_height, input_11x10, expected_11x10,
    991             &cnn_config, image_width, &thread_data, MSE_INT_TOL);
    992 }
    993 
    994 TEST_F(CNNTest, TestSoftsignSingleLayer) {
    995  int image_width = 8;
    996  int image_height = 8;
    997  int filter_height = 5;
    998  int filter_width = 4;
    999  float input[] = {
   1000    -0.5220f, 0.8410f,  -0.8990f, -0.0090f, 0.6710f,  -0.9470f, -0.8240f,
   1001    -0.0870f, 0.5380f,  0.4750f,  0.570f,   -0.3760f, -0.6960f, -0.5940f,
   1002    -0.3830f, 0.080f,   -0.0980f, -0.4940f, -0.4030f, 0.9460f,  -0.6020f,
   1003    0.4220f,  0.6190f,  0.6640f,  -0.9210f, -0.1470f, -0.2480f, -0.1120f,
   1004    -0.580f,  -0.0650f, 0.3330f,  0.9860f,  -0.7430f, 0.7610f,  0.4840f,
   1005    0.1030f,  0.9570f,  0.6120f,  -0.5240f, -0.1220f, -0.5850f, -0.270f,
   1006    0.7840f,  -0.9790f, 0.7290f,  -0.30f,   -0.6460f, 0.0780f,  0.4750f,
   1007    -0.0510f, 0.4550f,  0.3850f,  -0.7230f, 0.4460f,  -0.6260f, -0.810f,
   1008    0.8720f,  -0.2120f, -0.580f,  -0.9510f, -0.8430f, -0.1340f, -0.0850f,
   1009    0.9190f,
   1010  };
   1011  float expected_same[] = {
   1012    0.430f,   0.660f,  0.5510f,  -0.610f,  0.450f,  -0.1610f, 0.0520f,  0.3240f,
   1013    0.6820f,  0.3820f, 0.6360f,  0.7480f,  0.3080f, 0.090f,   0.3910f,  0.1730f,
   1014    0.340f,   0.6660f, -0.4990f, 0.4280f,  0.1540f, 0.120f,   0.4670f,  0.6150f,
   1015    -0.3880f, 0.7590f, 0.4190f,  0.7350f,  0.5310f, -0.5160f, -0.1760f, 0.6790f,
   1016    -0.6780f, 0.5470f, 0.5750f,  -0.6420f, 0.7210f, -0.4620f, 0.5430f,  0.770f,
   1017    -0.1990f, 0.3950f, 0.7860f,  -0.4380f, 0.7540f, 0.2640f,  -0.6430f, 0.4510f,
   1018    -0.1260f, 0.1590f, -0.2110f, -0.0560f, 0.6570f, 0.680f,   0.5870f,  0.4720f,
   1019    0.4040f,  0.3630f, 0.670f,   0.2360f,  0.410f,  0.6980f,  -0.5350f, 0.3940f,
   1020  };
   1021  float expected_replicate[] = {
   1022    0.540f,   0.7230f,  -0.3530f, -0.2130f, 0.7440f,  -0.4470f, -0.6260f,
   1023    -0.2050f, 0.7230f,  0.4630f,  0.5920f,  0.7440f,  0.6080f,  0.3130f,
   1024    -0.5670f, -0.4720f, 0.5480f,  0.6660f,  -0.4990f, 0.4280f,  0.1540f,
   1025    0.120f,   0.3390f,  0.6090f,  0.4160f,  0.7590f,  0.4190f,  0.7350f,
   1026    0.5310f,  -0.5160f, -0.490f,  0.4450f,  -0.610f,  0.5470f,  0.5750f,
   1027    -0.6420f, 0.7210f,  -0.4620f, 0.3150f,  0.7370f,  -0.5820f, 0.3950f,
   1028    0.7860f,  -0.4380f, 0.7540f,  0.2640f,  -0.7430f, -0.5340f, -0.6270f,
   1029    0.4430f,  0.4730f,  0.4570f,  0.7450f,  0.630f,   0.2620f,  0.3140f,
   1030    -0.1840f, 0.1810f,  0.7210f,  0.2760f,  0.6430f,  0.6720f,  -0.4390f,
   1031    0.2040f,
   1032  };
   1033  float expected_valid[] = {
   1034    0.6660f,  -0.4990f, 0.4280f,  0.1540f,  0.120f,  0.7590f,  0.4190f,
   1035    0.7350f,  0.5310f,  -0.5160f, 0.5470f,  0.5750f, -0.6420f, 0.7210f,
   1036    -0.4620f, 0.3950f,  0.7860f,  -0.4380f, 0.7540f, 0.2640f,
   1037  };
   1038  float weights[] = {
   1039    0.6210f,  0.3710f,  -0.2770f, -0.7230f, -0.2450f, 0.6770f,  0.3080f,
   1040    -0.9880f, -0.080f,  0.7190f,  -0.6760f, -0.0170f, -0.8970f, 0.8260f,
   1041    0.7390f,  -0.4550f, -0.4260f, -0.6330f, 0.0880f,  -0.9390f,
   1042  };
   1043  float bias[] = {
   1044    0.750f,
   1045  };
   1046 
   1047  CNN_CONFIG cnn_config = { 1,
   1048                            0,
   1049                            0,
   1050                            0,
   1051                            0,
   1052                            { {
   1053                                1,
   1054                                filter_width,
   1055                                filter_height,
   1056                                1,
   1057                                1,
   1058                                1,
   1059                                0,
   1060                                weights,
   1061                                bias,
   1062                                PADDING_SAME_ZERO,
   1063                                SOFTSIGN,
   1064                                0,
   1065                                0,
   1066                                BRANCH_NO_COPY,
   1067                                BRANCH_NOC,
   1068                                {},
   1069                                {},
   1070                                0,
   1071                            } } };
   1072 
   1073  CNN_THREAD_DATA thread_data = { 1, nullptr };
   1074 
   1075  RunCNNTest(image_width, image_height, input, expected_same, &cnn_config,
   1076             image_width, &thread_data, MSE_FLOAT_TOL);
   1077 
   1078  cnn_config.layer_config[0].pad = PADDING_SAME_REPLICATE;
   1079 
   1080  RunCNNTest(image_width, image_height, input, expected_replicate, &cnn_config,
   1081             image_width, &thread_data, MSE_FLOAT_TOL);
   1082 
   1083  cnn_config.layer_config[0].pad = PADDING_VALID;
   1084 
   1085  RunCNNTest(image_width, image_height, input, expected_valid, &cnn_config,
   1086             image_width, &thread_data, MSE_FLOAT_TOL);
   1087 }
   1088 
   1089 TEST_F(CNNTest, TestBranchTensorAdd) {
   1090  int filter_width = 2;
   1091  int filter_height = 3;
   1092 
   1093  int image_width = 4;
   1094  int image_height = 4;
   1095 
   1096  float input[] = {
   1097    -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
   1098  };
   1099 
   1100  float weights[] = {
   1101    -3, -1, 4,  -1, -3, 3,  3,  0,  2,  0,  3,  2,  4,  4, 4,  -5, 1, -4,
   1102    2,  -4, 1,  -3, 0,  4,  -5, 4,  0,  -4, -3, -1, 0,  0, -2, 0,  0, 2,
   1103    -5, -1, 1,  -3, 3,  4,  3,  0,  1,  -1, 1,  1,  2,  4, -2, -5, 2, -2,
   1104    3,  -2, 4,  -1, 0,  2,  3,  2,  -2, -1, -3, 1,  3,  4, -1, -3, 0, -4,
   1105    4,  2,  -3, -3, -1, 0,  1,  0,  3,  3,  -3, 0,  3,  2, -5, -3, 4, -5,
   1106    3,  -1, -1, -3, 0,  1,  -1, -4, 2,  4,  -1, 4,  -1, 1, 3,  4,  4, 4,
   1107    0,  -1, -3, -3, -3, -3, 2,  -3, -2, 2,  3,  -3,
   1108  };
   1109 
   1110  float bias[] = {
   1111    3, 4, -1, -1, 2, 1, -2, 1, 4, 1, 3,
   1112  };
   1113 
   1114  float expected[] = {
   1115    -11502, -4101, -3424, 668,   -17950, -5470, -5504, 626,
   1116    4835,   446,   1779,  -3483, 3679,   -4214, 4578,  -105,
   1117  };
   1118 
   1119  int channels = 2;
   1120 
   1121  CNN_CONFIG cnn_config = { 6,
   1122                            0,
   1123                            0,
   1124                            0,
   1125                            0,
   1126                            { {
   1127                                  1,
   1128                                  filter_width,
   1129                                  filter_height,
   1130                                  channels,
   1131                                  1,
   1132                                  1,
   1133                                  0,
   1134                                  weights,
   1135                                  bias,
   1136                                  PADDING_SAME_ZERO,
   1137                                  NONE,
   1138                                  0,
   1139                                  0,
   1140                                  BRANCH_NO_COPY,
   1141                                  BRANCH_NOC,
   1142                                  {},
   1143                                  {},
   1144                                  -1,
   1145                              },
   1146                              {
   1147                                  channels,
   1148                                  filter_width,
   1149                                  filter_height,
   1150                                  channels,
   1151                                  1,
   1152                                  1,
   1153                                  0,
   1154                                  nullptr,
   1155                                  nullptr,
   1156                                  PADDING_SAME_ZERO,
   1157                                  NONE,
   1158                                  0,
   1159                                  0,
   1160                                  BRANCH_INPUT,
   1161                                  BRANCH_NOC,
   1162                                  {
   1163                                      0x02,
   1164                                      0,
   1165                                      0x00,
   1166                                  },
   1167                                  {},
   1168                                  -1,
   1169                              },
   1170                              {
   1171                                  channels,
   1172                                  filter_width,
   1173                                  filter_height,
   1174                                  channels,
   1175                                  1,
   1176                                  1,
   1177                                  0,
   1178                                  nullptr,
   1179                                  nullptr,
   1180                                  PADDING_SAME_ZERO,
   1181                                  NONE,
   1182                                  0,
   1183                                  1,
   1184                                  BRANCH_NO_COPY,
   1185                                  BRANCH_NOC,
   1186                                  {},
   1187                                  {},
   1188                                  -1,
   1189                              },
   1190                              {
   1191                                  channels,
   1192                                  filter_width,
   1193                                  filter_height,
   1194                                  channels,
   1195                                  1,
   1196                                  1,
   1197                                  0,
   1198                                  nullptr,
   1199                                  nullptr,
   1200                                  PADDING_SAME_ZERO,
   1201                                  NONE,
   1202                                  0,
   1203                                  1,
   1204                                  BRANCH_NO_COPY,
   1205                                  BRANCH_NOC,
   1206                                  {},
   1207                                  {},
   1208                                  -1,
   1209                              },
   1210                              {
   1211                                  channels,
   1212                                  filter_width,
   1213                                  filter_height,
   1214                                  channels,
   1215                                  1,
   1216                                  1,
   1217                                  0,
   1218                                  nullptr,
   1219                                  nullptr,
   1220                                  PADDING_SAME_ZERO,
   1221                                  NONE,
   1222                                  0,
   1223                                  0,
   1224                                  BRANCH_NO_COPY,
   1225                                  BRANCH_ADD,
   1226                                  {
   1227                                      0x00,
   1228                                      0,
   1229                                      0x02,
   1230                                  },
   1231                                  {},
   1232                                  -1,
   1233                              },
   1234                              {
   1235                                  channels,
   1236                                  filter_width,
   1237                                  filter_height,
   1238                                  1,
   1239                                  1,
   1240                                  1,
   1241                                  0,
   1242                                  nullptr,
   1243                                  nullptr,
   1244                                  PADDING_SAME_ZERO,
   1245                                  NONE,
   1246                                  0,
   1247                                  0,
   1248                                  BRANCH_NO_COPY,
   1249                                  BRANCH_NOC,
   1250                                  {},
   1251                                  {},
   1252                                  0,
   1253                              } } };
   1254 
   1255  // Weights and biases need to be specified separately because
   1256  // of the offset.
   1257  AssignLayerWeightsBiases(&cnn_config, weights, bias);
   1258 
   1259  CNN_THREAD_DATA thread_data = { 1, nullptr };
   1260 
   1261  RunCNNTest(image_width, image_height, input, expected, &cnn_config,
   1262             image_width, &thread_data, MSE_INT_TOL);
   1263 }
   1264 
   1265 TEST_F(CNNTest, TestBranchTensorConcatenation) {
   1266  int filter_width = 2;
   1267  int filter_height = 3;
   1268 
   1269  int image_width = 4;
   1270  int image_height = 4;
   1271 
   1272  float input[] = {
   1273    -3, -2, -2, 0, -1, 3, 2, -2, 1, 3, 4, 0, 2, -5, -4, 0,
   1274  };
   1275 
   1276  float weights[] = {
   1277    3,  0,  2,  0,  2,  3,  1,  -3, 1,  -5, -3, 0,  -4, 4,  0,  -5, 0,  -5, -1,
   1278    -2, -5, 0,  -3, 2,  -4, 2,  0,  2,  -1, 0,  -4, 3,  0,  0,  -1, -5, 2,  -1,
   1279    4,  -4, -2, -3, -3, 3,  4,  -2, -1, -4, -1, 4,  4,  -1, 4,  3,  -4, 2,  -2,
   1280    -4, -3, -2, 3,  -3, -5, -1, 3,  -2, 4,  1,  -4, -3, -5, -5, -3, 4,  -2, -2,
   1281    -1, -5, -5, 0,  -1, -2, -3, 3,  -4, -5, 2,  -3, 1,  0,  -5, 2,  2,  -2, 0,
   1282    2,  2,  -2, 4,  2,  2,  0,  1,  -5, -3, 0,  2,  -2, 1,  2,  -5, 2,  3,  3,
   1283    -1, 3,  0,  -3, 3,  -4, -4, 3,  3,  -4, -2, 2,  -2, 2,  -2, -1, 3,  0,
   1284  };
   1285 
   1286  float bias[] = {
   1287    -3, -5, 4, -4, -3, -2, 0, 3, -4, 4, -3,
   1288  };
   1289 
   1290  float expected[] = {
   1291    -33533, -32087, -6741,  -2124, 39979, 41453, 14034, 689,
   1292    -22611, -42203, -14882, -239,  15781, 15963, 9524,  837,
   1293  };
   1294 
   1295  int channels = 2;
   1296 
   1297  CNN_CONFIG cnn_config = { 6,
   1298                            0,
   1299                            0,
   1300                            0,
   1301                            0,
   1302                            { {
   1303                                  1,
   1304                                  filter_width,
   1305                                  filter_height,
   1306                                  channels,
   1307                                  1,
   1308                                  1,
   1309                                  0,
   1310                                  weights,
   1311                                  bias,
   1312                                  PADDING_SAME_ZERO,
   1313                                  NONE,
   1314                                  0,
   1315                                  0,
   1316                                  BRANCH_NO_COPY,
   1317                                  BRANCH_NOC,
   1318                                  {},
   1319                                  {},
   1320                                  -1,
   1321                              },
   1322                              {
   1323                                  channels,
   1324                                  filter_width,
   1325                                  filter_height,
   1326                                  channels,
   1327                                  1,
   1328                                  1,
   1329                                  0,
   1330                                  nullptr,
   1331                                  nullptr,
   1332                                  PADDING_SAME_ZERO,
   1333                                  NONE,
   1334                                  0,
   1335                                  0,
   1336                                  BRANCH_INPUT,
   1337                                  BRANCH_NOC,
   1338                                  {
   1339                                      0x02,
   1340                                      0,
   1341                                      0x00,
   1342                                  },
   1343                                  {},
   1344                                  -1,
   1345                              },
   1346                              {
   1347                                  channels,
   1348                                  filter_width,
   1349                                  filter_height,
   1350                                  channels,
   1351                                  1,
   1352                                  1,
   1353                                  0,
   1354                                  nullptr,
   1355                                  nullptr,
   1356                                  PADDING_SAME_ZERO,
   1357                                  NONE,
   1358                                  0,
   1359                                  1,
   1360                                  BRANCH_NO_COPY,
   1361                                  BRANCH_NOC,
   1362                                  {},
   1363                                  {},
   1364                                  -1,
   1365                              },
   1366                              {
   1367                                  channels,
   1368                                  filter_width,
   1369                                  filter_height,
   1370                                  channels,
   1371                                  1,
   1372                                  1,
   1373                                  0,
   1374                                  nullptr,
   1375                                  nullptr,
   1376                                  PADDING_SAME_ZERO,
   1377                                  NONE,
   1378                                  0,
   1379                                  1,
   1380                                  BRANCH_NO_COPY,
   1381                                  BRANCH_NOC,
   1382                                  {},
   1383                                  {},
   1384                                  -1,
   1385                              },
   1386                              {
   1387                                  channels,
   1388                                  filter_width,
   1389                                  filter_height,
   1390                                  channels,
   1391                                  1,
   1392                                  1,
   1393                                  0,
   1394                                  nullptr,
   1395                                  nullptr,
   1396                                  PADDING_SAME_ZERO,
   1397                                  NONE,
   1398                                  0,
   1399                                  0,
   1400                                  BRANCH_NO_COPY,
   1401                                  BRANCH_CAT,
   1402                                  {
   1403                                      0x00,
   1404                                      0,
   1405                                      0x02,
   1406                                  },
   1407                                  {},
   1408                                  -1,
   1409                              },
   1410                              {
   1411                                  channels + channels,
   1412                                  filter_width,
   1413                                  filter_height,
   1414                                  1,
   1415                                  1,
   1416                                  1,
   1417                                  0,
   1418                                  nullptr,
   1419                                  nullptr,
   1420                                  PADDING_SAME_ZERO,
   1421                                  NONE,
   1422                                  0,
   1423                                  0,
   1424                                  BRANCH_NO_COPY,
   1425                                  BRANCH_NOC,
   1426                                  {},
   1427                                  {},
   1428                                  0,
   1429                              } } };
   1430 
   1431  // Weights and biases need to be specified separately because
   1432  // of the offset.
   1433  AssignLayerWeightsBiases(&cnn_config, weights, bias);
   1434 
   1435  CNN_THREAD_DATA thread_data = { 1, nullptr };
   1436 
   1437  RunCNNTest(image_width, image_height, input, expected, &cnn_config,
   1438             image_width, &thread_data, MSE_INT_TOL);
   1439 }
   1440 
   1441 // TODO(logangw): Add test to test all combinations of branch_copy_type.
   1442 
   1443 TEST_F(CNNTest, TestBranchCombinations) {
   1444  int filter_width = 2;
   1445  int filter_height = 3;
   1446 
   1447  int image_width = 4;
   1448  int image_height = 4;
   1449 
   1450  float input[] = {
   1451    3, 2, -5, -4, 4, -2, -4, -3, 4, 2, -3, 2, -3, 1, -5, -1,
   1452  };
   1453 
   1454  float weights[] = {
   1455    2,  3,  0,  4,  4,  3,  1,  0,  1,  -5, 4,  -3, 3,  0,  4,  -1, -1, -5,
   1456    2,  1,  -3, -5, 3,  -1, -3, -2, 0,  -2, 3,  0,  -2, -4, -2, -2, 2,  -5,
   1457    4,  -5, 0,  1,  -5, -4, -3, -4, 2,  -2, 1,  0,  3,  -2, -4, 3,  4,  -4,
   1458    -1, -1, -3, -2, -2, -1, 2,  0,  2,  -1, 2,  -4, -4, -1, 2,  0,  3,  -2,
   1459    -2, 3,  -3, 4,  -2, 4,  3,  4,  1,  0,  -2, -3, -5, 1,  -3, 2,  0,  -2,
   1460    -2, -1, -1, -5, -2, -3, -1, 3,  3,  4,  4,  0,  2,  1,  3,  -3, 2,  -5,
   1461    -5, 1,  -5, -1, 3,  3,  2,  -4, -1, 3,  -4, -2, -5, -2, 1,  3,  2,  2,
   1462    -5, -2, -3, -1, -2, -4, -1, -2, 2,  1,  -4, -4, 2,  0,  2,  0,  2,  -3,
   1463    -2, -4, 4,  0,  1,  -3, -5, 4,  -1, 2,  3,  -5, -1, 0,  4,  -1, -1, 3,
   1464    -1, -3, 3,  1,  4,  3,  4,  3,  -4, -5, -1, 3,  3,  -4, 3,  1,  3,  -5,
   1465    3,  4,  -5, 4,  2,  -1, -5, 2,  1,  0,  4,  0,  -3, 2,  0,  2,  -2, 1,
   1466    -1, -2, -1, -5, 4,  3,  3,  -2, 2,  4,  -5, -5, -3, -2, 4,  0,  -4, 1,
   1467  };
   1468 
   1469  float bias[] = {
   1470    -1, 4, 0, 2, 2, -2, 0, -4, -5, -1, 1, -2, 3, 0, 4, -2, 1, 0, 0,
   1471  };
   1472 
   1473  float expected[] = {
   1474    149496, 15553,  -24193, -20956, 134094, 86432,  -68283, -6366,
   1475    -53031, 133739, 67407,  -13539, -53205, -58635, -20033, 1979,
   1476  };
   1477 
   1478  int channels = 2;
   1479 
   1480  CNN_CONFIG cnn_config = { 10,
   1481                            0,
   1482                            0,
   1483                            0,
   1484                            0,
   1485                            {
   1486                                {
   1487                                    1,
   1488                                    filter_width,
   1489                                    filter_height,
   1490                                    channels,
   1491                                    1,
   1492                                    1,
   1493                                    0,
   1494                                    weights,
   1495                                    bias,
   1496                                    PADDING_SAME_ZERO,
   1497                                    NONE,
   1498                                    0,
   1499                                    0,
   1500                                    BRANCH_NO_COPY,
   1501                                    BRANCH_NOC,
   1502                                    {},
   1503                                    {},
   1504                                    -1,
   1505                                },
   1506                                {
   1507                                    channels,
   1508                                    filter_width,
   1509                                    filter_height,
   1510                                    channels,
   1511                                    1,
   1512                                    1,
   1513                                    0,
   1514                                    nullptr,
   1515                                    nullptr,
   1516                                    PADDING_SAME_ZERO,
   1517                                    NONE,
   1518                                    0,
   1519                                    0,
   1520                                    BRANCH_INPUT,
   1521                                    BRANCH_NOC,
   1522                                    {
   1523                                        0x06,
   1524                                        0,
   1525                                        0x00,
   1526                                    },
   1527                                    {},
   1528                                    -1,
   1529                                },
   1530                                {
   1531                                    channels,
   1532                                    filter_width,
   1533                                    filter_height,
   1534                                    channels,
   1535                                    1,
   1536                                    1,
   1537                                    0,
   1538                                    nullptr,
   1539                                    nullptr,
   1540                                    PADDING_SAME_ZERO,
   1541                                    NONE,
   1542                                    0,
   1543                                    2,
   1544                                    BRANCH_OUTPUT,
   1545                                    BRANCH_NOC,
   1546                                    {
   1547                                        0x08,
   1548                                        0,
   1549                                        0x00,
   1550                                    },
   1551                                    {},
   1552                                    -1,
   1553                                },
   1554                                {
   1555                                    channels,
   1556                                    filter_width,
   1557                                    filter_height,
   1558                                    channels,
   1559                                    1,
   1560                                    1,
   1561                                    0,
   1562                                    nullptr,
   1563                                    nullptr,
   1564                                    PADDING_SAME_ZERO,
   1565                                    NONE,
   1566                                    0,
   1567                                    3,
   1568                                    BRANCH_NO_COPY,
   1569                                    BRANCH_NOC,
   1570                                    {},
   1571                                    {},
   1572                                    -1,
   1573                                },
   1574                                {
   1575                                    channels,
   1576                                    filter_width,
   1577                                    filter_height,
   1578                                    channels,
   1579                                    1,
   1580                                    1,
   1581                                    0,
   1582                                    nullptr,
   1583                                    nullptr,
   1584                                    PADDING_SAME_ZERO,
   1585                                    NONE,
   1586                                    0,
   1587                                    2,
   1588                                    BRANCH_NO_COPY,
   1589                                    BRANCH_ADD,
   1590                                    {
   1591                                        0x00,
   1592                                        0,
   1593                                        0x08,
   1594                                    },
   1595                                    {},
   1596                                    -1,
   1597                                },
   1598                                {
   1599                                    channels,
   1600                                    filter_width,
   1601                                    filter_height,
   1602                                    channels,
   1603                                    1,
   1604                                    1,
   1605                                    0,
   1606                                    nullptr,
   1607                                    nullptr,
   1608                                    PADDING_SAME_ZERO,
   1609                                    NONE,
   1610                                    0,
   1611                                    2,
   1612                                    BRANCH_NO_COPY,
   1613                                    BRANCH_NOC,
   1614                                    {},
   1615                                    {},
   1616                                    -1,
   1617                                },
   1618                                {
   1619                                    channels,
   1620                                    filter_width,
   1621                                    filter_height,
   1622                                    channels,
   1623                                    1,
   1624                                    1,
   1625                                    0,
   1626                                    nullptr,
   1627                                    nullptr,
   1628                                    PADDING_SAME_ZERO,
   1629                                    NONE,
   1630                                    0,
   1631                                    1,
   1632                                    BRANCH_NO_COPY,
   1633                                    BRANCH_NOC,
   1634                                    {},
   1635                                    {},
   1636                                    -1,
   1637                                },
   1638                                {
   1639                                    channels,
   1640                                    filter_width,
   1641                                    filter_height,
   1642                                    channels,
   1643                                    1,
   1644                                    1,
   1645                                    0,
   1646                                    nullptr,
   1647                                    nullptr,
   1648                                    PADDING_SAME_ZERO,
   1649                                    NONE,
   1650                                    0,
   1651                                    1,
   1652                                    BRANCH_NO_COPY,
   1653                                    BRANCH_ADD,
   1654                                    {
   1655                                        0x00,
   1656                                        0,
   1657                                        0x0C,
   1658                                    },
   1659                                    {},
   1660                                    -1,
   1661                                },
   1662                                {
   1663                                    channels,
   1664                                    filter_width,
   1665                                    filter_height,
   1666                                    channels,
   1667                                    1,
   1668                                    1,
   1669                                    0,
   1670                                    nullptr,
   1671                                    nullptr,
   1672                                    PADDING_SAME_ZERO,
   1673                                    NONE,
   1674                                    0,
   1675                                    0,
   1676                                    BRANCH_NO_COPY,
   1677                                    BRANCH_ADD,
   1678                                    {
   1679                                        0x00,
   1680                                        0,
   1681                                        0x02,
   1682                                    },
   1683                                    {},
   1684                                    -1,
   1685                                },
   1686                                {
   1687                                    channels,
   1688                                    filter_width,
   1689                                    filter_height,
   1690                                    1,
   1691                                    1,
   1692                                    1,
   1693                                    0,
   1694                                    nullptr,
   1695                                    nullptr,
   1696                                    PADDING_SAME_ZERO,
   1697                                    NONE,
   1698                                    0,
   1699                                    0,
   1700                                    BRANCH_NO_COPY,
   1701                                    BRANCH_NOC,
   1702                                    {},
   1703                                    {},
   1704                                    0,
   1705                                },
   1706                            } };
   1707 
   1708  // Weights and biases need to be specified separately because
   1709  // of the offset.
   1710  AssignLayerWeightsBiases(&cnn_config, weights, bias);
   1711 
   1712  CNN_THREAD_DATA thread_data = { 1, nullptr };
   1713 
   1714  RunCNNTest(image_width, image_height, input, expected, &cnn_config,
   1715             image_width, &thread_data, MSE_INT_TOL);
   1716 }
   1717 
   1718 TEST_F(CNNTest, TestSplittingTensors) {
   1719  int filter_width = 2;
   1720  int filter_height = 3;
   1721 
   1722  int image_width = 4;
   1723  int image_height = 4;
   1724 
   1725  float input[] = {
   1726    -1, -1, 2, 1, 3, 2, 4, -3, -4, -2, 2, -3, 1, -3, 4, -2,
   1727  };
   1728 
   1729  float weights[] = {
   1730    -4, 1,  0,  2,  3,  4,  4,  -4, -5, -3, 2,  2,  -4, -3, 3,  2,
   1731    4,  -4, -3, -4, -4, 1,  -3, -5, -3, 4,  2,  -2, 2,  -1, -4, -1,
   1732    -2, -3, 1,  1,  0,  -5, -1, 3,  3,  -5, -3, 0,  -3, 1,  -3, -1,
   1733    1,  -3, -2, -2, 4,  -2, 0,  1,  2,  2,  -4, 2,  4,  0,  -5, -2,
   1734    4,  4,  -5, 1,  0,  2,  -2, -5, -5, -3, -5, -5, 4,  -3, 0,  0,
   1735    -4, -4, 0,  -5, -4, 0,  0,  -3, -5, -3, -1, 2,  -1, 4,  -1, 2,
   1736  };
   1737 
   1738  float bias[] = {
   1739    -4, -2, -3, -3, 3, 1, -2,
   1740  };
   1741 
   1742  float expected[] = {
   1743    530,  -762,  1469, 777,  849,   -771, -1698, 600,
   1744    -658, -1821, 98,   -668, -1798, 30,   887,   -971,
   1745  };
   1746 
   1747  CNN_CONFIG cnn_config = { 3,
   1748                            0,
   1749                            0,
   1750                            0,
   1751                            0,
   1752                            {
   1753                                {
   1754                                    1,
   1755                                    filter_width,
   1756                                    filter_height,
   1757                                    4,
   1758                                    1,
   1759                                    1,
   1760                                    0,
   1761                                    nullptr,
   1762                                    nullptr,
   1763                                    PADDING_SAME_ZERO,
   1764                                    NONE,
   1765                                    0,
   1766                                    0,
   1767                                    BRANCH_OUTPUT,
   1768                                    BRANCH_NOC,
   1769                                    {
   1770                                        0x02,
   1771                                        2,
   1772                                        0x00,
   1773                                    },
   1774                                    {},
   1775                                    -1,
   1776                                },
   1777                                {
   1778                                    4,
   1779                                    filter_width,
   1780                                    filter_height,
   1781                                    2,
   1782                                    1,
   1783                                    1,
   1784                                    0,
   1785                                    nullptr,
   1786                                    nullptr,
   1787                                    PADDING_SAME_ZERO,
   1788                                    NONE,
   1789                                    0,
   1790                                    0,
   1791                                    BRANCH_NO_COPY,
   1792                                    BRANCH_CAT,
   1793                                    {
   1794                                        0x00,
   1795                                        0,
   1796                                        0x02,
   1797                                    },
   1798                                    {},
   1799                                    -1,
   1800                                },
   1801                                {
   1802                                    4,
   1803                                    filter_width,
   1804                                    filter_height,
   1805                                    1,
   1806                                    1,
   1807                                    1,
   1808                                    0,
   1809                                    nullptr,
   1810                                    nullptr,
   1811                                    PADDING_SAME_ZERO,
   1812                                    NONE,
   1813                                    0,
   1814                                    0,
   1815                                    BRANCH_NO_COPY,
   1816                                    BRANCH_NOC,
   1817                                    {},
   1818                                    {},
   1819                                    0,
   1820                                },
   1821                            } };
   1822 
   1823  // Weights and biases need to be specified separately because
   1824  // of the offset.
   1825  AssignLayerWeightsBiases(&cnn_config, weights, bias);
   1826 
   1827  CNN_THREAD_DATA thread_data = { 1, nullptr };
   1828 
   1829  RunCNNTest(image_width, image_height, input, expected, &cnn_config,
   1830             image_width, &thread_data, MSE_INT_TOL);
   1831 }
   1832 
   1833 TEST_F(CNNTest, TestOutputChannelsCount) {
   1834  int filter_width = 1;
   1835  int filter_height = 1;
   1836 
   1837  int image_width = 2;
   1838  int image_height = 2;
   1839 
   1840  float input[] = { 0, 0, 0, 0 };
   1841 
   1842  float weights[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
   1843 
   1844  float bias[] = { 0, 0, 0, 0, 0, 0 };
   1845 
   1846  float expected[] = {
   1847    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
   1848    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
   1849  };
   1850 
   1851  CNN_CONFIG cnn_config = { 3,
   1852                            0,
   1853                            0,
   1854                            0,
   1855                            0,
   1856                            {
   1857                                {
   1858                                    1,
   1859                                    filter_width,
   1860                                    filter_height,
   1861                                    2,
   1862                                    1,
   1863                                    1,
   1864                                    0,
   1865                                    weights,
   1866                                    bias,
   1867                                    PADDING_SAME_ZERO,
   1868                                    NONE,
   1869                                    0,
   1870                                    0,
   1871                                    BRANCH_INPUT,
   1872                                    BRANCH_NOC,
   1873                                    {
   1874                                        0x06,
   1875                                        0,
   1876                                        0x00,
   1877                                    },
   1878                                    {},
   1879                                    -1,
   1880                                },
   1881                                {
   1882                                    1,
   1883                                    filter_width,
   1884                                    filter_height,
   1885                                    2,
   1886                                    1,
   1887                                    1,
   1888                                    0,
   1889                                    weights,
   1890                                    bias,
   1891                                    PADDING_SAME_ZERO,
   1892                                    NONE,
   1893                                    0,
   1894                                    2,
   1895                                    BRANCH_NO_COPY,
   1896                                    BRANCH_CAT,
   1897                                    {
   1898                                        0x00,
   1899                                        0,
   1900                                        0x03,
   1901                                    },
   1902                                    {},
   1903                                    -1,
   1904                                },
   1905                                {
   1906                                    2,
   1907                                    filter_width,
   1908                                    filter_height,
   1909                                    2,
   1910                                    1,
   1911                                    1,
   1912                                    0,
   1913                                    weights,
   1914                                    bias,
   1915                                    PADDING_SAME_ZERO,
   1916                                    NONE,
   1917                                    0,
   1918                                    0,
   1919                                    BRANCH_NO_COPY,
   1920                                    BRANCH_CAT,
   1921                                    {
   1922                                        0x00,
   1923                                        0,
   1924                                        0x04,
   1925                                    },
   1926                                    {},
   1927                                    0,
   1928                                },
   1929                            } };
   1930 
   1931  // Weights and biases need to be specified separately because
   1932  // of the offset.
   1933  AssignLayerWeightsBiases(&cnn_config, weights, bias);
   1934 
   1935  CNN_THREAD_DATA thread_data = { 1, nullptr };
   1936 
   1937  RunCNNTest(image_width, image_height, input, expected, &cnn_config,
   1938             image_width, &thread_data, MSE_FLOAT_TOL);
   1939 }
   1940 
   1941 TEST_F(CNNTest, TestBatchNorm) {
   1942  int image_width = 28;
   1943  int image_height = 28;
   1944  int filter_height = 7;
   1945  int filter_width = 7;
   1946  float input[] = {
   1947    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1948    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1949    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1950    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1951    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1952    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1953    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1954    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1955    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1956    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1957    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1958    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1959    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1960    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1961    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1962    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1963    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1964    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1965    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1966    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1967    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1968    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1969    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1970    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1971    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1972    0.0f,       0.0f,       0.0117647f,  0.0705882f,  0.0705882f,  0.0705882f,
   1973    0.494118f,  0.533333f,  0.686275f,   0.101961f,   0.65098f,    1.0f,
   1974    0.968627f,  0.498039f,  0.0f,        0.0f,        0.0f,        0.0f,
   1975    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1976    0.0f,       0.0f,       0.117647f,   0.141176f,   0.368627f,   0.603922f,
   1977    0.666667f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.992157f,
   1978    0.882353f,  0.67451f,   0.992157f,   0.94902f,    0.764706f,   0.25098f,
   1979    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1980    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.192157f,
   1981    0.933333f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.992157f,
   1982    0.992157f,  0.992157f,  0.992157f,   0.984314f,   0.364706f,   0.321569f,
   1983    0.321569f,  0.219608f,  0.152941f,   0.0f,        0.0f,        0.0f,
   1984    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1985    0.0f,       0.0f,       0.0f,        0.0705882f,  0.858824f,   0.992157f,
   1986    0.992157f,  0.992157f,  0.992157f,   0.992157f,   0.776471f,   0.713725f,
   1987    0.968627f,  0.945098f,  0.0f,        0.0f,        0.0f,        0.0f,
   1988    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1989    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1990    0.0f,       0.0f,       0.313725f,   0.611765f,   0.419608f,   0.992157f,
   1991    0.992157f,  0.803922f,  0.0431373f,  0.0f,        0.168627f,   0.603922f,
   1992    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1993    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1994    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1995    0.0f,       0.054902f,  0.00392157f, 0.603922f,   0.992157f,   0.352941f,
   1996    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1997    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1998    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   1999    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2000    0.0f,       0.545098f,  0.992157f,   0.745098f,   0.00784314f, 0.0f,
   2001    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2002    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2003    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2004    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0431373f,
   2005    0.745098f,  0.992157f,  0.27451f,    0.0f,        0.0f,        0.0f,
   2006    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2007    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2008    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2009    0.0f,       0.0f,       0.0f,        0.0f,        0.137255f,   0.945098f,
   2010    0.882353f,  0.627451f,  0.423529f,   0.00392157f, 0.0f,        0.0f,
   2011    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2012    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2013    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2014    0.0f,       0.0f,       0.0f,        0.317647f,   0.941176f,   0.992157f,
   2015    0.992157f,  0.466667f,  0.0980392f,  0.0f,        0.0f,        0.0f,
   2016    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2017    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2018    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2019    0.0f,       0.0f,       0.176471f,   0.729412f,   0.992157f,   0.992157f,
   2020    0.588235f,  0.105882f,  0.0f,        0.0f,        0.0f,        0.0f,
   2021    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2022    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2023    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2024    0.0f,       0.0627451f, 0.364706f,   0.988235f,   0.992157f,   0.733333f,
   2025    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2026    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2027    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2028    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2029    0.0f,       0.976471f,  0.992157f,   0.976471f,   0.25098f,    0.0f,
   2030    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2031    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2032    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2033    0.0f,       0.0f,       0.180392f,   0.509804f,   0.717647f,   0.992157f,
   2034    0.992157f,  0.811765f,  0.00784314f, 0.0f,        0.0f,        0.0f,
   2035    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2036    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2037    0.0f,       0.0f,       0.0f,        0.0f,        0.152941f,   0.580392f,
   2038    0.898039f,  0.992157f,  0.992157f,   0.992157f,   0.980392f,   0.713725f,
   2039    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2040    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2041    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2042    0.0941176f, 0.447059f,  0.866667f,   0.992157f,   0.992157f,   0.992157f,
   2043    0.992157f,  0.788235f,  0.305882f,   0.0f,        0.0f,        0.0f,
   2044    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2045    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2046    0.0f,       0.0f,       0.0901961f,  0.258824f,   0.835294f,   0.992157f,
   2047    0.992157f,  0.992157f,  0.992157f,   0.776471f,   0.317647f,   0.00784314f,
   2048    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2049    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2050    0.0f,       0.0f,       0.0f,        0.0f,        0.0705882f,  0.670588f,
   2051    0.858824f,  0.992157f,  0.992157f,   0.992157f,   0.992157f,   0.764706f,
   2052    0.313725f,  0.0352941f, 0.0f,        0.0f,        0.0f,        0.0f,
   2053    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2054    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2055    0.215686f,  0.67451f,   0.886275f,   0.992157f,   0.992157f,   0.992157f,
   2056    0.992157f,  0.956863f,  0.521569f,   0.0431373f,  0.0f,        0.0f,
   2057    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2058    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2059    0.0f,       0.0f,       0.0f,        0.0f,        0.533333f,   0.992157f,
   2060    0.992157f,  0.992157f,  0.831373f,   0.529412f,   0.517647f,   0.0627451f,
   2061    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2062    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2063    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2064    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2065    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2066    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2067    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2068    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2069    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2070    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2071    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2072    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2073    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2074    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2075    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2076    0.0f,       0.0f,       0.0f,        0.0f,        0.0f,        0.0f,
   2077    0.0f,       0.0f,       0.0f,        0.0f
   2078  };
   2079  float expected[] = {
   2080    -0.836424f, -0.857365f, -1.62739f,  -1.62739f,  -0.836424f, 5.40742f,
   2081    0.920853f,  -0.692567f, -0.836424f, -0.534405f, -1.62739f,  -0.836424f,
   2082    1.32602f,   1.36312f,   0.112766f,  -0.836424f, -0.192962f, 1.56975f,
   2083    2.45777f,   0.944414f,  -0.192962f, -1.5519f,   -1.5519f,   -0.554006f,
   2084    -0.192962f, 1.4231f,    -1.5519f,   -0.192962f, 1.3661f,    -1.5519f,
   2085    -1.5519f,   -0.192962f, -0.843708f, -0.359025f, -0.843708f, -0.843708f,
   2086    -0.843708f, 4.53065f,   0.0429584f, -0.796804f, -0.843708f, 0.3473f,
   2087    -0.843708f, -0.843708f, -0.114439f, 3.14817f,   0.0811934f, -0.843708f
   2088  };
   2089  float kernel[] = {
   2090    0.119643f,    -0.237864f,   0.0462892f,   0.0502297f,   -0.0134528f,
   2091    0.146347f,    0.153133f,    0.0513307f,   0.0752369f,   0.0135557f,
   2092    -0.111434f,   0.0941854f,   0.0788362f,   0.0299412f,   0.111762f,
   2093    0.144066f,    0.00431504f,  -0.0177954f,  0.0738092f,   -0.0344215f,
   2094    0.0832582f,   0.053989f,    -0.112691f,   0.0962145f,   0.0186525f,
   2095    -0.00660205f, -0.111962f,   -0.126801f,   -0.231625f,   0.17309f,
   2096    0.0748875f,   -0.179569f,   -0.00513812f, -0.156579f,   -0.147322f,
   2097    0.184168f,    0.189308f,    -0.200359f,   -0.0156733f,  0.140649f,
   2098    0.0858496f,   -0.0263217f,  -0.0740749f,  -0.112563f,   0.107528f,
   2099    0.0609729f,   -0.221625f,   0.0769944f,   -0.00900815f, -0.00136441f,
   2100    -0.0236521f,  -0.0418025f,  -0.00286299f, 0.12241f,     0.0964093f,
   2101    -0.0150897f,  0.0532171f,   0.0625916f,   0.116939f,    0.118024f,
   2102    0.161918f,    -0.00909767f, 0.100897f,    -0.054563f,   -0.175179f,
   2103    -0.0687892f,  0.00734235f,  0.109833f,    -0.113776f,   0.0595405f,
   2104    -0.170255f,   0.0124815f,   -0.0363301f,  -0.0127038f,  0.0445554f,
   2105    -0.0729894f,  0.107428f,    -0.0341417f,  0.132619f,    0.00984557f,
   2106    -0.00443654f, 0.202929f,    0.0945134f,   0.0148725f,   0.00998574f,
   2107    -0.0226449f,  0.0478197f,   -0.0793442f,  0.0707599f,   -0.084225f,
   2108    0.0865795f,   0.071104f,    -0.047894f,   0.0838322f,   0.0635493f,
   2109    -0.00370265f, -0.157247f,   -0.0289622f,  -0.0590963f,  0.13207f,
   2110    0.00468011f,  -0.0345372f,  0.217939f,    0.18861f,     -0.0290393f,
   2111    -0.0440664f,  0.0126197f,   -0.129132f,   -0.124943f,   0.0968156f,
   2112    -0.0853643f,  -0.182305f,   0.00461618f,  -0.147095f,   -0.230282f,
   2113    0.00856019f,  0.0278893f,   -0.0300229f,  0.0417871f,   0.0804717f,
   2114    -0.0768571f,  -0.0397085f,  -0.0601096f,  0.100901f,    -0.0184926f,
   2115    0.0350673f,   0.0971094f,   -0.0171837f,  -0.289644f,   -0.0899041f,
   2116    0.08998f,     -0.160319f,   -0.0195103f,  0.0392167f,   -0.137864f,
   2117    -0.0136294f,  0.0330886f,   -0.0409244f,  -0.092533f,   -0.0427934f,
   2118    -0.191144f,   -0.0969461f,  0.112035f,    0.138611f,    0.128717f,
   2119    0.191184f,    0.197462f
   2120  };
   2121  float bias[] = { 0.186703f, 0.204358f, -0.0230452f };
   2122 
   2123  float bn_gamma[] = { 1.32173f, 1.26171f, 1.21966f };
   2124  float bn_beta[] = { -0.232595f, -0.222652f, -0.232209f };
   2125  float bn_mean[] = { 0.329233f, 0.199894f, 0.12389f };
   2126  float bn_std[] = { 0.311986f, 0.189737f, 0.247104f };
   2127 
   2128  CNN_BATCHNORM_PARAMS bn_params = {
   2129    bn_gamma,
   2130    bn_beta,
   2131    bn_mean,
   2132    bn_std,
   2133  };
   2134 
   2135  CNN_CONFIG cnn_config = {
   2136    1,
   2137    0,
   2138    0,
   2139    0,
   2140    0,
   2141    {
   2142        {
   2143            1,
   2144            filter_width,
   2145            filter_height,
   2146            3,
   2147            7,
   2148            7,
   2149            0,
   2150            kernel,
   2151            bias,
   2152            PADDING_VALID,
   2153            RELU,
   2154            0,
   2155            0,
   2156            BRANCH_NO_COPY,
   2157            BRANCH_NOC,
   2158            {},
   2159            bn_params,
   2160            0,
   2161        },
   2162    },
   2163  };
   2164 
   2165  CNN_THREAD_DATA thread_data = { 1, nullptr };
   2166 
   2167  RunCNNTest(image_width, image_height, input, expected, &cnn_config,
   2168             image_width, &thread_data, MSE_FLOAT_TOL);
   2169 }
   2170 
   2171 TEST_F(CNNTest, TestMultithreading) {
   2172  int image_height = 2;
   2173  int image_width = 2;
   2174  int filter_height = 3;
   2175  int filter_width = 3;
   2176 
   2177  float input[] = {
   2178    -2,
   2179    4,
   2180    1,
   2181    0,
   2182  };
   2183 
   2184  float weights[] = {
   2185    -4, 2, -2, 0,  -4, 4, -3, -3, -3, -1, 1,  0,  -5, -3, 0, -5, 0, 0,
   2186    -1, 0, 2,  -5, 0,  1, 4,  2,  1,  0,  -2, -1, -5, -3, 2, -2, 1, -5,
   2187  };
   2188 
   2189  float bias[] = {
   2190    -4,
   2191    -3,
   2192    -2,
   2193    3,
   2194  };
   2195 
   2196  float expected[] = {
   2197    2, 10, -8, -17, -24, 5, -15, 6, -5, -5, 7, -10, 4, 13, 9, -14,
   2198  };
   2199 
   2200  CNN_CONFIG cnn_config = {
   2201    1,
   2202    0,
   2203    0,
   2204    0,
   2205    0,
   2206    {
   2207        {
   2208            1,
   2209            filter_width,
   2210            filter_height,
   2211            4,
   2212            1,
   2213            1,
   2214            0,
   2215            weights,
   2216            bias,
   2217            PADDING_SAME_ZERO,
   2218            NONE,
   2219            0,
   2220            0,
   2221            BRANCH_NO_COPY,
   2222            BRANCH_NOC,
   2223            {},
   2224            {},
   2225            0,
   2226        },
   2227    },
   2228  };
   2229 
   2230  CNN_THREAD_DATA thread_data = { 1, nullptr };
   2231 
   2232  RunCNNTest(image_width, image_height, input, expected, &cnn_config,
   2233             image_width, &thread_data, MSE_FLOAT_TOL);
   2234 
   2235  const AVxWorkerInterface *const winterface = aom_get_worker_interface();
   2236  AVxWorker workers[4];
   2237 
   2238  for (int i = 0; i < 4; ++i) {
   2239    winterface->init(&workers[i]);
   2240  }
   2241 
   2242  thread_data = { 4, workers };
   2243 
   2244  RunCNNTest(image_width, image_height, input, expected, &cnn_config,
   2245             image_width, &thread_data, MSE_FLOAT_TOL);
   2246 
   2247  for (int i = 0; i < 4; ++i) {
   2248    winterface->end(&workers[i]);
   2249  }
   2250 }
   2251 
   2252 TEST_F(CNNTest, TestMultiOutput) {
   2253  const int image_dim = 8;
   2254  const int image_ch = 3;
   2255  const int filter_dim = 2;
   2256  const int stride = 2;
   2257  const int num_filters = 2;
   2258 
   2259  const float input_[] = {
   2260    1.7537929121f,     0.134331551012f,    0.123580039877f,   0.957731845246f,
   2261    0.391006834217f,   1.00699352042f,     -0.778177955829f,  -0.814166433059f,
   2262    -0.656374394915f,  0.321967305228f,    -2.19455719176f,   0.708035038966f,
   2263    0.409148822266f,   -0.318254408902f,   0.152450211189f,   -0.250210793369f,
   2264    0.826811563186f,   1.6804156584f,      0.273626975978f,   0.437936241887f,
   2265    -0.329935520167f,  -0.288761611645f,   0.156937008304f,   0.271054157295f,
   2266    -0.0224828854332f, 1.70110336895f,     -0.989066699309f,  1.30863131729f,
   2267    -0.165813705702f,  0.00380178619265f,  -0.0837342367587f, 0.760954783156f,
   2268    -0.413610373524f,  1.17968204175f,     0.720295719536f,   0.308718974472f,
   2269    -1.10091337671f,   0.693160033687f,    -0.0202862320697f, 1.0221927503f,
   2270    -1.24521801881f,   -0.478501952308f,   -1.71648619442f,   -0.182571723636f,
   2271    0.339292649504f,   2.0806519131f,      0.967974033444f,   0.175248672328f,
   2272    0.0658124561472f,  0.795504169496f,    0.750592557361f,   -1.46631013249f,
   2273    -1.79052846838f,   -1.03672179515f,    -0.841985521653f,  1.20995011489f,
   2274    0.140859718215f,   -0.651552622661f,   0.451065110806f,   1.1189443693f,
   2275    0.100213260593f,   -0.834076868118f,   -1.28734321611f,   1.22064420095f,
   2276    -0.364143084361f,  0.750961509335f,    -0.888689074553f,  -0.8253547106f,
   2277    -1.21800999027f,   -0.966670603566f,   1.37384014741f,    0.47281264834f,
   2278    -0.420416235531f,  0.520163906493f,    0.501296589423f,   1.53418976951f,
   2279    0.715234751485f,   0.644551588907f,    0.0763504863375f,  -0.0018541943723f,
   2280    0.322853189656f,   -0.795099723224f,   -0.125177096675f,  1.4476577471f,
   2281    -0.585888410088f,  -1.44391754955f,    -0.610543221933f,  -0.221859179799f,
   2282    0.252060200774f,   -0.86287169623f,    -0.0350246229157f, 1.0932311997f,
   2283    0.899464648842f,   -0.468806951704f,   -0.300861137168f,  1.15776414206f,
   2284    1.03268544738f,    -0.171579585622f,   -0.179136557119f,  -0.354091003368f,
   2285    -0.612298249394f,  -1.20237379258f,    1.54604109659f,    0.130664370287f,
   2286    0.885225111868f,   1.0362799581f,      0.980561720868f,   -0.619379186999f,
   2287    -1.33818929924f,   -0.237233737961f,   -1.89335425073f,   0.567821011321f,
   2288    0.862420368465f,   -1.37380916821f,    0.352190056666f,   0.611261516274f,
   2289    0.393237747152f,   0.894686247967f,    0.190405182149f,   0.264872662911f,
   2290    -0.0657009133797f, 0.0580512653493f,   -0.401825294366f,  0.4106081318f,
   2291    0.49484512188f,    -0.0751103149442f,  -1.43243736382f,   1.79855656009f,
   2292    -1.1075351975f,    0.000354882733011f, -0.950716438608f,  1.27129831688f,
   2293    1.00495189838f,    0.110358656713f,    1.08315032822f,    -0.972676676218f,
   2294    -0.0757668962831f, 1.88932045165f,     -0.0672638136275f, 0.425913010161f,
   2295    -0.781540372017f,  0.976000248609f,    0.687218504122f,   1.31374513445f,
   2296    -0.932658930672f,  -1.25339468479f,    0.422071294078f,   -0.24189927912f,
   2297    0.216906604642f,   -1.88720997548f,    1.99252872889f,    0.353943735777f,
   2298    0.737434784132f,   -1.17848645017f,    1.70424254896f,    0.775297112968f,
   2299    -0.516392797501f,  0.398130609129f,    0.737248101457f,   0.166282500886f,
   2300    1.24699015468f,    0.47116183125f,     1.19091180182f,    -0.372695424578f,
   2301    0.219773209389f,   -0.829467838962f,   -0.52533122724f,   1.98707754595f,
   2302    0.553692606972f,   -0.933228902369f,   1.55427751643f,    -1.08813399144f,
   2303    -0.325686682094f,  0.205091443796f,    -1.70381666435f,   0.466465327942f,
   2304    1.73126863447f,    -0.939133672634f,   1.48318077459f,    -0.599414038168f,
   2305    -1.1583078687f,    0.518116190201f,    0.133571482458f,   0.84958342672f,
   2306    1.02205000597f,    -0.0772082009087f,  -1.69567503859f,   1.4697939436f,
   2307    1.67813743122f,    -0.627911582938f,   0.131380509137f,   -1.35717850726f,
   2308  };
   2309  const float *input[3] = { input_, &input_[image_dim * image_dim],
   2310                            &input_[2 * image_dim * image_dim] };
   2311 
   2312  const float bias[] = { 0.0f, 0.0f };
   2313 
   2314  const float weights_1[] = {
   2315    -0.489547413618f, 0.141916424749f,  -0.279286485585f,  -0.115322211094f,
   2316    0.299572786936f,  0.205289980785f,  -0.536254480088f,  -0.253626313744f,
   2317    -0.422883815849f, -0.169702966298f, -0.540104704793f,  0.495319646763f,
   2318    0.298799079422f,  -0.10054550901f,  -0.306085047056f,  0.171061886165f,
   2319    -0.108058703878f, -0.410734629888f, -0.0640674673049f, -0.386524840979f,
   2320    -0.157203423678f, -0.362138920529f, -0.216206085209f,  0.147502517971f,
   2321  };
   2322 
   2323  const float weights_2[] = {
   2324    0.207580604357f,  0.480821146263f,  -0.29111909562f,   0.47422567493f,
   2325    0.206892553253f,  -0.235067084092f, 0.354516800602f,   -0.212399370252f,
   2326    -0.419071343731f, -0.050350731631f, -0.0516457320279f, -0.0359310500731f,
   2327    0.567044864811f,  -0.060341127522f, 0.0501464839637f,  -0.437785677916f,
   2328  };
   2329 
   2330  const float weights_3[] = {
   2331    -0.0690452401448f, -0.356657338763f,   -0.219464031809f, 0.551288365843f,
   2332    0.181372090853f,   -0.00245268542109f, 0.409000696276f,  -0.593209108763f,
   2333    0.587352566749f,   -0.243720660227f,   0.266232713887f,  -0.00439285245097f,
   2334    0.252883228305f,   0.152646192631f,    0.0918944932026f, 0.398853715057f,
   2335  };
   2336 
   2337  const float weights_4[] = {
   2338    0.207560791573f,   0.194201350401f,   0.227802322443f,  0.206533663345f,
   2339    0.0557331066805f,  0.0224159800424f,  -0.143939197467f, -0.27703361602f,
   2340    0.130643888389f,   -0.269456557461f,  0.186242862864f,  -0.162879944774f,
   2341    -0.145503996718f,  -0.0768822987581f, -0.203127976359f, -0.238119922873f,
   2342    -0.258806479994f,  0.0357957680385f,  -0.1027606976f,   -0.287920082345f,
   2343    0.189047820993f,   0.250711538481f,   -0.272815714175f, -0.0431449742024f,
   2344    0.207261230996f,   -0.0396472677451f, 0.131236557412f,  0.174291832499f,
   2345    -0.251515885765f,  -0.107164007499f,  0.185824534748f,  -0.00561585838161f,
   2346    0.273393799578f,   -0.139563699075f,  -0.263922456031f, -0.118859844081f,
   2347    0.109230982597f,   -0.170170294794f,  0.0123025648515f, -0.0839368964355f,
   2348    -0.0774058234297f, 0.255847138286f,   -0.208430879637f, 0.279170114319f,
   2349    -0.272890330712f,  -0.217725903006f,  -0.295923275459f, -0.17008723953f,
   2350    -0.284281803405f,  0.281406323629f,   0.266910044663f,  -0.209963914338f,
   2351    0.271980962964f,   0.142013581699f,   -0.143896509026f, -0.290509242975f,
   2352    -0.305768180935f,  0.196902832117f,   -0.090424189662f, -0.147460802346f,
   2353    0.217722016651f,   0.12353848977f,    -0.169177363577f, -0.0454230918512f,
   2354  };
   2355 
   2356  const float expected_0[] = {
   2357    -2.04858441055f,  -2.12883075791f,    -0.045177363807f, 0.763949675768f,
   2358    -0.544361512821f, -1.58123168032f,    1.89319847039f,   0.16859080901f,
   2359    -1.16023321135f,  -0.396988107751f,   1.76637090744f,   -1.40434786514f,
   2360    0.908227575669f,  0.817064817605f,    0.215631134908f,  -0.848605613428f,
   2361    -0.106756747018f, 0.0193027166685f,   0.801345615113f,  -0.395407237598f,
   2362    -1.79983795658f,  -1.73054496242f,    0.0584392594454f, -0.388786095569f,
   2363    -0.237269619354f, 0.000843578271263f, -1.24043512104f,  0.487839445893f,
   2364    -0.394259726605f, 0.559632843424f,    -0.527224052291f, -1.53792340282f,
   2365  };
   2366 
   2367  const float expected_1[] = {
   2368    0.0f, 0.0f,           0.0f, 0.0f, 0.4057888292f, 0.325309571755f,
   2369    0.0f, 1.22013465602f,
   2370  };
   2371 
   2372  const float expected_2[] = {
   2373    0.156119444687f,
   2374    0.517385299817f,
   2375  };
   2376 
   2377  const float expected_3[] = {
   2378    0.224177852984f,
   2379    0.503384419034f,
   2380    0.156119444687f,
   2381    0.517385299817f,
   2382  };
   2383 
   2384  const float *expected[] = { expected_0, expected_1, expected_2, expected_3 };
   2385 
   2386  CNN_CONFIG cnn_config = {
   2387    4,  // num_layers
   2388    0,  // is_residue
   2389    0,  // ext_width
   2390    0,  // ext_height
   2391    0,  // strict_bounds
   2392    {
   2393        // layer_config
   2394        {
   2395            image_ch,           // in_channels
   2396            filter_dim,         // filter_width
   2397            filter_dim,         // filter_height
   2398            num_filters,        // out_channels
   2399            stride,             // skip_width
   2400            stride,             // skip_height
   2401            0,                  // max_pool
   2402            weights_1,          // weights
   2403            bias,               // bias
   2404            PADDING_SAME_ZERO,  // pad
   2405            NONE,               // activation
   2406            0,                  // deconvolve
   2407            0,                  // branch
   2408            BRANCH_OUTPUT,      // branch_copy_type
   2409            BRANCH_NOC,         // branch_combine_type
   2410            { 2, 0, 0 },        // branch_config
   2411            {},                 // bn_params
   2412            0,                  // output_num
   2413        },
   2414        {
   2415            num_filters,        // in_channels
   2416            filter_dim,         // filter_width
   2417            filter_dim,         // filter_height
   2418            num_filters,        // out_channels
   2419            stride,             // skip_width
   2420            stride,             // skip_height
   2421            0,                  // max_pool
   2422            weights_2,          // weights
   2423            bias,               // bias
   2424            PADDING_SAME_ZERO,  // pad
   2425            RELU,               // activation
   2426            0,                  // deconvolve
   2427            0,                  // branch
   2428            BRANCH_NO_COPY,     // branch_copy_type
   2429            BRANCH_NOC,         // branch_combine_type
   2430            {},                 // branch_config
   2431            {},                 // bn_params
   2432            1,                  // output_num
   2433        },
   2434        {
   2435            num_filters,        // in_channels
   2436            filter_dim,         // filter_width
   2437            filter_dim,         // filter_height
   2438            num_filters,        // out_channels
   2439            stride,             // skip_width
   2440            stride,             // skip_height
   2441            0,                  // max_pool
   2442            weights_3,          // weights
   2443            bias,               // bias
   2444            PADDING_SAME_ZERO,  // pad
   2445            RELU,               // activation
   2446            0,                  // deconvolve
   2447            0,                  // branch
   2448            BRANCH_NO_COPY,     // branch_copy_type
   2449            BRANCH_NOC,         // branch_combine_type
   2450            {},                 // branch_config
   2451            {},                 // bn_params
   2452            2,                  // output_num
   2453        },
   2454        {
   2455            num_filters,     // in_channels
   2456            2 * filter_dim,  // filter_width
   2457            2 * filter_dim,  // filter_height
   2458            num_filters,     // out_channels
   2459            2 * stride,      // skip_width
   2460            2 * stride,      // skip_height
   2461            0,               // max_pool
   2462            weights_4,       // weights
   2463            bias,            // bias
   2464            PADDING_VALID,   // pad
   2465            RELU,            // activation
   2466            0,               // deconvolve
   2467            1,               // branch
   2468            BRANCH_NO_COPY,  // branch_copy_type
   2469            BRANCH_CAT,      // branch_combine_type
   2470            { 0, 0, 1 },     // branch_config
   2471            {},              // bn_params
   2472            3,               // output_num
   2473        },
   2474    },
   2475  };
   2476 
   2477  CNN_THREAD_DATA thread_data = { 1, nullptr };
   2478 
   2479  const int num_outputs = 4;
   2480  const int output_chs[4] = { filter_dim, filter_dim, filter_dim,
   2481                              2 * filter_dim };
   2482  const int output_dims[4] = { 4, 2, 1, 1 };
   2483  const int output_sizes[4] = {
   2484    output_chs[0] * output_dims[0] * output_dims[0],
   2485    output_chs[1] * output_dims[1] * output_dims[1],
   2486    output_chs[2] * output_dims[2] * output_dims[2],
   2487    output_chs[3] * output_dims[3] * output_dims[3],
   2488  };
   2489  float *const output_ = (float *)aom_malloc(
   2490      sizeof(*output_) *
   2491      (output_sizes[0] + output_sizes[1] + output_sizes[2] + output_sizes[3]));
   2492  ASSERT_NE(output_, nullptr);
   2493  float *output[CNN_MAX_CHANNELS] = { nullptr };
   2494  int ch_ite = 0;
   2495  float *output_ite = output_;
   2496  for (int output_idx = 0; output_idx < num_outputs; output_idx++) {
   2497    for (int channel = 0; channel < output_chs[output_idx]; ++channel) {
   2498      output[ch_ite++] = output_ite;
   2499      output_ite += output_dims[output_idx] * output_dims[output_idx];
   2500    }
   2501  }
   2502  CNN_MULTI_OUT output_struct = { num_outputs, output_chs, output_dims,
   2503                                  output };
   2504 
   2505  RunMultiOutCNNTest(input, image_dim, image_dim, image_dim, &cnn_config,
   2506                     &thread_data, &output_struct, expected, MSE_FLOAT_TOL);
   2507 
   2508  aom_free(output_);
   2509 }
   2510 
   2511 namespace {
   2512 
   2513 using CNNConvolveNoMaxpoolPaddingValidFunc =
   2514    void (*)(const float **input, int in_width, int in_height, int in_stride,
   2515             const CNN_LAYER_CONFIG *layer_config, float **output,
   2516             int out_stride, int start_idx, int cstep, int channel_step);
   2517 
   2518 using CNNConvolveTestFuncs =
   2519    libaom_test::FuncParam<CNNConvolveNoMaxpoolPaddingValidFunc>;
   2520 
   2521 class CNNConvolveTest : public ::testing::TestWithParam<CNNConvolveTestFuncs> {
   2522 protected:
   2523  void SetUp() override { params_ = GetParam(); }
   2524 
   2525  void RunCNNConvolveSetup(int run_times) {
   2526    int in_width = 65;
   2527    int in_height = 65;
   2528 
   2529    const CNN_CONFIG *cnn_config = &av1_intra_mode_cnn_partition_cnn_config;
   2530 
   2531    for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
   2532      int out_width = 0, out_height = 0;
   2533      int in_size = in_width * in_height;
   2534      // Get current layer output width and height.
   2535      av1_find_cnn_layer_output_size(in_height, in_width,
   2536                                     &cnn_config->layer_config[layer],
   2537                                     &out_width, &out_height);
   2538 
   2539      int out_size = out_width * out_height;
   2540      float *input[20], *output_ref[20], *output_mod[20];
   2541 
   2542      float *input_data =
   2543          (float *)aom_malloc(sizeof(*input_data) * in_size *
   2544                              cnn_config->layer_config[layer].in_channels);
   2545      float *temp_ptr = input_data;
   2546      ASSERT_NE(temp_ptr, nullptr);
   2547      for (int i = 0; i < cnn_config->layer_config[layer].in_channels; ++i) {
   2548        input[i] = temp_ptr;
   2549        for (int j = 0; j < in_size; j++) {
   2550          *(temp_ptr++) = ((float)rng_.Rand31() - (1 << 30)) / (1u << 31);
   2551        }
   2552      }
   2553 
   2554      float *out_data_ref = (float *)aom_calloc(
   2555          sizeof(*out_data_ref),
   2556          out_size * cnn_config->layer_config[layer].out_channels);
   2557      ASSERT_NE(out_data_ref, nullptr);
   2558      float *out_data_mod = (float *)aom_calloc(
   2559          sizeof(*out_data_mod),
   2560          out_size * cnn_config->layer_config[layer].out_channels);
   2561      ASSERT_NE(out_data_mod, nullptr);
   2562      float *temp_ptr1 = out_data_ref;
   2563      float *temp_ptr2 = out_data_mod;
   2564      for (int i = 0; i < cnn_config->layer_config[layer].out_channels; ++i) {
   2565        output_ref[i] = temp_ptr1;
   2566        output_mod[i] = temp_ptr2;
   2567        temp_ptr1 += out_size;
   2568        temp_ptr2 += out_size;
   2569      }
   2570 
   2571      RunCNNConvolveTest(input, in_width, in_height, out_size,
   2572                         &cnn_config->layer_config[layer], 0, 1, run_times,
   2573                         layer, output_ref, output_mod, out_width);
   2574 
   2575      // Set current layer output width and height as next layer input width and
   2576      // height.
   2577      in_width = out_width;
   2578      in_height = out_height;
   2579 
   2580      aom_free(input_data);
   2581      aom_free(out_data_ref);
   2582      aom_free(out_data_mod);
   2583    }
   2584  }
   2585 
   2586  void RunCNNConvolveTest(float **input, int in_width, int in_height,
   2587                          int out_size, const CNN_LAYER_CONFIG *layer_config,
   2588                          int start_idx, int step, int run_times, int layer,
   2589                          float **output_ref, float **output_mod,
   2590                          int out_stride) {
   2591    const int cstep = layer_config->in_channels * layer_config->out_channels;
   2592    const int channel_step = AOMMAX(step, 1);
   2593    aom_usec_timer timer;
   2594    aom_usec_timer_start(&timer);
   2595    for (int i = 0; i < run_times; ++i) {
   2596      params_.ref_func((const float **)input, in_width, in_height, in_width,
   2597                       layer_config, output_ref, out_stride, start_idx, cstep,
   2598                       channel_step);
   2599    }
   2600    aom_usec_timer_mark(&timer);
   2601    const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer));
   2602 
   2603    aom_usec_timer_start(&timer);
   2604    for (int i = 0; i < run_times; ++i) {
   2605      params_.tst_func((const float **)input, in_width, in_height, in_width,
   2606                       layer_config, output_mod, out_stride, start_idx, cstep,
   2607                       channel_step);
   2608    }
   2609    aom_usec_timer_mark(&timer);
   2610    const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer));
   2611 
   2612    if (run_times > 1) {
   2613      printf("layer : %d \n", layer);
   2614      printf("%7.2f/%7.2fns (%3.2f)\n", time1, time2, time1 / time2);
   2615    } else {
   2616      for (int channel = 0; channel < layer_config->out_channels; ++channel) {
   2617        const float *buf_ref = output_ref[channel];
   2618        const float *buf_mod = output_mod[channel];
   2619 
   2620        for (int i = 0; i < out_size; ++i) {
   2621          if (buf_ref[i] < CNN_CONVOLVE_PIXELWISE_FLOAT_TOL) {
   2622            ASSERT_LE(buf_ref[i], CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
   2623                << "Reference output was near-zero, test output was not ("
   2624                << buf_mod[i] << ")";
   2625          } else {
   2626            const float error = buf_ref[i] - buf_mod[i];
   2627            const float relative_error = fabsf(error / buf_ref[i]);
   2628            ASSERT_LE(relative_error, CNN_CONVOLVE_PIXELWISE_FLOAT_TOL)
   2629                << " channel " << channel << " pixel " << i << ": "
   2630                << buf_ref[i] << "/" << buf_mod[i] << std::endl;
   2631          }
   2632        }
   2633      }
   2634    }
   2635  }
   2636 
   2637 private:
   2638  CNNConvolveTestFuncs params_;
   2639  libaom_test::ACMRandom rng_;
   2640 };
   2641 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CNNConvolveTest);
   2642 
   2643 TEST_P(CNNConvolveTest, CheckOutput) { RunCNNConvolveSetup(1); }
   2644 
   2645 TEST_P(CNNConvolveTest, DISABLED_Speed) { RunCNNConvolveSetup(100000); }
   2646 
   2647 #if HAVE_AVX2 && !CONFIG_EXCLUDE_SIMD_MISMATCH
   2648 INSTANTIATE_TEST_SUITE_P(AVX2, CNNConvolveTest,
   2649                         ::testing::Values(CNNConvolveTestFuncs(
   2650                             &av1_cnn_convolve_no_maxpool_padding_valid_c,
   2651                             &av1_cnn_convolve_no_maxpool_padding_valid_avx2)));
   2652 #endif
   2653 
   2654 #if HAVE_NEON
   2655 INSTANTIATE_TEST_SUITE_P(NEON, CNNConvolveTest,
   2656                         ::testing::Values(CNNConvolveTestFuncs(
   2657                             &av1_cnn_convolve_no_maxpool_padding_valid_c,
   2658                             &av1_cnn_convolve_no_maxpool_padding_valid_neon)));
   2659 #endif
   2660 
   2661 }  // namespace