tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

txfm_graph.cc (32582B)


      1 /*
      2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include "tools/txfm_analyzer/txfm_graph.h"
     13 
     14 #include <stdio.h>
     15 #include <stdlib.h>
     16 #include <math.h>
     17 
     18 typedef struct Node Node;
     19 
     20 void get_fun_name(char *str_fun_name, int str_buf_size, const TYPE_TXFM type,
     21                  const int txfm_size) {
     22  if (type == TYPE_DCT)
     23    snprintf(str_fun_name, str_buf_size, "fdct%d_new", txfm_size);
     24  else if (type == TYPE_ADST)
     25    snprintf(str_fun_name, str_buf_size, "fadst%d_new", txfm_size);
     26  else if (type == TYPE_IDCT)
     27    snprintf(str_fun_name, str_buf_size, "idct%d_new", txfm_size);
     28  else if (type == TYPE_IADST)
     29    snprintf(str_fun_name, str_buf_size, "iadst%d_new", txfm_size);
     30 }
     31 
     32 void get_txfm_type_name(char *str_fun_name, int str_buf_size,
     33                        const TYPE_TXFM type, const int txfm_size) {
     34  if (type == TYPE_DCT)
     35    snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size);
     36  else if (type == TYPE_ADST)
     37    snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size);
     38  else if (type == TYPE_IDCT)
     39    snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size);
     40  else if (type == TYPE_IADST)
     41    snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size);
     42 }
     43 
     44 void get_hybrid_2d_type_name(char *buf, int buf_size, const TYPE_TXFM type0,
     45                             const TYPE_TXFM type1, const int txfm_size0,
     46                             const int txfm_size1) {
     47  if (type0 == TYPE_DCT && type1 == TYPE_DCT)
     48    snprintf(buf, buf_size, "_dct_dct_%dx%d", txfm_size1, txfm_size0);
     49  else if (type0 == TYPE_DCT && type1 == TYPE_ADST)
     50    snprintf(buf, buf_size, "_dct_adst_%dx%d", txfm_size1, txfm_size0);
     51  else if (type0 == TYPE_ADST && type1 == TYPE_ADST)
     52    snprintf(buf, buf_size, "_adst_adst_%dx%d", txfm_size1, txfm_size0);
     53  else if (type0 == TYPE_ADST && type1 == TYPE_DCT)
     54    snprintf(buf, buf_size, "_adst_dct_%dx%d", txfm_size1, txfm_size0);
     55 }
     56 
     57 TYPE_TXFM get_inv_type(TYPE_TXFM type) {
     58  if (type == TYPE_DCT)
     59    return TYPE_IDCT;
     60  else if (type == TYPE_ADST)
     61    return TYPE_IADST;
     62  else if (type == TYPE_IDCT)
     63    return TYPE_DCT;
     64  else if (type == TYPE_IADST)
     65    return TYPE_ADST;
     66  else
     67    return TYPE_LAST;
     68 }
     69 
     70 void reference_dct_1d(double *in, double *out, int size) {
     71  const double kInvSqrt2 = 0.707106781186547524400844362104;
     72  for (int k = 0; k < size; k++) {
     73    out[k] = 0;  // initialize out[k]
     74    for (int n = 0; n < size; n++) {
     75      out[k] += in[n] * cos(PI * (2 * n + 1) * k / (2 * size));
     76    }
     77    if (k == 0) out[k] = out[k] * kInvSqrt2;
     78  }
     79 }
     80 
     81 void reference_dct_2d(double *in, double *out, int size) {
     82  double *tempOut = new double[size * size];
     83  // dct each row: in -> out
     84  for (int r = 0; r < size; r++) {
     85    reference_dct_1d(in + r * size, out + r * size, size);
     86  }
     87 
     88  for (int r = 0; r < size; r++) {
     89    // out ->tempOut
     90    for (int c = 0; c < size; c++) {
     91      tempOut[r * size + c] = out[c * size + r];
     92    }
     93  }
     94  for (int r = 0; r < size; r++) {
     95    reference_dct_1d(tempOut + r * size, out + r * size, size);
     96  }
     97  delete[] tempOut;
     98 }
     99 
    100 void reference_adst_1d(double *in, double *out, int size) {
    101  for (int k = 0; k < size; k++) {
    102    out[k] = 0;  // initialize out[k]
    103    for (int n = 0; n < size; n++) {
    104      out[k] += in[n] * sin(PI * (2 * n + 1) * (2 * k + 1) / (4 * size));
    105    }
    106  }
    107 }
    108 
    109 void reference_hybrid_2d(double *in, double *out, int size, int type0,
    110                         int type1) {
    111  double *tempOut = new double[size * size];
    112  // dct each row: in -> out
    113  for (int r = 0; r < size; r++) {
    114    if (type0 == TYPE_DCT)
    115      reference_dct_1d(in + r * size, out + r * size, size);
    116    else
    117      reference_adst_1d(in + r * size, out + r * size, size);
    118  }
    119 
    120  for (int r = 0; r < size; r++) {
    121    // out ->tempOut
    122    for (int c = 0; c < size; c++) {
    123      tempOut[r * size + c] = out[c * size + r];
    124    }
    125  }
    126  for (int r = 0; r < size; r++) {
    127    if (type1 == TYPE_DCT)
    128      reference_dct_1d(tempOut + r * size, out + r * size, size);
    129    else
    130      reference_adst_1d(tempOut + r * size, out + r * size, size);
    131  }
    132  delete[] tempOut;
    133 }
    134 
    135 void reference_hybrid_2d_new(double *in, double *out, int size0, int size1,
    136                             int type0, int type1) {
    137  double *tempOut = new double[size0 * size1];
    138  // dct each row: in -> out
    139  for (int r = 0; r < size1; r++) {
    140    if (type0 == TYPE_DCT)
    141      reference_dct_1d(in + r * size0, out + r * size0, size0);
    142    else
    143      reference_adst_1d(in + r * size0, out + r * size0, size0);
    144  }
    145 
    146  for (int r = 0; r < size1; r++) {
    147    // out ->tempOut
    148    for (int c = 0; c < size0; c++) {
    149      tempOut[c * size1 + r] = out[r * size0 + c];
    150    }
    151  }
    152  for (int r = 0; r < size0; r++) {
    153    if (type1 == TYPE_DCT)
    154      reference_dct_1d(tempOut + r * size1, out + r * size1, size1);
    155    else
    156      reference_adst_1d(tempOut + r * size1, out + r * size1, size1);
    157  }
    158  delete[] tempOut;
    159 }
    160 
    161 unsigned int get_max_bit(unsigned int x) {
    162  int max_bit = -1;
    163  while (x) {
    164    x = x >> 1;
    165    max_bit++;
    166  }
    167  return max_bit;
    168 }
    169 
    170 unsigned int bitwise_reverse(unsigned int x, int max_bit) {
    171  x = ((x >> 16) & 0x0000ffff) | ((x & 0x0000ffff) << 16);
    172  x = ((x >> 8) & 0x00ff00ff) | ((x & 0x00ff00ff) << 8);
    173  x = ((x >> 4) & 0x0f0f0f0f) | ((x & 0x0f0f0f0f) << 4);
    174  x = ((x >> 2) & 0x33333333) | ((x & 0x33333333) << 2);
    175  x = ((x >> 1) & 0x55555555) | ((x & 0x55555555) << 1);
    176  x = x >> (31 - max_bit);
    177  return x;
    178 }
    179 
    180 int get_idx(int ri, int ci, int cSize) { return ri * cSize + ci; }
    181 
    182 void add_node(Node *node, int stage_num, int node_num, int stage_idx,
    183              int node_idx, int in, double w) {
    184  int outIdx = get_idx(stage_idx, node_idx, node_num);
    185  int inIdx = get_idx(stage_idx - 1, in, node_num);
    186  int idx = node[outIdx].inNodeNum;
    187  if (idx < 2) {
    188    node[outIdx].inNode[idx] = &node[inIdx];
    189    node[outIdx].inNodeIdx[idx] = in;
    190    node[outIdx].inWeight[idx] = w;
    191    idx++;
    192    node[outIdx].inNodeNum = idx;
    193  } else {
    194    printf("Error: inNode is full");
    195  }
    196 }
    197 
    198 void connect_node(Node *node, int stage_num, int node_num, int stage_idx,
    199                  int node_idx, int in0, double w0, int in1, double w1) {
    200  int outIdx = get_idx(stage_idx, node_idx, node_num);
    201  int inIdx0 = get_idx(stage_idx - 1, in0, node_num);
    202  int inIdx1 = get_idx(stage_idx - 1, in1, node_num);
    203 
    204  int idx = 0;
    205  // if(w0 != 0) {
    206  node[outIdx].inNode[idx] = &node[inIdx0];
    207  node[outIdx].inNodeIdx[idx] = in0;
    208  node[outIdx].inWeight[idx] = w0;
    209  idx++;
    210  //}
    211 
    212  // if(w1 != 0) {
    213  node[outIdx].inNode[idx] = &node[inIdx1];
    214  node[outIdx].inNodeIdx[idx] = in1;
    215  node[outIdx].inWeight[idx] = w1;
    216  idx++;
    217  //}
    218 
    219  node[outIdx].inNodeNum = idx;
    220 }
    221 
    222 void propagate(Node *node, int stage_num, int node_num, int stage_idx) {
    223  for (int ni = 0; ni < node_num; ni++) {
    224    int outIdx = get_idx(stage_idx, ni, node_num);
    225    node[outIdx].value = 0;
    226    for (int k = 0; k < node[outIdx].inNodeNum; k++) {
    227      node[outIdx].value +=
    228          node[outIdx].inNode[k]->value * node[outIdx].inWeight[k];
    229    }
    230  }
    231 }
    232 
    233 int64_t round_shift(int64_t value, int bit) {
    234  if (bit > 0) {
    235    if (value < 0) {
    236      return -round_shift(-value, bit);
    237    } else {
    238      return (value + (1 << (bit - 1))) >> bit;
    239    }
    240  } else {
    241    return value << (-bit);
    242  }
    243 }
    244 
    245 void round_shift_array(int32_t *arr, int size, int bit) {
    246  if (bit == 0) {
    247    return;
    248  } else {
    249    for (int i = 0; i < size; i++) {
    250      arr[i] = round_shift(arr[i], bit);
    251    }
    252  }
    253 }
    254 
    255 void graph_reset_visited(Node *node, int stage_num, int node_num) {
    256  for (int si = 0; si < stage_num; si++) {
    257    for (int ni = 0; ni < node_num; ni++) {
    258      int idx = get_idx(si, ni, node_num);
    259      node[idx].visited = 0;
    260    }
    261  }
    262 }
    263 
    264 void estimate_value(Node *node, int stage_num, int node_num, int stage_idx,
    265                    int node_idx, int estimate_bit) {
    266  if (stage_idx > 0) {
    267    int outIdx = get_idx(stage_idx, node_idx, node_num);
    268    int64_t out = 0;
    269    node[outIdx].value = 0;
    270    for (int k = 0; k < node[outIdx].inNodeNum; k++) {
    271      int64_t w = round(node[outIdx].inWeight[k] * (1 << estimate_bit));
    272      int64_t v = round(node[outIdx].inNode[k]->value);
    273      out += v * w;
    274    }
    275    node[outIdx].value = round_shift(out, estimate_bit);
    276  }
    277 }
    278 
    279 void amplify_value(Node *node, int stage_num, int node_num, int stage_idx,
    280                   int node_idx, int amplify_bit) {
    281  int outIdx = get_idx(stage_idx, node_idx, node_num);
    282  node[outIdx].value = round_shift(round(node[outIdx].value), -amplify_bit);
    283 }
    284 
    285 void propagate_estimate_amlify(Node *node, int stage_num, int node_num,
    286                               int stage_idx, int amplify_bit,
    287                               int estimate_bit) {
    288  for (int ni = 0; ni < node_num; ni++) {
    289    estimate_value(node, stage_num, node_num, stage_idx, ni, estimate_bit);
    290    amplify_value(node, stage_num, node_num, stage_idx, ni, amplify_bit);
    291  }
    292 }
    293 
    294 void init_graph(Node *node, int stage_num, int node_num) {
    295  for (int si = 0; si < stage_num; si++) {
    296    for (int ni = 0; ni < node_num; ni++) {
    297      int outIdx = get_idx(si, ni, node_num);
    298      node[outIdx].stageIdx = si;
    299      node[outIdx].nodeIdx = ni;
    300      node[outIdx].value = 0;
    301      node[outIdx].inNodeNum = 0;
    302      if (si >= 1) {
    303        connect_node(node, stage_num, node_num, si, ni, ni, 1, ni, 0);
    304      }
    305    }
    306  }
    307 }
    308 
    309 void gen_B_graph(Node *node, int stage_num, int node_num, int stage_idx,
    310                 int node_idx, int N, int star) {
    311  for (int i = 0; i < N / 2; i++) {
    312    int out = node_idx + i;
    313    int in1 = node_idx + N - 1 - i;
    314    if (star == 1) {
    315      connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1,
    316                   1);
    317    } else {
    318      connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1,
    319                   1);
    320    }
    321  }
    322  for (int i = N / 2; i < N; i++) {
    323    int out = node_idx + i;
    324    int in1 = node_idx + N - 1 - i;
    325    if (star == 1) {
    326      connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1,
    327                   1);
    328    } else {
    329      connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1,
    330                   1);
    331    }
    332  }
    333 }
    334 
    335 void gen_P_graph(Node *node, int stage_num, int node_num, int stage_idx,
    336                 int node_idx, int N) {
    337  int max_bit = get_max_bit(N - 1);
    338  for (int i = 0; i < N; i++) {
    339    int out = node_idx + bitwise_reverse(i, max_bit);
    340    int in = node_idx + i;
    341    connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
    342  }
    343 }
    344 
    345 void gen_type1_graph(Node *node, int stage_num, int node_num, int stage_idx,
    346                     int node_idx, int N) {
    347  int max_bit = get_max_bit(N);
    348  for (int ni = 0; ni < N / 2; ni++) {
    349    int ai = bitwise_reverse(N + ni, max_bit);
    350    int out = node_idx + ni;
    351    int in1 = node_idx + N - ni - 1;
    352    connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
    353                 sin(PI * ai / (2 * 2 * N)), in1, cos(PI * ai / (2 * 2 * N)));
    354  }
    355  for (int ni = N / 2; ni < N; ni++) {
    356    int ai = bitwise_reverse(N + ni, max_bit);
    357    int out = node_idx + ni;
    358    int in1 = node_idx + N - ni - 1;
    359    connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
    360                 cos(PI * ai / (2 * 2 * N)), in1, -sin(PI * ai / (2 * 2 * N)));
    361  }
    362 }
    363 
    364 void gen_type2_graph(Node *node, int stage_num, int node_num, int stage_idx,
    365                     int node_idx, int N) {
    366  for (int ni = 0; ni < N / 4; ni++) {
    367    int out = node_idx + ni;
    368    connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0);
    369  }
    370 
    371  for (int ni = N / 4; ni < N / 2; ni++) {
    372    int out = node_idx + ni;
    373    int in1 = node_idx + N - ni - 1;
    374    connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
    375                 -cos(PI / 4), in1, cos(-PI / 4));
    376  }
    377 
    378  for (int ni = N / 2; ni < N * 3 / 4; ni++) {
    379    int out = node_idx + ni;
    380    int in1 = node_idx + N - ni - 1;
    381    connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
    382                 cos(-PI / 4), in1, cos(PI / 4));
    383  }
    384 
    385  for (int ni = N * 3 / 4; ni < N; ni++) {
    386    int out = node_idx + ni;
    387    connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0);
    388  }
    389 }
    390 
    391 void gen_type3_graph(Node *node, int stage_num, int node_num, int stage_idx,
    392                     int node_idx, int idx, int N) {
    393  // TODO(angiebird): Simplify and clarify this function
    394 
    395  int i = 2 * N / (1 << (idx / 2));
    396  int max_bit =
    397      get_max_bit(i / 2) - 1;  // the max_bit counts on i/2 instead of N here
    398  int N_over_i = 2 << (idx / 2);
    399 
    400  for (int nj = 0; nj < N / 2; nj += N_over_i) {
    401    int j = nj / (N_over_i);
    402    int kj = bitwise_reverse(i / 4 + j, max_bit);
    403 
    404    // I_N/2i   --- 0
    405    int offset = nj;
    406    for (int ni = 0; ni < N_over_i / 4; ni++) {
    407      int out = node_idx + offset + ni;
    408      int in = out;
    409      connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
    410    }
    411 
    412    // -C_Kj/i --- S_Kj/i
    413    offset += N_over_i / 4;
    414    for (int ni = 0; ni < N_over_i / 4; ni++) {
    415      int out = node_idx + offset + ni;
    416      int in0 = out;
    417      double w0 = -cos(kj * PI / i);
    418      int in1 = N - (offset + ni) - 1 + node_idx;
    419      double w1 = sin(kj * PI / i);
    420      connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
    421                   w1);
    422    }
    423 
    424    // S_kj/i  --- -C_Kj/i
    425    offset += N_over_i / 4;
    426    for (int ni = 0; ni < N_over_i / 4; ni++) {
    427      int out = node_idx + offset + ni;
    428      int in0 = out;
    429      double w0 = -sin(kj * PI / i);
    430      int in1 = N - (offset + ni) - 1 + node_idx;
    431      double w1 = -cos(kj * PI / i);
    432      connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
    433                   w1);
    434    }
    435 
    436    // I_N/2i   --- 0
    437    offset += N_over_i / 4;
    438    for (int ni = 0; ni < N_over_i / 4; ni++) {
    439      int out = node_idx + offset + ni;
    440      int in = out;
    441      connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
    442    }
    443  }
    444 
    445  for (int nj = N / 2; nj < N; nj += N_over_i) {
    446    int j = nj / N_over_i;
    447    int kj = bitwise_reverse(i / 4 + j, max_bit);
    448 
    449    // I_N/2i --- 0
    450    int offset = nj;
    451    for (int ni = 0; ni < N_over_i / 4; ni++) {
    452      int out = node_idx + offset + ni;
    453      int in = out;
    454      connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
    455    }
    456 
    457    // C_kj/i --- -S_Kj/i
    458    offset += N_over_i / 4;
    459    for (int ni = 0; ni < N_over_i / 4; ni++) {
    460      int out = node_idx + offset + ni;
    461      int in0 = out;
    462      double w0 = cos(kj * PI / i);
    463      int in1 = N - (offset + ni) - 1 + node_idx;
    464      double w1 = -sin(kj * PI / i);
    465      connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
    466                   w1);
    467    }
    468 
    469    // S_kj/i --- C_Kj/i
    470    offset += N_over_i / 4;
    471    for (int ni = 0; ni < N_over_i / 4; ni++) {
    472      int out = node_idx + offset + ni;
    473      int in0 = out;
    474      double w0 = sin(kj * PI / i);
    475      int in1 = N - (offset + ni) - 1 + node_idx;
    476      double w1 = cos(kj * PI / i);
    477      connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
    478                   w1);
    479    }
    480 
    481    // I_N/2i --- 0
    482    offset += N_over_i / 4;
    483    for (int ni = 0; ni < N_over_i / 4; ni++) {
    484      int out = node_idx + offset + ni;
    485      int in = out;
    486      connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
    487    }
    488  }
    489 }
    490 
    491 void gen_type4_graph(Node *node, int stage_num, int node_num, int stage_idx,
    492                     int node_idx, int idx, int N) {
    493  int B_size = 1 << ((idx + 1) / 2);
    494  for (int ni = 0; ni < N; ni += B_size) {
    495    gen_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni, B_size,
    496                (ni / B_size) % 2);
    497  }
    498 }
    499 
    500 void gen_R_graph(Node *node, int stage_num, int node_num, int stage_idx,
    501                 int node_idx, int N) {
    502  int max_idx = 2 * (get_max_bit(N) + 1) - 3;
    503  for (int idx = 0; idx < max_idx; idx++) {
    504    int s = stage_idx + max_idx - idx - 1;
    505    if (idx == 0) {
    506      // type 1
    507      gen_type1_graph(node, stage_num, node_num, s, node_idx, N);
    508    } else if (idx == max_idx - 1) {
    509      // type 2
    510      gen_type2_graph(node, stage_num, node_num, s, node_idx, N);
    511    } else if ((idx + 1) % 2 == 0) {
    512      // type 4
    513      gen_type4_graph(node, stage_num, node_num, s, node_idx, idx, N);
    514    } else if ((idx + 1) % 2 == 1) {
    515      // type 3
    516      gen_type3_graph(node, stage_num, node_num, s, node_idx, idx, N);
    517    } else {
    518      printf("check gen_R_graph()\n");
    519    }
    520  }
    521 }
    522 
    523 void gen_DCT_graph(Node *node, int stage_num, int node_num, int stage_idx,
    524                   int node_idx, int N) {
    525  if (N > 2) {
    526    gen_B_graph(node, stage_num, node_num, stage_idx, node_idx, N, 0);
    527    gen_DCT_graph(node, stage_num, node_num, stage_idx + 1, node_idx, N / 2);
    528    gen_R_graph(node, stage_num, node_num, stage_idx + 1, node_idx + N / 2,
    529                N / 2);
    530  } else {
    531    // generate dct_2
    532    connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx,
    533                 cos(PI / 4), node_idx + 1, cos(PI / 4));
    534    connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1,
    535                 node_idx + 1, -cos(PI / 4), node_idx, cos(PI / 4));
    536  }
    537 }
    538 
    539 int get_dct_stage_num(int size) { return 2 * get_max_bit(size); }
    540 
    541 void gen_DCT_graph_1d(Node *node, int stage_num, int node_num, int stage_idx,
    542                      int node_idx, int dct_node_num) {
    543  gen_DCT_graph(node, stage_num, node_num, stage_idx, node_idx, dct_node_num);
    544  int dct_stage_num = get_dct_stage_num(dct_node_num);
    545  gen_P_graph(node, stage_num, node_num, stage_idx + dct_stage_num - 2,
    546              node_idx, dct_node_num);
    547 }
    548 
    549 void gen_adst_B_graph(Node *node, int stage_num, int node_num, int stage_idx,
    550                      int node_idx, int adst_idx) {
    551  int size = 1 << (adst_idx + 1);
    552  for (int ni = 0; ni < size / 2; ni++) {
    553    int nOut = node_idx + ni;
    554    int nIn = nOut + size / 2;
    555    connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, 1, nIn,
    556                 1);
    557  }
    558  for (int ni = size / 2; ni < size; ni++) {
    559    int nOut = node_idx + ni;
    560    int nIn = nOut - size / 2;
    561    connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, -1, nIn,
    562                 1);
    563  }
    564 }
    565 
    566 void gen_adst_U_graph(Node *node, int stage_num, int node_num, int stage_idx,
    567                      int node_idx, int adst_idx, int adst_node_num) {
    568  int size = 1 << (adst_idx + 1);
    569  for (int ni = 0; ni < adst_node_num; ni += size) {
    570    gen_adst_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni,
    571                     adst_idx);
    572  }
    573 }
    574 
    575 void gen_adst_T_graph(Node *node, int stage_num, int node_num, int stage_idx,
    576                      int node_idx, double freq) {
    577  connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx,
    578               cos(freq * PI), node_idx + 1, sin(freq * PI));
    579  connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1,
    580               node_idx + 1, -cos(freq * PI), node_idx, sin(freq * PI));
    581 }
    582 
    583 void gen_adst_E_graph(Node *node, int stage_num, int node_num, int stage_idx,
    584                      int node_idx, int adst_idx) {
    585  int size = 1 << (adst_idx);
    586  for (int i = 0; i < size / 2; i++) {
    587    int ni = i * 2;
    588    double fi = (1 + 4 * i) * 1.0 / (1 << (adst_idx + 1));
    589    gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi);
    590  }
    591 }
    592 
    593 void gen_adst_V_graph(Node *node, int stage_num, int node_num, int stage_idx,
    594                      int node_idx, int adst_idx, int adst_node_num) {
    595  int size = 1 << (adst_idx);
    596  for (int i = 0; i < adst_node_num / size; i++) {
    597    if (i % 2 == 1) {
    598      int ni = i * size;
    599      gen_adst_E_graph(node, stage_num, node_num, stage_idx, node_idx + ni,
    600                       adst_idx);
    601    }
    602  }
    603 }
    604 void gen_adst_VJ_graph(Node *node, int stage_num, int node_num, int stage_idx,
    605                       int node_idx, int adst_node_num) {
    606  for (int i = 0; i < adst_node_num / 2; i++) {
    607    int ni = i * 2;
    608    double fi = (1 + 4 * i) * 1.0 / (4 * adst_node_num);
    609    gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi);
    610  }
    611 }
    612 void gen_adst_Q_graph(Node *node, int stage_num, int node_num, int stage_idx,
    613                      int node_idx, int adst_node_num) {
    614  // reverse order when idx is 1, 3, 5, 7 ...
    615  // example of adst_node_num = 8:
    616  //   0 1 2 3 4 5 6 7
    617  // --> 0 7 2 5 4 3 6 1
    618  for (int ni = 0; ni < adst_node_num; ni++) {
    619    if (ni % 2 == 0) {
    620      int out = node_idx + ni;
    621      connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out,
    622                   0);
    623    } else {
    624      int out = node_idx + ni;
    625      int in = node_idx + adst_node_num - ni;
    626      connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
    627    }
    628  }
    629 }
    630 void gen_adst_Ibar_graph(Node *node, int stage_num, int node_num, int stage_idx,
    631                         int node_idx, int adst_node_num) {
    632  // reverse order
    633  // 0 1 2 3 --> 3 2 1 0
    634  for (int ni = 0; ni < adst_node_num; ni++) {
    635    int out = node_idx + ni;
    636    int in = node_idx + adst_node_num - ni - 1;
    637    connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
    638  }
    639 }
    640 
    641 int get_Q_out2in(int adst_node_num, int out) {
    642  int in;
    643  if (out % 2 == 0) {
    644    in = out;
    645  } else {
    646    in = adst_node_num - out;
    647  }
    648  return in;
    649 }
    650 
    651 int get_Ibar_out2in(int adst_node_num, int out) {
    652  return adst_node_num - out - 1;
    653 }
    654 
    655 void gen_adst_IbarQ_graph(Node *node, int stage_num, int node_num,
    656                          int stage_idx, int node_idx, int adst_node_num) {
    657  // in -> Ibar -> Q -> out
    658  for (int ni = 0; ni < adst_node_num; ni++) {
    659    int out = node_idx + ni;
    660    int in = node_idx +
    661             get_Ibar_out2in(adst_node_num, get_Q_out2in(adst_node_num, ni));
    662    connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
    663  }
    664 }
    665 
    666 void gen_adst_D_graph(Node *node, int stage_num, int node_num, int stage_idx,
    667                      int node_idx, int adst_node_num) {
    668  // reverse order
    669  for (int ni = 0; ni < adst_node_num; ni++) {
    670    int out = node_idx + ni;
    671    int in = out;
    672    if (ni % 2 == 0) {
    673      connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
    674    } else {
    675      connect_node(node, stage_num, node_num, stage_idx + 1, out, in, -1, in,
    676                   0);
    677    }
    678  }
    679 }
    680 
    681 int get_hadamard_idx(int x, int adst_node_num) {
    682  int max_bit = get_max_bit(adst_node_num - 1);
    683  x = bitwise_reverse(x, max_bit);
    684 
    685  // gray code
    686  int c = x & 1;
    687  int p = x & 1;
    688  int y = c;
    689 
    690  for (int i = 1; i <= max_bit; i++) {
    691    p = c;
    692    c = (x >> i) & 1;
    693    y += (c ^ p) << i;
    694  }
    695  return y;
    696 }
    697 
    698 void gen_adst_Ht_graph(Node *node, int stage_num, int node_num, int stage_idx,
    699                       int node_idx, int adst_node_num) {
    700  for (int ni = 0; ni < adst_node_num; ni++) {
    701    int out = node_idx + ni;
    702    int in = node_idx + get_hadamard_idx(ni, adst_node_num);
    703    connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
    704  }
    705 }
    706 
    707 void gen_adst_HtD_graph(Node *node, int stage_num, int node_num, int stage_idx,
    708                        int node_idx, int adst_node_num) {
    709  for (int ni = 0; ni < adst_node_num; ni++) {
    710    int out = node_idx + ni;
    711    int in = node_idx + get_hadamard_idx(ni, adst_node_num);
    712    double inW;
    713    if (ni % 2 == 0)
    714      inW = 1;
    715    else
    716      inW = -1;
    717    connect_node(node, stage_num, node_num, stage_idx + 1, out, in, inW, in, 0);
    718  }
    719 }
    720 
    721 int get_adst_stage_num(int adst_node_num) {
    722  return 2 * get_max_bit(adst_node_num) + 2;
    723 }
    724 
    725 int gen_iadst_graph(Node *node, int stage_num, int node_num, int stage_idx,
    726                    int node_idx, int adst_node_num) {
    727  int max_bit = get_max_bit(adst_node_num);
    728  int si = 0;
    729  gen_adst_IbarQ_graph(node, stage_num, node_num, stage_idx + si, node_idx,
    730                       adst_node_num);
    731  si++;
    732  gen_adst_VJ_graph(node, stage_num, node_num, stage_idx + si, node_idx,
    733                    adst_node_num);
    734  si++;
    735  for (int adst_idx = max_bit - 1; adst_idx >= 1; adst_idx--) {
    736    gen_adst_U_graph(node, stage_num, node_num, stage_idx + si, node_idx,
    737                     adst_idx, adst_node_num);
    738    si++;
    739    gen_adst_V_graph(node, stage_num, node_num, stage_idx + si, node_idx,
    740                     adst_idx, adst_node_num);
    741    si++;
    742  }
    743  gen_adst_HtD_graph(node, stage_num, node_num, stage_idx + si, node_idx,
    744                     adst_node_num);
    745  si++;
    746  return si + 1;
    747 }
    748 
    749 int gen_adst_graph(Node *node, int stage_num, int node_num, int stage_idx,
    750                   int node_idx, int adst_node_num) {
    751  int hybrid_stage_num = get_hybrid_stage_num(TYPE_ADST, adst_node_num);
    752  // generate a adst tempNode
    753  Node *tempNode = new Node[hybrid_stage_num * adst_node_num];
    754  init_graph(tempNode, hybrid_stage_num, adst_node_num);
    755  int si = gen_iadst_graph(tempNode, hybrid_stage_num, adst_node_num, 0, 0,
    756                           adst_node_num);
    757 
    758  // tempNode's inverse graph to node[stage_idx][node_idx]
    759  gen_inv_graph(tempNode, hybrid_stage_num, adst_node_num, node, stage_num,
    760                node_num, stage_idx, node_idx);
    761  delete[] tempNode;
    762  return si;
    763 }
    764 
    765 void connect_layer_2d(Node *node, int stage_num, int node_num, int stage_idx,
    766                      int node_idx, int dct_node_num) {
    767  for (int first = 0; first < dct_node_num; first++) {
    768    for (int second = 0; second < dct_node_num; second++) {
    769      // int sIn = stage_idx;
    770      int sOut = stage_idx + 1;
    771      int nIn = node_idx + first * dct_node_num + second;
    772      int nOut = node_idx + second * dct_node_num + first;
    773 
    774      connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0);
    775    }
    776  }
    777 }
    778 
    779 void connect_layer_2d_new(Node *node, int stage_num, int node_num,
    780                          int stage_idx, int node_idx, int dct_node_num0,
    781                          int dct_node_num1) {
    782  for (int i = 0; i < dct_node_num1; i++) {
    783    for (int j = 0; j < dct_node_num0; j++) {
    784      // int sIn = stage_idx;
    785      int sOut = stage_idx + 1;
    786      int nIn = node_idx + i * dct_node_num0 + j;
    787      int nOut = node_idx + j * dct_node_num1 + i;
    788 
    789      connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0);
    790    }
    791  }
    792 }
    793 
    794 void gen_DCT_graph_2d(Node *node, int stage_num, int node_num, int stage_idx,
    795                      int node_idx, int dct_node_num) {
    796  int dct_stage_num = get_dct_stage_num(dct_node_num);
    797  // put 2 layers of dct_node_num DCTs on the graph
    798  for (int ni = 0; ni < dct_node_num; ni++) {
    799    gen_DCT_graph_1d(node, stage_num, node_num, stage_idx,
    800                     node_idx + ni * dct_node_num, dct_node_num);
    801    gen_DCT_graph_1d(node, stage_num, node_num, stage_idx + dct_stage_num,
    802                     node_idx + ni * dct_node_num, dct_node_num);
    803  }
    804  // connect first layer and second layer
    805  connect_layer_2d(node, stage_num, node_num, stage_idx + dct_stage_num - 1,
    806                   node_idx, dct_node_num);
    807 }
    808 
    809 int get_hybrid_stage_num(int type, int hybrid_node_num) {
    810  if (type == TYPE_DCT || type == TYPE_IDCT) {
    811    return get_dct_stage_num(hybrid_node_num);
    812  } else if (type == TYPE_ADST || type == TYPE_IADST) {
    813    return get_adst_stage_num(hybrid_node_num);
    814  }
    815  return 0;
    816 }
    817 
    818 int get_hybrid_2d_stage_num(int type0, int type1, int hybrid_node_num) {
    819  int stage_num = 0;
    820  stage_num += get_hybrid_stage_num(type0, hybrid_node_num);
    821  stage_num += get_hybrid_stage_num(type1, hybrid_node_num);
    822  return stage_num;
    823 }
    824 
    825 int get_hybrid_2d_stage_num_new(int type0, int type1, int hybrid_node_num0,
    826                                int hybrid_node_num1) {
    827  int stage_num = 0;
    828  stage_num += get_hybrid_stage_num(type0, hybrid_node_num0);
    829  stage_num += get_hybrid_stage_num(type1, hybrid_node_num1);
    830  return stage_num;
    831 }
    832 
    833 int get_hybrid_amplify_factor(int type, int hybrid_node_num) {
    834  return get_max_bit(hybrid_node_num) - 1;
    835 }
    836 
    837 void gen_hybrid_graph_1d(Node *node, int stage_num, int node_num, int stage_idx,
    838                         int node_idx, int hybrid_node_num, int type) {
    839  if (type == TYPE_DCT) {
    840    gen_DCT_graph_1d(node, stage_num, node_num, stage_idx, node_idx,
    841                     hybrid_node_num);
    842  } else if (type == TYPE_ADST) {
    843    gen_adst_graph(node, stage_num, node_num, stage_idx, node_idx,
    844                   hybrid_node_num);
    845  } else if (type == TYPE_IDCT) {
    846    int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num);
    847    // generate a dct tempNode
    848    Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num];
    849    init_graph(tempNode, hybrid_stage_num, hybrid_node_num);
    850    gen_DCT_graph_1d(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0,
    851                     hybrid_node_num);
    852 
    853    // tempNode's inverse graph to node[stage_idx][node_idx]
    854    gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num,
    855                  node_num, stage_idx, node_idx);
    856    delete[] tempNode;
    857  } else if (type == TYPE_IADST) {
    858    int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num);
    859    // generate a adst tempNode
    860    Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num];
    861    init_graph(tempNode, hybrid_stage_num, hybrid_node_num);
    862    gen_adst_graph(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0,
    863                   hybrid_node_num);
    864 
    865    // tempNode's inverse graph to node[stage_idx][node_idx]
    866    gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num,
    867                  node_num, stage_idx, node_idx);
    868    delete[] tempNode;
    869  }
    870 }
    871 
    872 void gen_hybrid_graph_2d(Node *node, int stage_num, int node_num, int stage_idx,
    873                         int node_idx, int hybrid_node_num, int type0,
    874                         int type1) {
    875  int hybrid_stage_num = get_hybrid_stage_num(type0, hybrid_node_num);
    876 
    877  for (int ni = 0; ni < hybrid_node_num; ni++) {
    878    gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx,
    879                        node_idx + ni * hybrid_node_num, hybrid_node_num,
    880                        type0);
    881    gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx + hybrid_stage_num,
    882                        node_idx + ni * hybrid_node_num, hybrid_node_num,
    883                        type1);
    884  }
    885 
    886  // connect first layer and second layer
    887  connect_layer_2d(node, stage_num, node_num, stage_idx + hybrid_stage_num - 1,
    888                   node_idx, hybrid_node_num);
    889 }
    890 
    891 void gen_hybrid_graph_2d_new(Node *node, int stage_num, int node_num,
    892                             int stage_idx, int node_idx, int hybrid_node_num0,
    893                             int hybrid_node_num1, int type0, int type1) {
    894  int hybrid_stage_num0 = get_hybrid_stage_num(type0, hybrid_node_num0);
    895 
    896  for (int ni = 0; ni < hybrid_node_num1; ni++) {
    897    gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx,
    898                        node_idx + ni * hybrid_node_num0, hybrid_node_num0,
    899                        type0);
    900  }
    901  for (int ni = 0; ni < hybrid_node_num0; ni++) {
    902    gen_hybrid_graph_1d(
    903        node, stage_num, node_num, stage_idx + hybrid_stage_num0,
    904        node_idx + ni * hybrid_node_num1, hybrid_node_num1, type1);
    905  }
    906 
    907  // connect first layer and second layer
    908  connect_layer_2d_new(node, stage_num, node_num,
    909                       stage_idx + hybrid_stage_num0 - 1, node_idx,
    910                       hybrid_node_num0, hybrid_node_num1);
    911 }
    912 
    913 void gen_inv_graph(Node *node, int stage_num, int node_num, Node *invNode,
    914                   int inv_stage_num, int inv_node_num, int inv_stage_idx,
    915                   int inv_node_idx) {
    916  // clean up inNodeNum in invNode because of add_node
    917  for (int si = 1 + inv_stage_idx; si < inv_stage_idx + stage_num; si++) {
    918    for (int ni = inv_node_idx; ni < inv_node_idx + node_num; ni++) {
    919      int idx = get_idx(si, ni, inv_node_num);
    920      invNode[idx].inNodeNum = 0;
    921    }
    922  }
    923  // generate inverse graph of node on invNode
    924  for (int si = 1; si < stage_num; si++) {
    925    for (int ni = 0; ni < node_num; ni++) {
    926      int invSi = stage_num - si;
    927      int idx = get_idx(si, ni, node_num);
    928      for (int k = 0; k < node[idx].inNodeNum; k++) {
    929        int invNi = node[idx].inNodeIdx[k];
    930        add_node(invNode, inv_stage_num, inv_node_num, invSi + inv_stage_idx,
    931                 invNi + inv_node_idx, ni + inv_node_idx,
    932                 node[idx].inWeight[k]);
    933      }
    934    }
    935  }
    936 }