tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

fft_test.cc (8778B)


      1 /*
      2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
      3 *
      4 * This source code is subject to the terms of the BSD 2 Clause License and
      5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
      6 * was not distributed with this source code in the LICENSE file, you can
      7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
      8 * Media Patent License 1.0 was not distributed with this source code in the
      9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
     10 */
     11 
     12 #include <math.h>
     13 
     14 #include <algorithm>
     15 #include <complex>
     16 #include <ostream>
     17 #include <vector>
     18 
     19 #include "aom_dsp/fft_common.h"
     20 #include "aom_mem/aom_mem.h"
     21 #include "av1/common/common.h"
     22 #include "config/aom_dsp_rtcd.h"
     23 #include "gtest/gtest.h"
     24 #include "test/acm_random.h"
     25 
     26 namespace {
     27 
     28 using tform_fun_t = void (*)(const float *input, float *temp, float *output);
     29 
     30 // Simple 1D FFT implementation
     31 template <typename InputType>
     32 void fft(const InputType *data, std::complex<float> *result, int n) {
     33  if (n == 1) {
     34    result[0] = data[0];
     35    return;
     36  }
     37  std::vector<InputType> temp(n);
     38  for (int k = 0; k < n / 2; ++k) {
     39    temp[k] = data[2 * k];
     40    temp[n / 2 + k] = data[2 * k + 1];
     41  }
     42  fft(&temp[0], result, n / 2);
     43  fft(&temp[n / 2], result + n / 2, n / 2);
     44  for (int k = 0; k < n / 2; ++k) {
     45    std::complex<float> w = std::complex<float>((float)cos(2. * PI * k / n),
     46                                                (float)-sin(2. * PI * k / n));
     47    std::complex<float> a = result[k];
     48    std::complex<float> b = result[n / 2 + k];
     49    result[k] = a + w * b;
     50    result[n / 2 + k] = a - w * b;
     51  }
     52 }
     53 
     54 void transpose(std::vector<std::complex<float> > *data, int n) {
     55  for (int y = 0; y < n; ++y) {
     56    for (int x = y + 1; x < n; ++x) {
     57      std::swap((*data)[y * n + x], (*data)[x * n + y]);
     58    }
     59  }
     60 }
     61 
     62 // Simple 2D FFT implementation
     63 template <class InputType>
     64 std::vector<std::complex<float> > fft2d(const InputType *input, int n) {
     65  std::vector<std::complex<float> > rowfft(n * n);
     66  std::vector<std::complex<float> > result(n * n);
     67  for (int y = 0; y < n; ++y) {
     68    fft(input + y * n, &rowfft[y * n], n);
     69  }
     70  transpose(&rowfft, n);
     71  for (int y = 0; y < n; ++y) {
     72    fft(&rowfft[y * n], &result[y * n], n);
     73  }
     74  transpose(&result, n);
     75  return result;
     76 }
     77 
     78 struct FFTTestArg {
     79  int n;
     80  void (*fft)(const float *input, float *temp, float *output);
     81  FFTTestArg(int n_in, tform_fun_t fft_in) : n(n_in), fft(fft_in) {}
     82 };
     83 
     84 std::ostream &operator<<(std::ostream &os, const FFTTestArg &test_arg) {
     85  return os << "fft_arg { n:" << test_arg.n
     86            << " fft:" << reinterpret_cast<const void *>(test_arg.fft) << " }";
     87 }
     88 
     89 class FFT2DTest : public ::testing::TestWithParam<FFTTestArg> {
     90 protected:
     91  void SetUp() override {
     92    int n = GetParam().n;
     93    input_ = (float *)aom_memalign(32, sizeof(*input_) * n * n);
     94    temp_ = (float *)aom_memalign(32, sizeof(*temp_) * n * n);
     95    output_ = (float *)aom_memalign(32, sizeof(*output_) * n * n * 2);
     96    ASSERT_NE(input_, nullptr);
     97    ASSERT_NE(temp_, nullptr);
     98    ASSERT_NE(output_, nullptr);
     99    memset(input_, 0, sizeof(*input_) * n * n);
    100    memset(temp_, 0, sizeof(*temp_) * n * n);
    101    memset(output_, 0, sizeof(*output_) * n * n * 2);
    102  }
    103  void TearDown() override {
    104    aom_free(input_);
    105    aom_free(temp_);
    106    aom_free(output_);
    107  }
    108  float *input_;
    109  float *temp_;
    110  float *output_;
    111 };
    112 
    113 TEST_P(FFT2DTest, Correct) {
    114  int n = GetParam().n;
    115  for (int i = 0; i < n * n; ++i) {
    116    input_[i] = 1;
    117    std::vector<std::complex<float> > expected = fft2d<float>(&input_[0], n);
    118    GetParam().fft(&input_[0], &temp_[0], &output_[0]);
    119    for (int y = 0; y < n; ++y) {
    120      for (int x = 0; x < (n / 2) + 1; ++x) {
    121        EXPECT_NEAR(expected[y * n + x].real(), output_[2 * (y * n + x)], 1e-5);
    122        EXPECT_NEAR(expected[y * n + x].imag(), output_[2 * (y * n + x) + 1],
    123                    1e-5);
    124      }
    125    }
    126    input_[i] = 0;
    127  }
    128 }
    129 
    130 TEST_P(FFT2DTest, Benchmark) {
    131  int n = GetParam().n;
    132  float sum = 0;
    133  const int num_trials = 1000 * (64 - n);
    134  for (int i = 0; i < num_trials; ++i) {
    135    input_[i % (n * n)] = 1;
    136    GetParam().fft(&input_[0], &temp_[0], &output_[0]);
    137    sum += output_[0];
    138    input_[i % (n * n)] = 0;
    139  }
    140  EXPECT_NEAR(sum, num_trials, 1e-3);
    141 }
    142 
    143 INSTANTIATE_TEST_SUITE_P(C, FFT2DTest,
    144                         ::testing::Values(FFTTestArg(2, aom_fft2x2_float_c),
    145                                           FFTTestArg(4, aom_fft4x4_float_c),
    146                                           FFTTestArg(8, aom_fft8x8_float_c),
    147                                           FFTTestArg(16, aom_fft16x16_float_c),
    148                                           FFTTestArg(32,
    149                                                      aom_fft32x32_float_c)));
    150 #if AOM_ARCH_X86 || AOM_ARCH_X86_64
    151 #if HAVE_SSE2
    152 INSTANTIATE_TEST_SUITE_P(
    153    SSE2, FFT2DTest,
    154    ::testing::Values(FFTTestArg(4, aom_fft4x4_float_sse2),
    155                      FFTTestArg(8, aom_fft8x8_float_sse2),
    156                      FFTTestArg(16, aom_fft16x16_float_sse2),
    157                      FFTTestArg(32, aom_fft32x32_float_sse2)));
    158 #endif  // HAVE_SSE2
    159 #if HAVE_AVX2
    160 INSTANTIATE_TEST_SUITE_P(
    161    AVX2, FFT2DTest,
    162    ::testing::Values(FFTTestArg(8, aom_fft8x8_float_avx2),
    163                      FFTTestArg(16, aom_fft16x16_float_avx2),
    164                      FFTTestArg(32, aom_fft32x32_float_avx2)));
    165 #endif  // HAVE_AVX2
    166 #endif  // AOM_ARCH_X86 || AOM_ARCH_X86_64
    167 
    168 struct IFFTTestArg {
    169  int n;
    170  tform_fun_t ifft;
    171  IFFTTestArg(int n_in, tform_fun_t ifft_in) : n(n_in), ifft(ifft_in) {}
    172 };
    173 
    174 std::ostream &operator<<(std::ostream &os, const IFFTTestArg &test_arg) {
    175  return os << "ifft_arg { n:" << test_arg.n
    176            << " fft:" << reinterpret_cast<const void *>(test_arg.ifft) << " }";
    177 }
    178 
    179 class IFFT2DTest : public ::testing::TestWithParam<IFFTTestArg> {
    180 protected:
    181  void SetUp() override {
    182    int n = GetParam().n;
    183    input_ = (float *)aom_memalign(32, sizeof(*input_) * n * n * 2);
    184    temp_ = (float *)aom_memalign(32, sizeof(*temp_) * n * n * 2);
    185    output_ = (float *)aom_memalign(32, sizeof(*output_) * n * n);
    186    ASSERT_NE(input_, nullptr);
    187    ASSERT_NE(temp_, nullptr);
    188    ASSERT_NE(output_, nullptr);
    189    memset(input_, 0, sizeof(*input_) * n * n * 2);
    190    memset(temp_, 0, sizeof(*temp_) * n * n * 2);
    191    memset(output_, 0, sizeof(*output_) * n * n);
    192  }
    193  void TearDown() override {
    194    aom_free(input_);
    195    aom_free(temp_);
    196    aom_free(output_);
    197  }
    198  float *input_;
    199  float *temp_;
    200  float *output_;
    201 };
    202 
    203 TEST_P(IFFT2DTest, Correctness) {
    204  int n = GetParam().n;
    205  ASSERT_GE(n, 2);
    206  std::vector<float> expected(n * n);
    207  std::vector<float> actual(n * n);
    208  // Do forward transform then invert to make sure we get back expected
    209  for (int y = 0; y < n; ++y) {
    210    for (int x = 0; x < n; ++x) {
    211      expected[y * n + x] = 1;
    212      std::vector<std::complex<float> > input_c = fft2d(&expected[0], n);
    213      for (int i = 0; i < n * n; ++i) {
    214        input_[2 * i + 0] = input_c[i].real();
    215        input_[2 * i + 1] = input_c[i].imag();
    216      }
    217      GetParam().ifft(&input_[0], &temp_[0], &output_[0]);
    218 
    219      for (int yy = 0; yy < n; ++yy) {
    220        for (int xx = 0; xx < n; ++xx) {
    221          EXPECT_NEAR(expected[yy * n + xx], output_[yy * n + xx] / (n * n),
    222                      1e-5);
    223        }
    224      }
    225      expected[y * n + x] = 0;
    226    }
    227  }
    228 }
    229 
    230 TEST_P(IFFT2DTest, Benchmark) {
    231  int n = GetParam().n;
    232  float sum = 0;
    233  const int num_trials = 1000 * (64 - n);
    234  for (int i = 0; i < num_trials; ++i) {
    235    input_[i % (n * n)] = 1;
    236    GetParam().ifft(&input_[0], &temp_[0], &output_[0]);
    237    sum += output_[0];
    238    input_[i % (n * n)] = 0;
    239  }
    240  EXPECT_GE(sum, num_trials / 2);
    241 }
    242 INSTANTIATE_TEST_SUITE_P(
    243    C, IFFT2DTest,
    244    ::testing::Values(IFFTTestArg(2, aom_ifft2x2_float_c),
    245                      IFFTTestArg(4, aom_ifft4x4_float_c),
    246                      IFFTTestArg(8, aom_ifft8x8_float_c),
    247                      IFFTTestArg(16, aom_ifft16x16_float_c),
    248                      IFFTTestArg(32, aom_ifft32x32_float_c)));
    249 #if AOM_ARCH_X86 || AOM_ARCH_X86_64
    250 #if HAVE_SSE2
    251 INSTANTIATE_TEST_SUITE_P(
    252    SSE2, IFFT2DTest,
    253    ::testing::Values(IFFTTestArg(4, aom_ifft4x4_float_sse2),
    254                      IFFTTestArg(8, aom_ifft8x8_float_sse2),
    255                      IFFTTestArg(16, aom_ifft16x16_float_sse2),
    256                      IFFTTestArg(32, aom_ifft32x32_float_sse2)));
    257 #endif  // HAVE_SSE2
    258 
    259 #if HAVE_AVX2
    260 INSTANTIATE_TEST_SUITE_P(
    261    AVX2, IFFT2DTest,
    262    ::testing::Values(IFFTTestArg(8, aom_ifft8x8_float_avx2),
    263                      IFFTTestArg(16, aom_ifft16x16_float_avx2),
    264                      IFFTTestArg(32, aom_ifft32x32_float_avx2)));
    265 #endif  // HAVE_AVX2
    266 #endif  // AOM_ARCH_X86 || AOM_ARCH_X86_64
    267 
    268 }  // namespace