txfm_graph.cc (32582B)
1 /* 2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include "tools/txfm_analyzer/txfm_graph.h" 13 14 #include <stdio.h> 15 #include <stdlib.h> 16 #include <math.h> 17 18 typedef struct Node Node; 19 20 void get_fun_name(char *str_fun_name, int str_buf_size, const TYPE_TXFM type, 21 const int txfm_size) { 22 if (type == TYPE_DCT) 23 snprintf(str_fun_name, str_buf_size, "fdct%d_new", txfm_size); 24 else if (type == TYPE_ADST) 25 snprintf(str_fun_name, str_buf_size, "fadst%d_new", txfm_size); 26 else if (type == TYPE_IDCT) 27 snprintf(str_fun_name, str_buf_size, "idct%d_new", txfm_size); 28 else if (type == TYPE_IADST) 29 snprintf(str_fun_name, str_buf_size, "iadst%d_new", txfm_size); 30 } 31 32 void get_txfm_type_name(char *str_fun_name, int str_buf_size, 33 const TYPE_TXFM type, const int txfm_size) { 34 if (type == TYPE_DCT) 35 snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size); 36 else if (type == TYPE_ADST) 37 snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size); 38 else if (type == TYPE_IDCT) 39 snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size); 40 else if (type == TYPE_IADST) 41 snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size); 42 } 43 44 void get_hybrid_2d_type_name(char *buf, int buf_size, const TYPE_TXFM type0, 45 const TYPE_TXFM type1, const int txfm_size0, 46 const int txfm_size1) { 47 if (type0 == TYPE_DCT && type1 == TYPE_DCT) 48 snprintf(buf, buf_size, "_dct_dct_%dx%d", txfm_size1, txfm_size0); 49 else if (type0 == TYPE_DCT && type1 == TYPE_ADST) 50 snprintf(buf, buf_size, "_dct_adst_%dx%d", txfm_size1, txfm_size0); 51 else if (type0 == TYPE_ADST && type1 == TYPE_ADST) 52 snprintf(buf, buf_size, "_adst_adst_%dx%d", txfm_size1, txfm_size0); 53 else if (type0 == TYPE_ADST && type1 == TYPE_DCT) 54 snprintf(buf, buf_size, "_adst_dct_%dx%d", txfm_size1, txfm_size0); 55 } 56 57 TYPE_TXFM get_inv_type(TYPE_TXFM type) { 58 if (type == TYPE_DCT) 59 return TYPE_IDCT; 60 else if (type == TYPE_ADST) 61 return TYPE_IADST; 62 else if (type == TYPE_IDCT) 63 return TYPE_DCT; 64 else if (type == TYPE_IADST) 65 return TYPE_ADST; 66 else 67 return TYPE_LAST; 68 } 69 70 void reference_dct_1d(double *in, double *out, int size) { 71 const double kInvSqrt2 = 0.707106781186547524400844362104; 72 for (int k = 0; k < size; k++) { 73 out[k] = 0; // initialize out[k] 74 for (int n = 0; n < size; n++) { 75 out[k] += in[n] * cos(PI * (2 * n + 1) * k / (2 * size)); 76 } 77 if (k == 0) out[k] = out[k] * kInvSqrt2; 78 } 79 } 80 81 void reference_dct_2d(double *in, double *out, int size) { 82 double *tempOut = new double[size * size]; 83 // dct each row: in -> out 84 for (int r = 0; r < size; r++) { 85 reference_dct_1d(in + r * size, out + r * size, size); 86 } 87 88 for (int r = 0; r < size; r++) { 89 // out ->tempOut 90 for (int c = 0; c < size; c++) { 91 tempOut[r * size + c] = out[c * size + r]; 92 } 93 } 94 for (int r = 0; r < size; r++) { 95 reference_dct_1d(tempOut + r * size, out + r * size, size); 96 } 97 delete[] tempOut; 98 } 99 100 void reference_adst_1d(double *in, double *out, int size) { 101 for (int k = 0; k < size; k++) { 102 out[k] = 0; // initialize out[k] 103 for (int n = 0; n < size; n++) { 104 out[k] += in[n] * sin(PI * (2 * n + 1) * (2 * k + 1) / (4 * size)); 105 } 106 } 107 } 108 109 void reference_hybrid_2d(double *in, double *out, int size, int type0, 110 int type1) { 111 double *tempOut = new double[size * size]; 112 // dct each row: in -> out 113 for (int r = 0; r < size; r++) { 114 if (type0 == TYPE_DCT) 115 reference_dct_1d(in + r * size, out + r * size, size); 116 else 117 reference_adst_1d(in + r * size, out + r * size, size); 118 } 119 120 for (int r = 0; r < size; r++) { 121 // out ->tempOut 122 for (int c = 0; c < size; c++) { 123 tempOut[r * size + c] = out[c * size + r]; 124 } 125 } 126 for (int r = 0; r < size; r++) { 127 if (type1 == TYPE_DCT) 128 reference_dct_1d(tempOut + r * size, out + r * size, size); 129 else 130 reference_adst_1d(tempOut + r * size, out + r * size, size); 131 } 132 delete[] tempOut; 133 } 134 135 void reference_hybrid_2d_new(double *in, double *out, int size0, int size1, 136 int type0, int type1) { 137 double *tempOut = new double[size0 * size1]; 138 // dct each row: in -> out 139 for (int r = 0; r < size1; r++) { 140 if (type0 == TYPE_DCT) 141 reference_dct_1d(in + r * size0, out + r * size0, size0); 142 else 143 reference_adst_1d(in + r * size0, out + r * size0, size0); 144 } 145 146 for (int r = 0; r < size1; r++) { 147 // out ->tempOut 148 for (int c = 0; c < size0; c++) { 149 tempOut[c * size1 + r] = out[r * size0 + c]; 150 } 151 } 152 for (int r = 0; r < size0; r++) { 153 if (type1 == TYPE_DCT) 154 reference_dct_1d(tempOut + r * size1, out + r * size1, size1); 155 else 156 reference_adst_1d(tempOut + r * size1, out + r * size1, size1); 157 } 158 delete[] tempOut; 159 } 160 161 unsigned int get_max_bit(unsigned int x) { 162 int max_bit = -1; 163 while (x) { 164 x = x >> 1; 165 max_bit++; 166 } 167 return max_bit; 168 } 169 170 unsigned int bitwise_reverse(unsigned int x, int max_bit) { 171 x = ((x >> 16) & 0x0000ffff) | ((x & 0x0000ffff) << 16); 172 x = ((x >> 8) & 0x00ff00ff) | ((x & 0x00ff00ff) << 8); 173 x = ((x >> 4) & 0x0f0f0f0f) | ((x & 0x0f0f0f0f) << 4); 174 x = ((x >> 2) & 0x33333333) | ((x & 0x33333333) << 2); 175 x = ((x >> 1) & 0x55555555) | ((x & 0x55555555) << 1); 176 x = x >> (31 - max_bit); 177 return x; 178 } 179 180 int get_idx(int ri, int ci, int cSize) { return ri * cSize + ci; } 181 182 void add_node(Node *node, int stage_num, int node_num, int stage_idx, 183 int node_idx, int in, double w) { 184 int outIdx = get_idx(stage_idx, node_idx, node_num); 185 int inIdx = get_idx(stage_idx - 1, in, node_num); 186 int idx = node[outIdx].inNodeNum; 187 if (idx < 2) { 188 node[outIdx].inNode[idx] = &node[inIdx]; 189 node[outIdx].inNodeIdx[idx] = in; 190 node[outIdx].inWeight[idx] = w; 191 idx++; 192 node[outIdx].inNodeNum = idx; 193 } else { 194 printf("Error: inNode is full"); 195 } 196 } 197 198 void connect_node(Node *node, int stage_num, int node_num, int stage_idx, 199 int node_idx, int in0, double w0, int in1, double w1) { 200 int outIdx = get_idx(stage_idx, node_idx, node_num); 201 int inIdx0 = get_idx(stage_idx - 1, in0, node_num); 202 int inIdx1 = get_idx(stage_idx - 1, in1, node_num); 203 204 int idx = 0; 205 // if(w0 != 0) { 206 node[outIdx].inNode[idx] = &node[inIdx0]; 207 node[outIdx].inNodeIdx[idx] = in0; 208 node[outIdx].inWeight[idx] = w0; 209 idx++; 210 //} 211 212 // if(w1 != 0) { 213 node[outIdx].inNode[idx] = &node[inIdx1]; 214 node[outIdx].inNodeIdx[idx] = in1; 215 node[outIdx].inWeight[idx] = w1; 216 idx++; 217 //} 218 219 node[outIdx].inNodeNum = idx; 220 } 221 222 void propagate(Node *node, int stage_num, int node_num, int stage_idx) { 223 for (int ni = 0; ni < node_num; ni++) { 224 int outIdx = get_idx(stage_idx, ni, node_num); 225 node[outIdx].value = 0; 226 for (int k = 0; k < node[outIdx].inNodeNum; k++) { 227 node[outIdx].value += 228 node[outIdx].inNode[k]->value * node[outIdx].inWeight[k]; 229 } 230 } 231 } 232 233 int64_t round_shift(int64_t value, int bit) { 234 if (bit > 0) { 235 if (value < 0) { 236 return -round_shift(-value, bit); 237 } else { 238 return (value + (1 << (bit - 1))) >> bit; 239 } 240 } else { 241 return value << (-bit); 242 } 243 } 244 245 void round_shift_array(int32_t *arr, int size, int bit) { 246 if (bit == 0) { 247 return; 248 } else { 249 for (int i = 0; i < size; i++) { 250 arr[i] = round_shift(arr[i], bit); 251 } 252 } 253 } 254 255 void graph_reset_visited(Node *node, int stage_num, int node_num) { 256 for (int si = 0; si < stage_num; si++) { 257 for (int ni = 0; ni < node_num; ni++) { 258 int idx = get_idx(si, ni, node_num); 259 node[idx].visited = 0; 260 } 261 } 262 } 263 264 void estimate_value(Node *node, int stage_num, int node_num, int stage_idx, 265 int node_idx, int estimate_bit) { 266 if (stage_idx > 0) { 267 int outIdx = get_idx(stage_idx, node_idx, node_num); 268 int64_t out = 0; 269 node[outIdx].value = 0; 270 for (int k = 0; k < node[outIdx].inNodeNum; k++) { 271 int64_t w = round(node[outIdx].inWeight[k] * (1 << estimate_bit)); 272 int64_t v = round(node[outIdx].inNode[k]->value); 273 out += v * w; 274 } 275 node[outIdx].value = round_shift(out, estimate_bit); 276 } 277 } 278 279 void amplify_value(Node *node, int stage_num, int node_num, int stage_idx, 280 int node_idx, int amplify_bit) { 281 int outIdx = get_idx(stage_idx, node_idx, node_num); 282 node[outIdx].value = round_shift(round(node[outIdx].value), -amplify_bit); 283 } 284 285 void propagate_estimate_amlify(Node *node, int stage_num, int node_num, 286 int stage_idx, int amplify_bit, 287 int estimate_bit) { 288 for (int ni = 0; ni < node_num; ni++) { 289 estimate_value(node, stage_num, node_num, stage_idx, ni, estimate_bit); 290 amplify_value(node, stage_num, node_num, stage_idx, ni, amplify_bit); 291 } 292 } 293 294 void init_graph(Node *node, int stage_num, int node_num) { 295 for (int si = 0; si < stage_num; si++) { 296 for (int ni = 0; ni < node_num; ni++) { 297 int outIdx = get_idx(si, ni, node_num); 298 node[outIdx].stageIdx = si; 299 node[outIdx].nodeIdx = ni; 300 node[outIdx].value = 0; 301 node[outIdx].inNodeNum = 0; 302 if (si >= 1) { 303 connect_node(node, stage_num, node_num, si, ni, ni, 1, ni, 0); 304 } 305 } 306 } 307 } 308 309 void gen_B_graph(Node *node, int stage_num, int node_num, int stage_idx, 310 int node_idx, int N, int star) { 311 for (int i = 0; i < N / 2; i++) { 312 int out = node_idx + i; 313 int in1 = node_idx + N - 1 - i; 314 if (star == 1) { 315 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1, 316 1); 317 } else { 318 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1, 319 1); 320 } 321 } 322 for (int i = N / 2; i < N; i++) { 323 int out = node_idx + i; 324 int in1 = node_idx + N - 1 - i; 325 if (star == 1) { 326 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1, 327 1); 328 } else { 329 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1, 330 1); 331 } 332 } 333 } 334 335 void gen_P_graph(Node *node, int stage_num, int node_num, int stage_idx, 336 int node_idx, int N) { 337 int max_bit = get_max_bit(N - 1); 338 for (int i = 0; i < N; i++) { 339 int out = node_idx + bitwise_reverse(i, max_bit); 340 int in = node_idx + i; 341 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); 342 } 343 } 344 345 void gen_type1_graph(Node *node, int stage_num, int node_num, int stage_idx, 346 int node_idx, int N) { 347 int max_bit = get_max_bit(N); 348 for (int ni = 0; ni < N / 2; ni++) { 349 int ai = bitwise_reverse(N + ni, max_bit); 350 int out = node_idx + ni; 351 int in1 = node_idx + N - ni - 1; 352 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 353 sin(PI * ai / (2 * 2 * N)), in1, cos(PI * ai / (2 * 2 * N))); 354 } 355 for (int ni = N / 2; ni < N; ni++) { 356 int ai = bitwise_reverse(N + ni, max_bit); 357 int out = node_idx + ni; 358 int in1 = node_idx + N - ni - 1; 359 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 360 cos(PI * ai / (2 * 2 * N)), in1, -sin(PI * ai / (2 * 2 * N))); 361 } 362 } 363 364 void gen_type2_graph(Node *node, int stage_num, int node_num, int stage_idx, 365 int node_idx, int N) { 366 for (int ni = 0; ni < N / 4; ni++) { 367 int out = node_idx + ni; 368 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0); 369 } 370 371 for (int ni = N / 4; ni < N / 2; ni++) { 372 int out = node_idx + ni; 373 int in1 = node_idx + N - ni - 1; 374 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 375 -cos(PI / 4), in1, cos(-PI / 4)); 376 } 377 378 for (int ni = N / 2; ni < N * 3 / 4; ni++) { 379 int out = node_idx + ni; 380 int in1 = node_idx + N - ni - 1; 381 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 382 cos(-PI / 4), in1, cos(PI / 4)); 383 } 384 385 for (int ni = N * 3 / 4; ni < N; ni++) { 386 int out = node_idx + ni; 387 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0); 388 } 389 } 390 391 void gen_type3_graph(Node *node, int stage_num, int node_num, int stage_idx, 392 int node_idx, int idx, int N) { 393 // TODO(angiebird): Simplify and clarify this function 394 395 int i = 2 * N / (1 << (idx / 2)); 396 int max_bit = 397 get_max_bit(i / 2) - 1; // the max_bit counts on i/2 instead of N here 398 int N_over_i = 2 << (idx / 2); 399 400 for (int nj = 0; nj < N / 2; nj += N_over_i) { 401 int j = nj / (N_over_i); 402 int kj = bitwise_reverse(i / 4 + j, max_bit); 403 404 // I_N/2i --- 0 405 int offset = nj; 406 for (int ni = 0; ni < N_over_i / 4; ni++) { 407 int out = node_idx + offset + ni; 408 int in = out; 409 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); 410 } 411 412 // -C_Kj/i --- S_Kj/i 413 offset += N_over_i / 4; 414 for (int ni = 0; ni < N_over_i / 4; ni++) { 415 int out = node_idx + offset + ni; 416 int in0 = out; 417 double w0 = -cos(kj * PI / i); 418 int in1 = N - (offset + ni) - 1 + node_idx; 419 double w1 = sin(kj * PI / i); 420 connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1, 421 w1); 422 } 423 424 // S_kj/i --- -C_Kj/i 425 offset += N_over_i / 4; 426 for (int ni = 0; ni < N_over_i / 4; ni++) { 427 int out = node_idx + offset + ni; 428 int in0 = out; 429 double w0 = -sin(kj * PI / i); 430 int in1 = N - (offset + ni) - 1 + node_idx; 431 double w1 = -cos(kj * PI / i); 432 connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1, 433 w1); 434 } 435 436 // I_N/2i --- 0 437 offset += N_over_i / 4; 438 for (int ni = 0; ni < N_over_i / 4; ni++) { 439 int out = node_idx + offset + ni; 440 int in = out; 441 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); 442 } 443 } 444 445 for (int nj = N / 2; nj < N; nj += N_over_i) { 446 int j = nj / N_over_i; 447 int kj = bitwise_reverse(i / 4 + j, max_bit); 448 449 // I_N/2i --- 0 450 int offset = nj; 451 for (int ni = 0; ni < N_over_i / 4; ni++) { 452 int out = node_idx + offset + ni; 453 int in = out; 454 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); 455 } 456 457 // C_kj/i --- -S_Kj/i 458 offset += N_over_i / 4; 459 for (int ni = 0; ni < N_over_i / 4; ni++) { 460 int out = node_idx + offset + ni; 461 int in0 = out; 462 double w0 = cos(kj * PI / i); 463 int in1 = N - (offset + ni) - 1 + node_idx; 464 double w1 = -sin(kj * PI / i); 465 connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1, 466 w1); 467 } 468 469 // S_kj/i --- C_Kj/i 470 offset += N_over_i / 4; 471 for (int ni = 0; ni < N_over_i / 4; ni++) { 472 int out = node_idx + offset + ni; 473 int in0 = out; 474 double w0 = sin(kj * PI / i); 475 int in1 = N - (offset + ni) - 1 + node_idx; 476 double w1 = cos(kj * PI / i); 477 connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1, 478 w1); 479 } 480 481 // I_N/2i --- 0 482 offset += N_over_i / 4; 483 for (int ni = 0; ni < N_over_i / 4; ni++) { 484 int out = node_idx + offset + ni; 485 int in = out; 486 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); 487 } 488 } 489 } 490 491 void gen_type4_graph(Node *node, int stage_num, int node_num, int stage_idx, 492 int node_idx, int idx, int N) { 493 int B_size = 1 << ((idx + 1) / 2); 494 for (int ni = 0; ni < N; ni += B_size) { 495 gen_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni, B_size, 496 (ni / B_size) % 2); 497 } 498 } 499 500 void gen_R_graph(Node *node, int stage_num, int node_num, int stage_idx, 501 int node_idx, int N) { 502 int max_idx = 2 * (get_max_bit(N) + 1) - 3; 503 for (int idx = 0; idx < max_idx; idx++) { 504 int s = stage_idx + max_idx - idx - 1; 505 if (idx == 0) { 506 // type 1 507 gen_type1_graph(node, stage_num, node_num, s, node_idx, N); 508 } else if (idx == max_idx - 1) { 509 // type 2 510 gen_type2_graph(node, stage_num, node_num, s, node_idx, N); 511 } else if ((idx + 1) % 2 == 0) { 512 // type 4 513 gen_type4_graph(node, stage_num, node_num, s, node_idx, idx, N); 514 } else if ((idx + 1) % 2 == 1) { 515 // type 3 516 gen_type3_graph(node, stage_num, node_num, s, node_idx, idx, N); 517 } else { 518 printf("check gen_R_graph()\n"); 519 } 520 } 521 } 522 523 void gen_DCT_graph(Node *node, int stage_num, int node_num, int stage_idx, 524 int node_idx, int N) { 525 if (N > 2) { 526 gen_B_graph(node, stage_num, node_num, stage_idx, node_idx, N, 0); 527 gen_DCT_graph(node, stage_num, node_num, stage_idx + 1, node_idx, N / 2); 528 gen_R_graph(node, stage_num, node_num, stage_idx + 1, node_idx + N / 2, 529 N / 2); 530 } else { 531 // generate dct_2 532 connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx, 533 cos(PI / 4), node_idx + 1, cos(PI / 4)); 534 connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1, 535 node_idx + 1, -cos(PI / 4), node_idx, cos(PI / 4)); 536 } 537 } 538 539 int get_dct_stage_num(int size) { return 2 * get_max_bit(size); } 540 541 void gen_DCT_graph_1d(Node *node, int stage_num, int node_num, int stage_idx, 542 int node_idx, int dct_node_num) { 543 gen_DCT_graph(node, stage_num, node_num, stage_idx, node_idx, dct_node_num); 544 int dct_stage_num = get_dct_stage_num(dct_node_num); 545 gen_P_graph(node, stage_num, node_num, stage_idx + dct_stage_num - 2, 546 node_idx, dct_node_num); 547 } 548 549 void gen_adst_B_graph(Node *node, int stage_num, int node_num, int stage_idx, 550 int node_idx, int adst_idx) { 551 int size = 1 << (adst_idx + 1); 552 for (int ni = 0; ni < size / 2; ni++) { 553 int nOut = node_idx + ni; 554 int nIn = nOut + size / 2; 555 connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, 1, nIn, 556 1); 557 } 558 for (int ni = size / 2; ni < size; ni++) { 559 int nOut = node_idx + ni; 560 int nIn = nOut - size / 2; 561 connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, -1, nIn, 562 1); 563 } 564 } 565 566 void gen_adst_U_graph(Node *node, int stage_num, int node_num, int stage_idx, 567 int node_idx, int adst_idx, int adst_node_num) { 568 int size = 1 << (adst_idx + 1); 569 for (int ni = 0; ni < adst_node_num; ni += size) { 570 gen_adst_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni, 571 adst_idx); 572 } 573 } 574 575 void gen_adst_T_graph(Node *node, int stage_num, int node_num, int stage_idx, 576 int node_idx, double freq) { 577 connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx, 578 cos(freq * PI), node_idx + 1, sin(freq * PI)); 579 connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1, 580 node_idx + 1, -cos(freq * PI), node_idx, sin(freq * PI)); 581 } 582 583 void gen_adst_E_graph(Node *node, int stage_num, int node_num, int stage_idx, 584 int node_idx, int adst_idx) { 585 int size = 1 << (adst_idx); 586 for (int i = 0; i < size / 2; i++) { 587 int ni = i * 2; 588 double fi = (1 + 4 * i) * 1.0 / (1 << (adst_idx + 1)); 589 gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi); 590 } 591 } 592 593 void gen_adst_V_graph(Node *node, int stage_num, int node_num, int stage_idx, 594 int node_idx, int adst_idx, int adst_node_num) { 595 int size = 1 << (adst_idx); 596 for (int i = 0; i < adst_node_num / size; i++) { 597 if (i % 2 == 1) { 598 int ni = i * size; 599 gen_adst_E_graph(node, stage_num, node_num, stage_idx, node_idx + ni, 600 adst_idx); 601 } 602 } 603 } 604 void gen_adst_VJ_graph(Node *node, int stage_num, int node_num, int stage_idx, 605 int node_idx, int adst_node_num) { 606 for (int i = 0; i < adst_node_num / 2; i++) { 607 int ni = i * 2; 608 double fi = (1 + 4 * i) * 1.0 / (4 * adst_node_num); 609 gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi); 610 } 611 } 612 void gen_adst_Q_graph(Node *node, int stage_num, int node_num, int stage_idx, 613 int node_idx, int adst_node_num) { 614 // reverse order when idx is 1, 3, 5, 7 ... 615 // example of adst_node_num = 8: 616 // 0 1 2 3 4 5 6 7 617 // --> 0 7 2 5 4 3 6 1 618 for (int ni = 0; ni < adst_node_num; ni++) { 619 if (ni % 2 == 0) { 620 int out = node_idx + ni; 621 connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 622 0); 623 } else { 624 int out = node_idx + ni; 625 int in = node_idx + adst_node_num - ni; 626 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); 627 } 628 } 629 } 630 void gen_adst_Ibar_graph(Node *node, int stage_num, int node_num, int stage_idx, 631 int node_idx, int adst_node_num) { 632 // reverse order 633 // 0 1 2 3 --> 3 2 1 0 634 for (int ni = 0; ni < adst_node_num; ni++) { 635 int out = node_idx + ni; 636 int in = node_idx + adst_node_num - ni - 1; 637 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); 638 } 639 } 640 641 int get_Q_out2in(int adst_node_num, int out) { 642 int in; 643 if (out % 2 == 0) { 644 in = out; 645 } else { 646 in = adst_node_num - out; 647 } 648 return in; 649 } 650 651 int get_Ibar_out2in(int adst_node_num, int out) { 652 return adst_node_num - out - 1; 653 } 654 655 void gen_adst_IbarQ_graph(Node *node, int stage_num, int node_num, 656 int stage_idx, int node_idx, int adst_node_num) { 657 // in -> Ibar -> Q -> out 658 for (int ni = 0; ni < adst_node_num; ni++) { 659 int out = node_idx + ni; 660 int in = node_idx + 661 get_Ibar_out2in(adst_node_num, get_Q_out2in(adst_node_num, ni)); 662 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); 663 } 664 } 665 666 void gen_adst_D_graph(Node *node, int stage_num, int node_num, int stage_idx, 667 int node_idx, int adst_node_num) { 668 // reverse order 669 for (int ni = 0; ni < adst_node_num; ni++) { 670 int out = node_idx + ni; 671 int in = out; 672 if (ni % 2 == 0) { 673 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); 674 } else { 675 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, -1, in, 676 0); 677 } 678 } 679 } 680 681 int get_hadamard_idx(int x, int adst_node_num) { 682 int max_bit = get_max_bit(adst_node_num - 1); 683 x = bitwise_reverse(x, max_bit); 684 685 // gray code 686 int c = x & 1; 687 int p = x & 1; 688 int y = c; 689 690 for (int i = 1; i <= max_bit; i++) { 691 p = c; 692 c = (x >> i) & 1; 693 y += (c ^ p) << i; 694 } 695 return y; 696 } 697 698 void gen_adst_Ht_graph(Node *node, int stage_num, int node_num, int stage_idx, 699 int node_idx, int adst_node_num) { 700 for (int ni = 0; ni < adst_node_num; ni++) { 701 int out = node_idx + ni; 702 int in = node_idx + get_hadamard_idx(ni, adst_node_num); 703 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0); 704 } 705 } 706 707 void gen_adst_HtD_graph(Node *node, int stage_num, int node_num, int stage_idx, 708 int node_idx, int adst_node_num) { 709 for (int ni = 0; ni < adst_node_num; ni++) { 710 int out = node_idx + ni; 711 int in = node_idx + get_hadamard_idx(ni, adst_node_num); 712 double inW; 713 if (ni % 2 == 0) 714 inW = 1; 715 else 716 inW = -1; 717 connect_node(node, stage_num, node_num, stage_idx + 1, out, in, inW, in, 0); 718 } 719 } 720 721 int get_adst_stage_num(int adst_node_num) { 722 return 2 * get_max_bit(adst_node_num) + 2; 723 } 724 725 int gen_iadst_graph(Node *node, int stage_num, int node_num, int stage_idx, 726 int node_idx, int adst_node_num) { 727 int max_bit = get_max_bit(adst_node_num); 728 int si = 0; 729 gen_adst_IbarQ_graph(node, stage_num, node_num, stage_idx + si, node_idx, 730 adst_node_num); 731 si++; 732 gen_adst_VJ_graph(node, stage_num, node_num, stage_idx + si, node_idx, 733 adst_node_num); 734 si++; 735 for (int adst_idx = max_bit - 1; adst_idx >= 1; adst_idx--) { 736 gen_adst_U_graph(node, stage_num, node_num, stage_idx + si, node_idx, 737 adst_idx, adst_node_num); 738 si++; 739 gen_adst_V_graph(node, stage_num, node_num, stage_idx + si, node_idx, 740 adst_idx, adst_node_num); 741 si++; 742 } 743 gen_adst_HtD_graph(node, stage_num, node_num, stage_idx + si, node_idx, 744 adst_node_num); 745 si++; 746 return si + 1; 747 } 748 749 int gen_adst_graph(Node *node, int stage_num, int node_num, int stage_idx, 750 int node_idx, int adst_node_num) { 751 int hybrid_stage_num = get_hybrid_stage_num(TYPE_ADST, adst_node_num); 752 // generate a adst tempNode 753 Node *tempNode = new Node[hybrid_stage_num * adst_node_num]; 754 init_graph(tempNode, hybrid_stage_num, adst_node_num); 755 int si = gen_iadst_graph(tempNode, hybrid_stage_num, adst_node_num, 0, 0, 756 adst_node_num); 757 758 // tempNode's inverse graph to node[stage_idx][node_idx] 759 gen_inv_graph(tempNode, hybrid_stage_num, adst_node_num, node, stage_num, 760 node_num, stage_idx, node_idx); 761 delete[] tempNode; 762 return si; 763 } 764 765 void connect_layer_2d(Node *node, int stage_num, int node_num, int stage_idx, 766 int node_idx, int dct_node_num) { 767 for (int first = 0; first < dct_node_num; first++) { 768 for (int second = 0; second < dct_node_num; second++) { 769 // int sIn = stage_idx; 770 int sOut = stage_idx + 1; 771 int nIn = node_idx + first * dct_node_num + second; 772 int nOut = node_idx + second * dct_node_num + first; 773 774 connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0); 775 } 776 } 777 } 778 779 void connect_layer_2d_new(Node *node, int stage_num, int node_num, 780 int stage_idx, int node_idx, int dct_node_num0, 781 int dct_node_num1) { 782 for (int i = 0; i < dct_node_num1; i++) { 783 for (int j = 0; j < dct_node_num0; j++) { 784 // int sIn = stage_idx; 785 int sOut = stage_idx + 1; 786 int nIn = node_idx + i * dct_node_num0 + j; 787 int nOut = node_idx + j * dct_node_num1 + i; 788 789 connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0); 790 } 791 } 792 } 793 794 void gen_DCT_graph_2d(Node *node, int stage_num, int node_num, int stage_idx, 795 int node_idx, int dct_node_num) { 796 int dct_stage_num = get_dct_stage_num(dct_node_num); 797 // put 2 layers of dct_node_num DCTs on the graph 798 for (int ni = 0; ni < dct_node_num; ni++) { 799 gen_DCT_graph_1d(node, stage_num, node_num, stage_idx, 800 node_idx + ni * dct_node_num, dct_node_num); 801 gen_DCT_graph_1d(node, stage_num, node_num, stage_idx + dct_stage_num, 802 node_idx + ni * dct_node_num, dct_node_num); 803 } 804 // connect first layer and second layer 805 connect_layer_2d(node, stage_num, node_num, stage_idx + dct_stage_num - 1, 806 node_idx, dct_node_num); 807 } 808 809 int get_hybrid_stage_num(int type, int hybrid_node_num) { 810 if (type == TYPE_DCT || type == TYPE_IDCT) { 811 return get_dct_stage_num(hybrid_node_num); 812 } else if (type == TYPE_ADST || type == TYPE_IADST) { 813 return get_adst_stage_num(hybrid_node_num); 814 } 815 return 0; 816 } 817 818 int get_hybrid_2d_stage_num(int type0, int type1, int hybrid_node_num) { 819 int stage_num = 0; 820 stage_num += get_hybrid_stage_num(type0, hybrid_node_num); 821 stage_num += get_hybrid_stage_num(type1, hybrid_node_num); 822 return stage_num; 823 } 824 825 int get_hybrid_2d_stage_num_new(int type0, int type1, int hybrid_node_num0, 826 int hybrid_node_num1) { 827 int stage_num = 0; 828 stage_num += get_hybrid_stage_num(type0, hybrid_node_num0); 829 stage_num += get_hybrid_stage_num(type1, hybrid_node_num1); 830 return stage_num; 831 } 832 833 int get_hybrid_amplify_factor(int type, int hybrid_node_num) { 834 return get_max_bit(hybrid_node_num) - 1; 835 } 836 837 void gen_hybrid_graph_1d(Node *node, int stage_num, int node_num, int stage_idx, 838 int node_idx, int hybrid_node_num, int type) { 839 if (type == TYPE_DCT) { 840 gen_DCT_graph_1d(node, stage_num, node_num, stage_idx, node_idx, 841 hybrid_node_num); 842 } else if (type == TYPE_ADST) { 843 gen_adst_graph(node, stage_num, node_num, stage_idx, node_idx, 844 hybrid_node_num); 845 } else if (type == TYPE_IDCT) { 846 int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num); 847 // generate a dct tempNode 848 Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num]; 849 init_graph(tempNode, hybrid_stage_num, hybrid_node_num); 850 gen_DCT_graph_1d(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0, 851 hybrid_node_num); 852 853 // tempNode's inverse graph to node[stage_idx][node_idx] 854 gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num, 855 node_num, stage_idx, node_idx); 856 delete[] tempNode; 857 } else if (type == TYPE_IADST) { 858 int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num); 859 // generate a adst tempNode 860 Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num]; 861 init_graph(tempNode, hybrid_stage_num, hybrid_node_num); 862 gen_adst_graph(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0, 863 hybrid_node_num); 864 865 // tempNode's inverse graph to node[stage_idx][node_idx] 866 gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num, 867 node_num, stage_idx, node_idx); 868 delete[] tempNode; 869 } 870 } 871 872 void gen_hybrid_graph_2d(Node *node, int stage_num, int node_num, int stage_idx, 873 int node_idx, int hybrid_node_num, int type0, 874 int type1) { 875 int hybrid_stage_num = get_hybrid_stage_num(type0, hybrid_node_num); 876 877 for (int ni = 0; ni < hybrid_node_num; ni++) { 878 gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx, 879 node_idx + ni * hybrid_node_num, hybrid_node_num, 880 type0); 881 gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx + hybrid_stage_num, 882 node_idx + ni * hybrid_node_num, hybrid_node_num, 883 type1); 884 } 885 886 // connect first layer and second layer 887 connect_layer_2d(node, stage_num, node_num, stage_idx + hybrid_stage_num - 1, 888 node_idx, hybrid_node_num); 889 } 890 891 void gen_hybrid_graph_2d_new(Node *node, int stage_num, int node_num, 892 int stage_idx, int node_idx, int hybrid_node_num0, 893 int hybrid_node_num1, int type0, int type1) { 894 int hybrid_stage_num0 = get_hybrid_stage_num(type0, hybrid_node_num0); 895 896 for (int ni = 0; ni < hybrid_node_num1; ni++) { 897 gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx, 898 node_idx + ni * hybrid_node_num0, hybrid_node_num0, 899 type0); 900 } 901 for (int ni = 0; ni < hybrid_node_num0; ni++) { 902 gen_hybrid_graph_1d( 903 node, stage_num, node_num, stage_idx + hybrid_stage_num0, 904 node_idx + ni * hybrid_node_num1, hybrid_node_num1, type1); 905 } 906 907 // connect first layer and second layer 908 connect_layer_2d_new(node, stage_num, node_num, 909 stage_idx + hybrid_stage_num0 - 1, node_idx, 910 hybrid_node_num0, hybrid_node_num1); 911 } 912 913 void gen_inv_graph(Node *node, int stage_num, int node_num, Node *invNode, 914 int inv_stage_num, int inv_node_num, int inv_stage_idx, 915 int inv_node_idx) { 916 // clean up inNodeNum in invNode because of add_node 917 for (int si = 1 + inv_stage_idx; si < inv_stage_idx + stage_num; si++) { 918 for (int ni = inv_node_idx; ni < inv_node_idx + node_num; ni++) { 919 int idx = get_idx(si, ni, inv_node_num); 920 invNode[idx].inNodeNum = 0; 921 } 922 } 923 // generate inverse graph of node on invNode 924 for (int si = 1; si < stage_num; si++) { 925 for (int ni = 0; ni < node_num; ni++) { 926 int invSi = stage_num - si; 927 int idx = get_idx(si, ni, node_num); 928 for (int k = 0; k < node[idx].inNodeNum; k++) { 929 int invNi = node[idx].inNodeIdx[k]; 930 add_node(invNode, inv_stage_num, inv_node_num, invSi + inv_stage_idx, 931 invNi + inv_node_idx, ni + inv_node_idx, 932 node[idx].inWeight[k]); 933 } 934 } 935 } 936 }