tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

cnn.c (49059B)


      1 /*
      2 * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <assert.h>
     13 #include <math.h>
     14 #include <stdbool.h>
     15 
     16 #include "aom_dsp/aom_dsp_common.h"
     17 #include "av1/common/av1_common_int.h"
     18 #include "av1/encoder/cnn.h"
     19 
     20 #define CLAMPINDEX(a, hi) ((a) < 0 ? 0 : ((a) >= (hi) ? ((hi) - 1) : (a)))
     21 
     22 typedef struct {
     23  const float **input;
     24  int in_width;
     25  int in_height;
     26  int in_stride;
     27  const CNN_LAYER_CONFIG *layer_config;
     28  float **output;
     29  int out_stride;
     30  int start_idx;
     31  int th_step;
     32 } CONVOLVE_OPS;
     33 
     34 static inline float softsign(float x) { return x / (fabsf(x) + 1.0f); }
     35 
     36 static inline float relu(float x) { return (x < 0) ? 0 : x; }
     37 
     38 typedef struct {
     39  int allocsize;
     40  int channels;
     41  int width, height, stride;
     42  float *buf[CNN_MAX_CHANNELS];
     43 } TENSOR;
     44 
     45 static void init_tensor(TENSOR *tensor) { memset(tensor, 0, sizeof(*tensor)); }
     46 
     47 static void free_tensor(TENSOR *tensor) {
     48  if (tensor->allocsize) {
     49    aom_free(tensor->buf[0]);
     50    tensor->buf[0] = NULL;
     51    tensor->allocsize = 0;
     52  }
     53 }
     54 
     55 static bool realloc_tensor(TENSOR *tensor, int channels, int width,
     56                           int height) {
     57  const int newallocsize = channels * width * height;
     58  if (tensor->allocsize < newallocsize) {
     59    free_tensor(tensor);
     60    tensor->buf[0] =
     61        (float *)aom_malloc(sizeof(*tensor->buf[0]) * newallocsize);
     62    if (!tensor->buf[0]) return false;
     63    tensor->allocsize = newallocsize;
     64  }
     65  tensor->width = width;
     66  tensor->height = height;
     67  tensor->stride = width;
     68  tensor->channels = channels;
     69  for (int c = 1; c < channels; ++c)
     70    tensor->buf[c] = &tensor->buf[0][c * width * height];
     71  return true;
     72 }
     73 
     74 static void copy_tensor(const TENSOR *src, int copy_channels, int dst_offset,
     75                        TENSOR *dst) {
     76  assert(src->width == dst->width);
     77  assert(src->height == dst->height);
     78  assert(copy_channels <= src->channels);
     79  if (src->stride == dst->width && dst->stride == dst->width) {
     80    for (int c = 0; c < copy_channels; ++c) {
     81      memcpy(dst->buf[dst_offset + c], src->buf[c],
     82             sizeof(*dst->buf[0]) * src->width * src->height);
     83    }
     84  } else {
     85    for (int c = 0; c < copy_channels; ++c) {
     86      for (int r = 0; r < dst->height; ++r) {
     87        memcpy(&dst->buf[dst_offset + c][r * dst->stride],
     88               &src->buf[c][r * src->stride],
     89               dst->width * sizeof(*dst->buf[c]));
     90      }
     91    }
     92  }
     93 }
     94 
     95 static void assign_tensor(TENSOR *tensor, float *buf[CNN_MAX_CHANNELS],
     96                          int channels, int width, int height, int stride) {
     97  tensor->allocsize = 0;
     98  tensor->channels = channels;
     99  tensor->width = width;
    100  tensor->height = height;
    101  tensor->stride = stride;
    102  if (buf) {
    103    for (int c = 0; c < channels; ++c) tensor->buf[c] = buf[c];
    104  } else {
    105    for (int c = 0; c < channels; ++c) tensor->buf[c] = NULL;
    106  }
    107 }
    108 
    109 static void swap_tensor(TENSOR *t1, TENSOR *t2) {
    110  TENSOR t = *t1;
    111  *t1 = *t2;
    112  *t2 = t;
    113 }
    114 
    115 // The concatenated tensor goes into dst with first the channels in
    116 // original dst followed by the channels in the src
    117 static bool concat_tensor(const TENSOR *src, TENSOR *dst) {
    118  assert(src->width == dst->width);
    119  assert(src->height == dst->height);
    120 
    121  const int dst_channels = dst->channels;
    122  const int channels = dst->channels + src->channels;
    123  const int newallocsize = channels * dst->width * dst->height;
    124  if (dst->allocsize < newallocsize) {
    125    TENSOR t;
    126    init_tensor(&t);
    127    // allocate new buffers and copy first the dst channels
    128    if (!realloc_tensor(&t, channels, dst->width, dst->height)) return false;
    129    copy_tensor(dst, dst->channels, 0, &t);
    130    // Swap the tensors and free the old buffers
    131    swap_tensor(dst, &t);
    132    free_tensor(&t);
    133  }
    134  for (int c = 1; c < channels; ++c)
    135    dst->buf[c] = &dst->buf[0][c * dst->width * dst->height];
    136  // Copy the channels in src after the first dst_channels channels.
    137  copy_tensor(src, src->channels, dst_channels, dst);
    138  return true;
    139 }
    140 
    141 #ifndef NDEBUG
    142 static int check_tensor_equal_dims(TENSOR *t1, TENSOR *t2) {
    143  return (t1->width == t2->width && t1->height == t2->height);
    144 }
    145 
    146 static int check_tensor_equal_size(TENSOR *t1, TENSOR *t2) {
    147  return (t1->channels == t2->channels && t1->width == t2->width &&
    148          t1->height == t2->height);
    149 }
    150 #endif  // NDEBUG
    151 
    152 void av1_find_cnn_layer_output_size(int in_width, int in_height,
    153                                    const CNN_LAYER_CONFIG *layer_config,
    154                                    int *out_width, int *out_height) {
    155  assert(layer_config->skip_width > 0);
    156  assert(layer_config->skip_height > 0);
    157  if (!layer_config->deconvolve) {
    158    switch (layer_config->pad) {
    159      case PADDING_SAME_ZERO:
    160      case PADDING_SAME_REPLICATE:
    161        *out_width = (in_width + layer_config->skip_width - 1) /
    162                     layer_config->skip_width;
    163        *out_height = (in_height + layer_config->skip_height - 1) /
    164                      layer_config->skip_height;
    165        break;
    166      case PADDING_VALID:
    167        *out_width =
    168            (in_width - layer_config->filter_width + layer_config->skip_width) /
    169            layer_config->skip_width;
    170        *out_height = (in_height - layer_config->filter_height +
    171                       layer_config->skip_height) /
    172                      layer_config->skip_height;
    173        break;
    174      default: assert(0 && "Unknown padding type");
    175    }
    176  } else {
    177    switch (layer_config->pad) {
    178      case PADDING_SAME_ZERO:
    179      case PADDING_SAME_REPLICATE:
    180        *out_width = in_width * layer_config->skip_width;
    181        *out_height = in_height * layer_config->skip_height;
    182        break;
    183      case PADDING_VALID:
    184        *out_width = (in_width - 1) * layer_config->skip_width +
    185                     layer_config->filter_width;
    186        *out_height = (in_height - 1) * layer_config->skip_height +
    187                      layer_config->filter_height;
    188        break;
    189      default: assert(0 && "Unknown padding type");
    190    }
    191  }
    192 }
    193 
    194 static void find_cnn_out_channels(const CNN_LAYER_CONFIG *layer_config,
    195                                  int channels_per_branch[]) {
    196  int branch = layer_config->branch;
    197  const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
    198  for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
    199    if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
    200      if (layer_config->branch_copy_type == BRANCH_INPUT) {
    201        channels_per_branch[b] = layer_config->in_channels;
    202      } else if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
    203        channels_per_branch[b] = layer_config->out_channels;
    204      } else if (layer_config->branch_copy_type == BRANCH_COMBINED) {
    205        channels_per_branch[b] = layer_config->out_channels;
    206        for (int c = 0; c < CNN_MAX_BRANCHES; ++c) {
    207          if ((branch_config->branches_to_combine & (1 << c)) && c != branch) {
    208            assert(channels_per_branch[c] > 0);
    209            channels_per_branch[b] += channels_per_branch[c];
    210          }
    211        }
    212      }
    213    }
    214  }
    215  channels_per_branch[branch] = layer_config->out_channels;
    216  for (int c = 0; c < CNN_MAX_BRANCHES; ++c) {
    217    if ((branch_config->branches_to_combine & (1 << c)) && c != branch) {
    218      assert(channels_per_branch[c] > 0);
    219      channels_per_branch[branch] += channels_per_branch[c];
    220    }
    221  }
    222 }
    223 
    224 #if CONFIG_DEBUG
    225 static inline int cnn_has_at_least_one_output(const CNN_CONFIG *cnn_config) {
    226  const int num_layers = cnn_config->num_layers;
    227  const CNN_LAYER_CONFIG *layer_configs = cnn_config->layer_config;
    228 
    229  for (int idx = 0; idx < num_layers; idx++) {
    230    if (layer_configs[idx].output_num != -1) {
    231      return 1;
    232    }
    233  }
    234  return 0;
    235 }
    236 #endif
    237 
    238 void av1_find_cnn_output_size(int in_width, int in_height,
    239                              const CNN_CONFIG *cnn_config, int *out_width,
    240                              int *out_height, int *out_channels) {
    241  int channels_per_branch[CNN_MAX_BRANCHES] = { 0 };
    242  int i_width[CNN_MAX_BRANCHES] = { 0 };
    243  int i_height[CNN_MAX_BRANCHES] = { 0 };
    244  i_width[0] = in_width + cnn_config->ext_width * 2;
    245  i_height[0] = in_height + cnn_config->ext_height * 2;
    246 
    247 #if CONFIG_DEBUG
    248  assert(cnn_has_at_least_one_output(cnn_config));
    249 #endif
    250 
    251  for (int i = 0; i < cnn_config->num_layers; ++i) {
    252    const CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[i];
    253    const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
    254    const int branch = layer_config->branch;
    255    int o_width = 0, o_height = 0;
    256 
    257    if (layer_config->branch_copy_type == BRANCH_INPUT) {
    258      for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
    259        if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
    260          assert(i_width[branch] > 0 && i_height[branch] > 0);
    261          i_width[b] = i_width[branch];
    262          i_height[b] = i_height[branch];
    263        }
    264      }
    265    }
    266 
    267    av1_find_cnn_layer_output_size(i_width[branch], i_height[branch],
    268                                   layer_config, &o_width, &o_height);
    269    i_width[branch] = o_width;
    270    i_height[branch] = o_height;
    271 
    272    if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
    273      for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
    274        if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
    275          i_width[b] = o_width;
    276          i_height[b] = o_height;
    277        }
    278      }
    279    }
    280 
    281    find_cnn_out_channels(layer_config, channels_per_branch);
    282 
    283    const int output_num = layer_config->output_num;
    284    if (output_num != -1) {  // Current layer is an output layer
    285      out_width[output_num] = o_width;
    286      out_height[output_num] = o_height;
    287      out_channels[output_num] = channels_per_branch[layer_config->branch];
    288    }
    289  }
    290 }
    291 
    292 static inline int get_start_shift_convolve(int width, int filt_width,
    293                                           int stride) {
    294  const int mod = (width % stride);
    295  const int filt_off = (filt_width - 1) / 2;
    296  const int dif = (mod ? mod - 1 : stride - 1);
    297  return AOMMIN((dif + (filt_width % 2)) / 2, filt_off);
    298 }
    299 
    300 void av1_cnn_add_c(float **output, int channels, int width, int height,
    301                   int stride, const float **add) {
    302  for (int c = 0; c < channels; ++c) {
    303    for (int i = 0; i < height; ++i)
    304      for (int j = 0; j < width; ++j)
    305        output[c][i * stride + j] += add[c][i * stride + j];
    306  }
    307 }
    308 
    309 void av1_cnn_activate_c(float **output, int channels, int width, int height,
    310                        int stride, ACTIVATION layer_activation) {
    311  if (layer_activation == RELU) {
    312    for (int c = 0; c < channels; ++c) {
    313      for (int i = 0; i < height; ++i)
    314        for (int j = 0; j < width; ++j)
    315          output[c][i * stride + j] = relu(output[c][i * stride + j]);
    316    }
    317  } else if (layer_activation == SOFTSIGN) {
    318    for (int c = 0; c < channels; ++c) {
    319      for (int i = 0; i < height; ++i)
    320        for (int j = 0; j < width; ++j)
    321          output[c][i * stride + j] = softsign(output[c][i * stride + j]);
    322    }
    323  } else if (layer_activation == SIGMOID) {
    324    assert(0 && "Sigmoid has not been supported in CNN.");  // TO DO
    325  } else if (layer_activation != NONE) {
    326    assert(0 && "Unknown activation type");
    327  }
    328 }
    329 
    330 static bool copy_active_tensor_to_branches(const TENSOR *layer_active_tensor,
    331                                           const CNN_LAYER_CONFIG *layer_config,
    332                                           int branch, TENSOR branch_output[]) {
    333  const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
    334  for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
    335    if ((branch_config->input_to_branches & (1 << b)) && b != branch) {
    336      // Copy layer's active tensor to output tensor of branch b if set in
    337      // mask. The output becomes the input of the first layer of the branch
    338      // because the layer of the branch is not the first layer.
    339      int copy_channels = branch_config->channels_to_copy > 0
    340                              ? branch_config->channels_to_copy
    341                              : layer_active_tensor->channels;
    342      if (!realloc_tensor(&branch_output[b], copy_channels,
    343                          layer_active_tensor->width,
    344                          layer_active_tensor->height)) {
    345        return false;
    346      }
    347      copy_tensor(layer_active_tensor, copy_channels, 0, &branch_output[b]);
    348    }
    349  }
    350  return true;
    351 }
    352 
    353 // CNNConvolve specific to maxpool set as 1, either skip_width or skip_height
    354 // greater than 1 and padding equal to PADDING_SAME_ZERO.
    355 static void convolve_maxpool_padding_zero(
    356    const float **input, int in_width, int in_height, int in_stride,
    357    const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
    358    const int cstep, const int filter_width_half,
    359    const int filter_height_half) {
    360  for (int i = 0; i < layer_config->out_channels; ++i) {
    361    for (int h = 0, u = 0; h < in_height; h += layer_config->skip_height, ++u) {
    362      for (int w = 0, v = 0; w < in_width; w += layer_config->skip_width, ++v) {
    363        for (int hh = h; hh < AOMMIN(in_height, h + layer_config->skip_height);
    364             ++hh) {
    365          for (int ww = w; ww < AOMMIN(in_width, w + layer_config->skip_width);
    366               ++ww) {
    367            float sum = layer_config->bias[i];
    368            for (int k = 0; k < layer_config->in_channels; ++k) {
    369              int off = k * layer_config->out_channels + i;
    370              for (int l = 0; l < layer_config->filter_height; ++l) {
    371                const int ii = hh + l - filter_height_half;
    372                for (int m = 0; m < layer_config->filter_width;
    373                     ++m, off += cstep) {
    374                  const int jj = ww + m - filter_width_half;
    375                  if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
    376                    continue;
    377                  sum += layer_config->weights[off] *
    378                         input[k][ii * in_stride + jj];
    379                }
    380              }
    381            }
    382            const float a = sum;
    383            if (h == hh && w == ww)
    384              output[i][u * out_stride + v] = a;
    385            else
    386              output[i][u * out_stride + v] =
    387                  AOMMAX(output[i][u * out_stride + v], a);
    388          }
    389        }
    390      }
    391    }
    392  }
    393 }
    394 
    395 // CNNConvolve specific to maxpool set as 1, either skip_width or skip_height
    396 // greater than 1 and padding equal to PADDING_SAME_REPLICATE.
    397 static void convolve_maxpool_padding_replicate(
    398    const float **input, int in_width, int in_height, int in_stride,
    399    const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
    400    const int cstep, const int filter_width_half,
    401    const int filter_height_half) {
    402  for (int i = 0; i < layer_config->out_channels; ++i) {
    403    for (int h = 0, u = 0; h < in_height; h += layer_config->skip_height, ++u) {
    404      for (int w = 0, v = 0; w < in_width; w += layer_config->skip_width, ++v) {
    405        for (int hh = h; hh < AOMMIN(in_height, h + layer_config->skip_height);
    406             ++hh) {
    407          for (int ww = w; ww < AOMMIN(in_width, w + layer_config->skip_width);
    408               ++ww) {
    409            float sum = layer_config->bias[i];
    410            for (int k = 0; k < layer_config->in_channels; ++k) {
    411              int off = k * layer_config->out_channels + i;
    412              for (int l = 0; l < layer_config->filter_height; ++l) {
    413                const int ii =
    414                    CLAMPINDEX(hh + l - filter_height_half, in_height);
    415                for (int m = 0; m < layer_config->filter_width;
    416                     ++m, off += cstep) {
    417                  const int jj =
    418                      CLAMPINDEX(ww + m - filter_width_half, in_width);
    419                  assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
    420                  sum += layer_config->weights[off] *
    421                         input[k][ii * in_stride + jj];
    422                }
    423              }
    424            }
    425            const float a = sum;
    426            if (h == hh && w == ww)
    427              output[i][u * out_stride + v] = a;
    428            else
    429              output[i][u * out_stride + v] =
    430                  AOMMAX(output[i][u * out_stride + v], a);
    431          }
    432        }
    433      }
    434    }
    435  }
    436 }
    437 
    438 // CNNConvolve specific to maxpool set as 1, either skip_width or skip_height
    439 // greater than 1 and padding equal to PADDING_VALID.
    440 static void convolve_maxpool_padding_valid(
    441    const float **input, int in_width, int in_height, int in_stride,
    442    const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
    443    const int cstep) {
    444  for (int i = 0; i < layer_config->out_channels; ++i) {
    445    for (int h = 0, u = 0; h < in_height - layer_config->filter_height + 1;
    446         h += layer_config->skip_height, ++u) {
    447      for (int w = 0, v = 0; w < in_width - layer_config->filter_width + 1;
    448           w += layer_config->skip_width, ++v) {
    449        for (int hh = h; hh < AOMMIN(in_height, h + layer_config->skip_height);
    450             ++hh) {
    451          for (int ww = w; ww < AOMMIN(in_width, w + layer_config->skip_width);
    452               ++ww) {
    453            float sum = layer_config->bias[i];
    454            for (int k = 0; k < layer_config->in_channels; ++k) {
    455              int off = k * layer_config->out_channels + i;
    456              for (int l = 0; l < layer_config->filter_height; ++l) {
    457                const int ii = hh + l;
    458                for (int m = 0; m < layer_config->filter_width;
    459                     ++m, off += cstep) {
    460                  const int jj = ww + m;
    461                  assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
    462                  sum += layer_config->weights[off] *
    463                         input[k][ii * in_stride + jj];
    464                }
    465              }
    466            }
    467            const float a = sum;
    468            if (h == hh && w == ww)
    469              output[i][u * out_stride + v] = a;
    470            else
    471              output[i][u * out_stride + v] =
    472                  AOMMAX(output[i][u * out_stride + v], a);
    473          }
    474        }
    475      }
    476    }
    477  }
    478 }
    479 
    480 // CNNConvolve specific to maxpool set as 0 with filter_height and filter_width
    481 // equal to 1.
    482 static void convolve_element_wise(const float **input, int in_width,
    483                                  int in_height, int in_stride,
    484                                  const CNN_LAYER_CONFIG *const layer_config,
    485                                  float **output, int out_stride, int start_idx,
    486                                  int step) {
    487  const int start_h = get_start_shift_convolve(
    488      in_height, layer_config->filter_height, layer_config->skip_height);
    489  const int start_w =
    490      get_start_shift_convolve(in_width, layer_config->filter_width,
    491                               layer_config->skip_width) +
    492      start_idx * layer_config->skip_width;
    493  const int out_w_step = AOMMAX(step, 1);
    494  const int in_w_step = layer_config->skip_width * out_w_step;
    495  for (int i = 0; i < layer_config->out_channels; ++i) {
    496    for (int h = start_h, u = 0; h < in_height;
    497         h += layer_config->skip_height, ++u) {
    498      const int in_h = h * in_stride;
    499      const int out_h = u * out_stride + start_idx;
    500      for (int w = start_w, out_index = out_h; w < in_width;
    501           w += in_w_step, out_index += out_w_step) {
    502        float sum = layer_config->bias[i];
    503        for (int k = 0; k < layer_config->in_channels; ++k) {
    504          sum += layer_config->weights[k * layer_config->out_channels + i] *
    505                 input[k][in_h + w];
    506        }
    507        output[i][out_index] = sum;
    508      }
    509    }
    510  }
    511 }
    512 
    513 // CNNConvolve specific to maxpool set as 0 and padding equal to
    514 // PADDING_SAME_ZERO.
    515 static void convolve_no_maxpool_padding_zero(
    516    const float **input, int in_width, int in_height, int in_stride,
    517    const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
    518    int start_idx, const int cstep, const int filter_width_half,
    519    const int filter_height_half, const int ii_shift, const int jj_shift,
    520    const int channel_step) {
    521  const int start_h = get_start_shift_convolve(
    522      in_height, layer_config->filter_height, layer_config->skip_height);
    523  const int start_w = get_start_shift_convolve(
    524      in_width, layer_config->filter_width, layer_config->skip_width);
    525  const int end_ii_shift = filter_height_half + 1;
    526  const int end_jj_shift = filter_width_half + 1;
    527  // *_filter_margin stores the number of pixels along a dimension in the
    528  // intersection of the complement of the image in the extended image
    529  // and the filter.
    530  const int top_filter_margin = layer_config->filter_width * ii_shift;
    531  const int right_filter_margin = end_jj_shift - in_width;
    532  for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
    533    for (int h = start_h, u = 0; h < in_height;
    534         h += layer_config->skip_height, ++u) {
    535      const int out_h = u * out_stride;
    536      const int top_cstep =
    537          AOMMAX(0, top_filter_margin - h * layer_config->filter_width) *
    538              cstep +
    539          i;
    540      const int start_ii = AOMMAX(0, h - ii_shift);
    541      const int end_ii = AOMMIN(in_height, h + end_ii_shift);
    542      for (int w = start_w, out_index = out_h; w < in_width;
    543           w += layer_config->skip_width, ++out_index) {
    544        const int left_cstep = AOMMAX(0, jj_shift - w) * cstep;
    545        const int right_cstep = AOMMAX(0, right_filter_margin + w) * cstep;
    546        const int start_jj = AOMMAX(0, w - jj_shift);
    547        const int end_jj = AOMMIN(in_width, w + end_jj_shift);
    548        float sum = layer_config->bias[i];
    549        for (int k = 0; k < layer_config->in_channels; ++k) {
    550          int off = k * layer_config->out_channels + top_cstep;
    551          for (int ii = start_ii; ii < end_ii; ++ii) {
    552            off += left_cstep;
    553            for (int jj = start_jj; jj < end_jj; ++jj, off += cstep) {
    554              sum += layer_config->weights[off] * input[k][ii * in_stride + jj];
    555            }
    556            off += right_cstep;
    557          }
    558        }
    559        output[i][out_index] = sum;
    560      }
    561    }
    562  }
    563 }
    564 
    565 // CNNConvolve specific to maxpool set as 0 and padding equal to
    566 // PADDING_SAME_REPLICATE.
    567 static void convolve_no_maxpool_padding_replicate(
    568    const float **input, int in_width, int in_height, int in_stride,
    569    const CNN_LAYER_CONFIG *const layer_config, float **output, int out_stride,
    570    int start_idx, const int cstep, const int ii_shift, const int jj_shift,
    571    const int channel_step) {
    572  // h and w are shifted to an offset coordinate system to reduce in-loop
    573  // computation.
    574  const int start_h =
    575      get_start_shift_convolve(in_height, layer_config->filter_height,
    576                               layer_config->skip_height) -
    577      ii_shift;
    578  const int start_w =
    579      get_start_shift_convolve(in_width, layer_config->filter_width,
    580                               layer_config->skip_width) -
    581      jj_shift;
    582  const int end_h = in_height - ii_shift;
    583  const int end_w = in_width - jj_shift;
    584  for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
    585    for (int h = start_h, u = 0; h < end_h;
    586         h += layer_config->skip_height, ++u) {
    587      const int out_h = u * out_stride;
    588      const int upper_ii_index = layer_config->filter_height + h;
    589      for (int w = start_w, out_index = out_h; w < end_w;
    590           w += layer_config->skip_width, ++out_index) {
    591        const int upper_jj_index = layer_config->filter_width + w;
    592        float sum = layer_config->bias[i];
    593        for (int k = 0; k < layer_config->in_channels; ++k) {
    594          int off = k * layer_config->out_channels + i;
    595          for (int ii = h; ii < upper_ii_index; ++ii) {
    596            const int clamped_ii = CLAMPINDEX(ii, in_height);
    597            for (int jj = w; jj < upper_jj_index; ++jj) {
    598              const int clamped_jj = CLAMPINDEX(jj, in_width);
    599              assert(clamped_ii >= 0 && clamped_ii < in_height &&
    600                     clamped_jj >= 0 && clamped_jj < in_width);
    601              sum += layer_config->weights[off] *
    602                     input[k][clamped_ii * in_stride + clamped_jj];
    603              off += cstep;
    604            }
    605          }
    606        }
    607        output[i][out_index] = sum;
    608      }
    609    }
    610  }
    611 }
    612 
    613 // CNNConvolve specific to maxpool set as 0 and padding equal to
    614 // PADDING_VALID.
    615 void av1_cnn_convolve_no_maxpool_padding_valid_c(
    616    const float **input, int in_width, int in_height, int in_stride,
    617    const CNN_LAYER_CONFIG *layer_config, float **output, int out_stride,
    618    int start_idx, int cstep, int channel_step) {
    619  assert((layer_config->skip_height == 1 && layer_config->skip_width == 1) ||
    620         !layer_config->maxpool);
    621  assert(layer_config->filter_height > 1 || layer_config->filter_width > 1);
    622  assert(layer_config->pad == PADDING_VALID);
    623  for (int i = start_idx; i < layer_config->out_channels; i += channel_step) {
    624    for (int h = 0, u = 0; h < in_height - layer_config->filter_height + 1;
    625         h += layer_config->skip_height, ++u) {
    626      const int out_h = u * out_stride;
    627      const int upper_ii_index = layer_config->filter_height + h;
    628      for (int w = 0, out_index = out_h;
    629           w < in_width - layer_config->filter_width + 1;
    630           w += layer_config->skip_width, ++out_index) {
    631        const int upper_jj_index = layer_config->filter_width + w;
    632        float sum = layer_config->bias[i];
    633        for (int k = 0; k < layer_config->in_channels; ++k) {
    634          int off = k * layer_config->out_channels + i;
    635          for (int ii = h; ii < upper_ii_index; ++ii) {
    636            for (int jj = w; jj < upper_jj_index; ++jj) {
    637              assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
    638              sum += layer_config->weights[off] * input[k][ii * in_stride + jj];
    639              off += cstep;
    640            }
    641          }
    642        }
    643        output[i][out_index] = sum;
    644      }
    645    }
    646  }
    647 }
    648 
    649 static void av1_cnn_convolve(const float **input, int in_width, int in_height,
    650                             int in_stride,
    651                             const CNN_LAYER_CONFIG *layer_config,
    652                             float **output, int out_stride, int start_idx,
    653                             int step) {
    654  assert(!layer_config->deconvolve);
    655  const int cstep = layer_config->in_channels * layer_config->out_channels;
    656  const int filter_height_half = layer_config->filter_height >> 1;
    657  const int filter_width_half = layer_config->filter_width >> 1;
    658  const int channel_step = AOMMAX(step, 1);
    659 
    660  if (layer_config->maxpool &&
    661      (layer_config->skip_height > 1 || layer_config->skip_width > 1)) {
    662    switch (layer_config->pad) {
    663      case PADDING_SAME_ZERO:
    664        convolve_maxpool_padding_zero(input, in_width, in_height, in_stride,
    665                                      layer_config, output, out_stride, cstep,
    666                                      filter_width_half, filter_height_half);
    667        break;
    668      case PADDING_SAME_REPLICATE:
    669        convolve_maxpool_padding_replicate(
    670            input, in_width, in_height, in_stride, layer_config, output,
    671            out_stride, cstep, filter_width_half, filter_height_half);
    672        break;
    673      case PADDING_VALID:
    674        convolve_maxpool_padding_valid(input, in_width, in_height, in_stride,
    675                                       layer_config, output, out_stride, cstep);
    676        break;
    677      default: assert(0 && "Unknown padding type");
    678    }
    679  } else {
    680    // Results in element-wise matrix multiplication.
    681    if (layer_config->filter_height == 1 && layer_config->filter_width == 1) {
    682      convolve_element_wise(input, in_width, in_height, in_stride, layer_config,
    683                            output, out_stride, start_idx, step);
    684      return;
    685    }
    686    const int ii_shift =
    687        filter_height_half - (layer_config->filter_height - 1) % 2;
    688    const int jj_shift =
    689        filter_width_half - (layer_config->filter_width - 1) % 2;
    690    switch (layer_config->pad) {
    691      case PADDING_SAME_ZERO:
    692        convolve_no_maxpool_padding_zero(
    693            input, in_width, in_height, in_stride, layer_config, output,
    694            out_stride, start_idx, cstep, filter_width_half, filter_height_half,
    695            ii_shift, jj_shift, channel_step);
    696        break;
    697      case PADDING_SAME_REPLICATE:
    698        convolve_no_maxpool_padding_replicate(
    699            input, in_width, in_height, in_stride, layer_config, output,
    700            out_stride, start_idx, cstep, ii_shift, jj_shift, channel_step);
    701        break;
    702      case PADDING_VALID:
    703        av1_cnn_convolve_no_maxpool_padding_valid(
    704            input, in_width, in_height, in_stride, layer_config, output,
    705            out_stride, start_idx, cstep, channel_step);
    706        break;
    707      default: assert(0 && "Unknown padding type");
    708    }
    709  }
    710 }
    711 
    712 static int convolve_layer(void *arg1, void *arg2) {
    713  const CONVOLVE_OPS *convolve_ops = arg1;
    714  (void)arg2;
    715  av1_cnn_convolve(
    716      convolve_ops->input, convolve_ops->in_width, convolve_ops->in_height,
    717      convolve_ops->in_stride, convolve_ops->layer_config, convolve_ops->output,
    718      convolve_ops->out_stride, convolve_ops->start_idx, convolve_ops->th_step);
    719  return 1;
    720 }
    721 
    722 static void convolve_layer_mt(const float **input, int in_width, int in_height,
    723                              int in_stride,
    724                              const CNN_LAYER_CONFIG *layer_config,
    725                              const CNN_THREAD_DATA *thread_data,
    726                              float **output, int out_stride) {
    727  const AVxWorkerInterface *const winterface = aom_get_worker_interface();
    728  const int num_workers = thread_data->num_workers;
    729  assert(thread_data->workers);
    730 
    731  CONVOLVE_OPS convolve_ops[CNN_MAX_THREADS];
    732  for (int th = 0; th < AOMMIN(num_workers, CNN_MAX_THREADS); ++th) {
    733    AVxWorker *const worker = &thread_data->workers[th];
    734    winterface->reset(worker);
    735 
    736    CONVOLVE_OPS convolve_op = { input,      in_width,     in_height,
    737                                 in_stride,  layer_config, output,
    738                                 out_stride, th,           num_workers };
    739    convolve_ops[th] = convolve_op;
    740    worker->hook = convolve_layer;
    741    worker->data1 = &(convolve_ops[th]);
    742    worker->data2 = NULL;
    743 
    744    // Start convolving.
    745    if (th == num_workers - 1) {
    746      winterface->execute(worker);
    747    } else {
    748      winterface->launch(worker);
    749    }
    750  }
    751 
    752  // Wait until all workers have finished.
    753  for (int th = 0; th < AOMMIN(num_workers, CNN_MAX_THREADS); ++th) {
    754    winterface->sync(&thread_data->workers[th]);
    755  }
    756 }
    757 
    758 static inline int get_start_shift_deconvolve(int filt_width, int stride) {
    759  const int dif = AOMMAX(filt_width - stride, 0);
    760  return dif / 2;
    761 }
    762 
    763 void av1_cnn_batchnorm_c(float **image, int channels, int width, int height,
    764                         int stride, const float *gamma, const float *beta,
    765                         const float *mean, const float *std) {
    766  assert(gamma && beta && beta && std && "batchnorm has null parameter!");
    767  for (int ch = 0; ch < channels; ch++) {
    768    const float ch_gamma = gamma[ch];
    769    const float ch_beta = beta[ch];
    770    const float ch_mean = mean[ch];
    771    const float ch_std = std[ch];
    772    float *image_row = image[ch];
    773 
    774    for (int row = 0; row < height; row++) {
    775      for (int col = 0; col < width; col++) {
    776        image_row[col] =
    777            ch_gamma * (image_row[col] - ch_mean) / ch_std + ch_beta;
    778      }
    779      image_row += stride;
    780    }
    781  }
    782 }
    783 
    784 void av1_cnn_deconvolve_c(const float **input, int in_width, int in_height,
    785                          int in_stride, const CNN_LAYER_CONFIG *layer_config,
    786                          float **output, int out_stride) {
    787  assert(layer_config->deconvolve);
    788 
    789  const int cstep = layer_config->in_channels * layer_config->out_channels;
    790 
    791  int out_width = 0;
    792  int out_height = 0;
    793  av1_find_cnn_layer_output_size(in_width, in_height, layer_config, &out_width,
    794                                 &out_height);
    795  switch (layer_config->pad) {
    796    case PADDING_SAME_ZERO:
    797      for (int i = 0; i < layer_config->out_channels; ++i) {
    798        for (int u = 0; u < out_height; ++u) {
    799          for (int v = 0; v < out_width; ++v) {
    800            float sum = layer_config->bias[i];
    801            for (int k = 0; k < layer_config->in_channels; ++k) {
    802              int off = k * layer_config->out_channels + i;
    803              for (int l = 0; l < layer_config->filter_height; ++l) {
    804                const int h =
    805                    u - l +
    806                    get_start_shift_deconvolve(layer_config->filter_height,
    807                                               layer_config->skip_height);
    808                for (int m = 0; m < layer_config->filter_width;
    809                     ++m, off += cstep) {
    810                  const int w =
    811                      v - m +
    812                      get_start_shift_deconvolve(layer_config->filter_width,
    813                                                 layer_config->skip_width);
    814                  if ((h % layer_config->skip_height) != 0 ||
    815                      (w % layer_config->skip_width) != 0)
    816                    continue;
    817                  const int ii = h / layer_config->skip_height;
    818                  const int jj = w / layer_config->skip_width;
    819                  if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
    820                    continue;
    821                  sum += layer_config->weights[off] *
    822                         input[k][ii * in_stride + jj];
    823                }
    824              }
    825            }
    826            output[i][u * out_stride + v] = sum;
    827          }
    828        }
    829      }
    830      break;
    831    case PADDING_SAME_REPLICATE:
    832      for (int i = 0; i < layer_config->out_channels; ++i) {
    833        for (int u = 0; u < out_height; ++u) {
    834          for (int v = 0; v < out_width; ++v) {
    835            float sum = layer_config->bias[i];
    836            for (int k = 0; k < layer_config->in_channels; ++k) {
    837              int off = k * layer_config->out_channels + i;
    838              for (int l = 0; l < layer_config->filter_height; ++l) {
    839                const int h =
    840                    u - l +
    841                    get_start_shift_deconvolve(layer_config->filter_height,
    842                                               layer_config->skip_height);
    843                for (int m = 0; m < layer_config->filter_width;
    844                     ++m, off += cstep) {
    845                  const int w =
    846                      v - m +
    847                      get_start_shift_deconvolve(layer_config->filter_width,
    848                                                 layer_config->skip_width);
    849                  if ((h % layer_config->skip_height) != 0 ||
    850                      (w % layer_config->skip_width) != 0)
    851                    continue;
    852                  const int ii =
    853                      CLAMPINDEX(h / layer_config->skip_height, in_height);
    854                  const int jj =
    855                      CLAMPINDEX(w / layer_config->skip_width, in_width);
    856                  assert(ii >= 0 && ii < in_height && jj >= 0 && jj < in_width);
    857                  sum += layer_config->weights[off] *
    858                         input[k][ii * in_stride + jj];
    859                }
    860              }
    861            }
    862            output[i][u * out_stride + v] = sum;
    863          }
    864        }
    865      }
    866      break;
    867    case PADDING_VALID:
    868      for (int i = 0; i < layer_config->out_channels; ++i) {
    869        for (int u = 0; u < out_height; ++u) {
    870          for (int v = 0; v < out_width; ++v) {
    871            float sum = layer_config->bias[i];
    872            for (int k = 0; k < layer_config->in_channels; ++k) {
    873              int off = k * layer_config->out_channels + i;
    874              for (int l = 0; l < layer_config->filter_height; ++l) {
    875                const int h = u - l;
    876                for (int m = 0; m < layer_config->filter_width;
    877                     ++m, off += cstep) {
    878                  const int w = v - m;
    879                  if ((h % layer_config->skip_height) != 0 ||
    880                      (w % layer_config->skip_width) != 0)
    881                    continue;
    882                  const int ii = h / layer_config->skip_height;
    883                  const int jj = w / layer_config->skip_width;
    884                  if (ii < 0 || ii >= in_height || jj < 0 || jj >= in_width)
    885                    continue;
    886                  sum += layer_config->weights[off] *
    887                         input[k][ii * in_stride + jj];
    888                }
    889              }
    890            }
    891            output[i][u * out_stride + v] = sum;
    892          }
    893        }
    894      }
    895      break;
    896    default: assert(0 && "Unknown padding type");
    897  }
    898 }
    899 
    900 bool av1_cnn_predict_c(const float **input, int in_width, int in_height,
    901                       int in_stride, const CNN_CONFIG *cnn_config,
    902                       const CNN_THREAD_DATA *thread_data,
    903                       CNN_MULTI_OUT *output_struct) {
    904  bool success = false;
    905  TENSOR tensor1[CNN_MAX_BRANCHES] = { { 0 } };
    906  TENSOR tensor2[CNN_MAX_BRANCHES] = { { 0 } };
    907 
    908  float **output[CNN_MAX_BRANCHES];
    909  const int *out_chs = output_struct->output_channels;
    910  output[0] = output_struct->output_buffer;
    911  for (int out_idx = 1; out_idx < output_struct->num_outputs; out_idx++) {
    912    output[out_idx] = output[out_idx - 1] + out_chs[out_idx - 1];
    913  }
    914 
    915  int i_width = in_width;
    916  int i_height = in_height;
    917  int o_width = 0, o_height = 0;
    918  for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
    919    init_tensor(&tensor1[b]);
    920    init_tensor(&tensor2[b]);
    921  }
    922 
    923  const int *out_stride = output_struct->output_strides;
    924  for (int layer = 0; layer < cnn_config->num_layers; ++layer) {
    925    const CNN_LAYER_CONFIG *layer_config = &cnn_config->layer_config[layer];
    926    const int branch = layer_config->branch;
    927    const CNN_BRANCH_CONFIG *branch_config = &layer_config->branch_config;
    928 
    929    // Allocate input tensor
    930    if (layer == 0) {       // First layer
    931      assert(branch == 0);  // First layer must be primary branch
    932      assign_tensor(&tensor1[branch], (float **)input,
    933                    layer_config->in_channels, in_width, in_height, in_stride);
    934    } else {  // Non-first layer
    935      // Swap tensor1 and tensor2
    936      swap_tensor(&tensor1[branch], &tensor2[branch]);
    937 
    938      i_width = tensor1[branch].width;
    939      i_height = tensor1[branch].height;
    940    }
    941 
    942    // Allocate output tensor
    943    av1_find_cnn_layer_output_size(i_width, i_height, layer_config, &o_width,
    944                                   &o_height);
    945    const int output_num = layer_config->output_num;
    946    if (output_num == -1) {  // Non-output layer
    947      if (!realloc_tensor(&tensor2[branch], layer_config->out_channels, o_width,
    948                          o_height)) {
    949        goto Error;
    950      }
    951    } else {  // Output layer
    952      free_tensor(&tensor2[branch]);
    953      assign_tensor(&tensor2[branch], output[output_num],
    954                    layer_config->out_channels, o_width, o_height,
    955                    out_stride[output_num]);
    956    }
    957 
    958    // If we are combining branches make sure that the branch to combine
    959    // is different from the current branch.
    960    assert(IMPLIES(layer_config->branch_combine_type != BRANCH_NOC,
    961                   !(branch_config->branches_to_combine & (1 << branch))));
    962 
    963    if (layer_config->branch_copy_type == BRANCH_INPUT) {
    964      if (!copy_active_tensor_to_branches(&tensor1[branch], layer_config,
    965                                          branch, tensor2)) {
    966        goto Error;
    967      }
    968    }
    969    // Check consistency of input and output channels
    970    assert(tensor1[branch].channels == layer_config->in_channels);
    971    assert(tensor2[branch].channels == layer_config->out_channels);
    972 
    973    // Convolve/Deconvolve
    974    if (!cnn_config->layer_config[layer].deconvolve) {
    975      if (thread_data->num_workers > 1) {
    976        convolve_layer_mt((const float **)tensor1[branch].buf,
    977                          tensor1[branch].width, tensor1[branch].height,
    978                          tensor1[branch].stride, layer_config, thread_data,
    979                          tensor2[branch].buf, tensor2[branch].stride);
    980      } else {
    981        av1_cnn_convolve((const float **)tensor1[branch].buf,
    982                         tensor1[branch].width, tensor1[branch].height,
    983                         tensor1[branch].stride, layer_config,
    984                         tensor2[branch].buf, tensor2[branch].stride, 0, 1);
    985      }
    986    } else {
    987      av1_cnn_deconvolve((const float **)tensor1[branch].buf,
    988                         tensor1[branch].width, tensor1[branch].height,
    989                         tensor1[branch].stride, layer_config,
    990                         tensor2[branch].buf, tensor2[branch].stride);
    991    }
    992 
    993    if (layer_config->branch_copy_type == BRANCH_OUTPUT) {
    994      if (!copy_active_tensor_to_branches(&tensor2[branch], layer_config,
    995                                          branch, tensor2)) {
    996        goto Error;
    997      }
    998    }
    999 
   1000    // Add tensors from other branches if needed
   1001    if (layer_config->branch_combine_type == BRANCH_ADD) {
   1002      for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
   1003        if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
   1004          assert(check_tensor_equal_size(&tensor2[b], &tensor2[branch]));
   1005          av1_cnn_add(tensor2[branch].buf, tensor2[branch].channels,
   1006                      tensor2[branch].width, tensor2[branch].height,
   1007                      tensor2[branch].stride, (const float **)tensor2[b].buf);
   1008        }
   1009      }
   1010    }
   1011 
   1012    // Non-linearity
   1013    av1_cnn_activate(tensor2[branch].buf, tensor2[branch].channels,
   1014                     tensor2[branch].width, tensor2[branch].height,
   1015                     tensor2[branch].stride, layer_config->activation);
   1016 
   1017    if (layer_config->bn_params.bn_gamma) {
   1018      av1_cnn_batchnorm(
   1019          tensor2[branch].buf, tensor2[branch].channels, tensor2[branch].width,
   1020          tensor2[branch].height, tensor2[branch].stride,
   1021          layer_config->bn_params.bn_gamma, layer_config->bn_params.bn_beta,
   1022          layer_config->bn_params.bn_mean, layer_config->bn_params.bn_std);
   1023    }
   1024 
   1025    // Concatenate tensors
   1026    if (layer_config->branch_combine_type == BRANCH_CAT) {
   1027      if (output_num == -1) {  // Non-output layer
   1028        for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
   1029          if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
   1030            assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
   1031            assert(tensor2[b].channels > 0);
   1032            if (!concat_tensor(&tensor2[b], &tensor2[branch])) goto Error;
   1033          }
   1034        }
   1035      } else {  // Output layer
   1036        const int existing_channels = tensor2[branch].channels;
   1037        int num_chs = existing_channels;
   1038        for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
   1039          if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
   1040            assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
   1041            // Needed only to assign the new channel buffers
   1042            num_chs += tensor2[b].channels;
   1043          }
   1044        }
   1045        assign_tensor(&tensor2[branch], output[output_num], num_chs, o_width,
   1046                      o_height, out_stride[output_num]);
   1047 
   1048        num_chs = existing_channels;
   1049        for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
   1050          if ((branch_config->branches_to_combine & (1 << b)) && b != branch) {
   1051            assert(check_tensor_equal_dims(&tensor2[b], &tensor2[branch]));
   1052            // Needed only to assign the new channel buffers
   1053            copy_tensor(&tensor2[b], tensor2[b].channels, num_chs,
   1054                        &tensor2[branch]);
   1055            num_chs += tensor2[b].channels;
   1056          }
   1057        }
   1058      }
   1059    }
   1060 
   1061    if (layer_config->branch_copy_type == BRANCH_COMBINED) {
   1062      if (!copy_active_tensor_to_branches(&tensor2[branch], layer_config,
   1063                                          branch, tensor2)) {
   1064        goto Error;
   1065      }
   1066    }
   1067  }
   1068 
   1069  success = true;
   1070 Error:
   1071  for (int b = 0; b < CNN_MAX_BRANCHES; ++b) {
   1072    free_tensor(&tensor1[b]);
   1073    free_tensor(&tensor2[b]);
   1074  }
   1075  return success;
   1076 }
   1077 
   1078 // Assume output already has proper allocation
   1079 // Assume input image buffers all have same resolution and strides
   1080 bool av1_cnn_predict_img_multi_out(uint8_t **dgd, int width, int height,
   1081                                   int stride, const CNN_CONFIG *cnn_config,
   1082                                   const CNN_THREAD_DATA *thread_data,
   1083                                   CNN_MULTI_OUT *output) {
   1084  const float max_val = 255.0;
   1085 
   1086  const int in_width = width + 2 * cnn_config->ext_width;
   1087  const int in_height = height + 2 * cnn_config->ext_height;
   1088  const int in_channels = cnn_config->layer_config[0].in_channels;
   1089  float *inputs[CNN_MAX_CHANNELS];
   1090  float *input_ =
   1091      (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
   1092  if (!input_) return false;
   1093  const int in_stride = in_width;
   1094 
   1095  for (int c = 0; c < in_channels; ++c) {
   1096    inputs[c] = input_ + c * in_stride * in_height;
   1097    float *input =
   1098        inputs[c] + cnn_config->ext_height * in_stride + cnn_config->ext_width;
   1099 
   1100    if (cnn_config->strict_bounds) {
   1101      for (int i = 0; i < height; ++i)
   1102        for (int j = 0; j < width; ++j)
   1103          input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
   1104      // extend left and right
   1105      for (int i = 0; i < height; ++i) {
   1106        for (int j = -cnn_config->ext_width; j < 0; ++j)
   1107          input[i * in_stride + j] = input[i * in_stride];
   1108        for (int j = width; j < width + cnn_config->ext_width; ++j)
   1109          input[i * in_stride + j] = input[i * in_stride + width - 1];
   1110      }
   1111      // extend top and bottom
   1112      for (int i = -cnn_config->ext_height; i < 0; ++i)
   1113        memcpy(&input[i * in_stride - cnn_config->ext_width],
   1114               &input[-cnn_config->ext_width], in_width * sizeof(*input));
   1115      for (int i = height; i < height + cnn_config->ext_height; ++i)
   1116        memcpy(&input[i * in_stride - cnn_config->ext_width],
   1117               &input[(height - 1) * in_stride - cnn_config->ext_width],
   1118               in_width * sizeof(*input));
   1119    } else {
   1120      for (int i = -cnn_config->ext_height; i < height + cnn_config->ext_height;
   1121           ++i)
   1122        for (int j = -cnn_config->ext_width; j < width + cnn_config->ext_width;
   1123             ++j)
   1124          input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
   1125    }
   1126  }
   1127  bool success = av1_cnn_predict((const float **)inputs, in_width, in_height,
   1128                                 in_stride, cnn_config, thread_data, output);
   1129 
   1130  aom_free(input_);
   1131  return success;
   1132 }
   1133 
   1134 // Assume output already has proper allocation
   1135 // Assume input image buffers all have same resolution and strides
   1136 bool av1_cnn_predict_img_multi_out_highbd(uint16_t **dgd, int width, int height,
   1137                                          int stride,
   1138                                          const CNN_CONFIG *cnn_config,
   1139                                          const CNN_THREAD_DATA *thread_data,
   1140                                          int bit_depth,
   1141                                          CNN_MULTI_OUT *output) {
   1142  const float max_val = (float)((1 << bit_depth) - 1);
   1143 
   1144  const int in_width = width + 2 * cnn_config->ext_width;
   1145  const int in_height = height + 2 * cnn_config->ext_height;
   1146  const int in_channels = cnn_config->layer_config[0].in_channels;
   1147  float *inputs[CNN_MAX_CHANNELS];
   1148  float *input_ =
   1149      (float *)aom_malloc(in_width * in_height * in_channels * sizeof(*input_));
   1150  if (!input_) return false;
   1151  const int in_stride = in_width;
   1152 
   1153  for (int c = 0; c < in_channels; ++c) {
   1154    inputs[c] = input_ + c * in_stride * in_height;
   1155    float *input =
   1156        inputs[c] + cnn_config->ext_height * in_stride + cnn_config->ext_width;
   1157 
   1158    if (cnn_config->strict_bounds) {
   1159      for (int i = 0; i < height; ++i)
   1160        for (int j = 0; j < width; ++j)
   1161          input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
   1162      // extend left and right
   1163      for (int i = 0; i < height; ++i) {
   1164        for (int j = -cnn_config->ext_width; j < 0; ++j)
   1165          input[i * in_stride + j] = input[i * in_stride];
   1166        for (int j = width; j < width + cnn_config->ext_width; ++j)
   1167          input[i * in_stride + j] = input[i * in_stride + width - 1];
   1168      }
   1169      // extend top and bottom
   1170      for (int i = -cnn_config->ext_height; i < 0; ++i)
   1171        memcpy(&input[i * in_stride - cnn_config->ext_width],
   1172               &input[-cnn_config->ext_width], in_width * sizeof(*input));
   1173      for (int i = height; i < height + cnn_config->ext_height; ++i)
   1174        memcpy(&input[i * in_stride - cnn_config->ext_width],
   1175               &input[(height - 1) * in_stride - cnn_config->ext_width],
   1176               in_width * sizeof(*input));
   1177    } else {
   1178      for (int i = -cnn_config->ext_height; i < height + cnn_config->ext_height;
   1179           ++i)
   1180        for (int j = -cnn_config->ext_width; j < width + cnn_config->ext_width;
   1181             ++j)
   1182          input[i * in_stride + j] = (float)dgd[c][i * stride + j] / max_val;
   1183    }
   1184  }
   1185 
   1186  bool success = av1_cnn_predict((const float **)inputs, in_width, in_height,
   1187                                 in_stride, cnn_config, thread_data, output);
   1188 
   1189  aom_free(input_);
   1190  return success;
   1191 }