fft_test.cc (8778B)
1 /* 2 * Copyright (c) 2018, Alliance for Open Media. All rights reserved. 3 * 4 * This source code is subject to the terms of the BSD 2 Clause License and 5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 6 * was not distributed with this source code in the LICENSE file, you can 7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open 8 * Media Patent License 1.0 was not distributed with this source code in the 9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent. 10 */ 11 12 #include <math.h> 13 14 #include <algorithm> 15 #include <complex> 16 #include <ostream> 17 #include <vector> 18 19 #include "aom_dsp/fft_common.h" 20 #include "aom_mem/aom_mem.h" 21 #include "av1/common/common.h" 22 #include "config/aom_dsp_rtcd.h" 23 #include "gtest/gtest.h" 24 #include "test/acm_random.h" 25 26 namespace { 27 28 using tform_fun_t = void (*)(const float *input, float *temp, float *output); 29 30 // Simple 1D FFT implementation 31 template <typename InputType> 32 void fft(const InputType *data, std::complex<float> *result, int n) { 33 if (n == 1) { 34 result[0] = data[0]; 35 return; 36 } 37 std::vector<InputType> temp(n); 38 for (int k = 0; k < n / 2; ++k) { 39 temp[k] = data[2 * k]; 40 temp[n / 2 + k] = data[2 * k + 1]; 41 } 42 fft(&temp[0], result, n / 2); 43 fft(&temp[n / 2], result + n / 2, n / 2); 44 for (int k = 0; k < n / 2; ++k) { 45 std::complex<float> w = std::complex<float>((float)cos(2. * PI * k / n), 46 (float)-sin(2. * PI * k / n)); 47 std::complex<float> a = result[k]; 48 std::complex<float> b = result[n / 2 + k]; 49 result[k] = a + w * b; 50 result[n / 2 + k] = a - w * b; 51 } 52 } 53 54 void transpose(std::vector<std::complex<float> > *data, int n) { 55 for (int y = 0; y < n; ++y) { 56 for (int x = y + 1; x < n; ++x) { 57 std::swap((*data)[y * n + x], (*data)[x * n + y]); 58 } 59 } 60 } 61 62 // Simple 2D FFT implementation 63 template <class InputType> 64 std::vector<std::complex<float> > fft2d(const InputType *input, int n) { 65 std::vector<std::complex<float> > rowfft(n * n); 66 std::vector<std::complex<float> > result(n * n); 67 for (int y = 0; y < n; ++y) { 68 fft(input + y * n, &rowfft[y * n], n); 69 } 70 transpose(&rowfft, n); 71 for (int y = 0; y < n; ++y) { 72 fft(&rowfft[y * n], &result[y * n], n); 73 } 74 transpose(&result, n); 75 return result; 76 } 77 78 struct FFTTestArg { 79 int n; 80 void (*fft)(const float *input, float *temp, float *output); 81 FFTTestArg(int n_in, tform_fun_t fft_in) : n(n_in), fft(fft_in) {} 82 }; 83 84 std::ostream &operator<<(std::ostream &os, const FFTTestArg &test_arg) { 85 return os << "fft_arg { n:" << test_arg.n 86 << " fft:" << reinterpret_cast<const void *>(test_arg.fft) << " }"; 87 } 88 89 class FFT2DTest : public ::testing::TestWithParam<FFTTestArg> { 90 protected: 91 void SetUp() override { 92 int n = GetParam().n; 93 input_ = (float *)aom_memalign(32, sizeof(*input_) * n * n); 94 temp_ = (float *)aom_memalign(32, sizeof(*temp_) * n * n); 95 output_ = (float *)aom_memalign(32, sizeof(*output_) * n * n * 2); 96 ASSERT_NE(input_, nullptr); 97 ASSERT_NE(temp_, nullptr); 98 ASSERT_NE(output_, nullptr); 99 memset(input_, 0, sizeof(*input_) * n * n); 100 memset(temp_, 0, sizeof(*temp_) * n * n); 101 memset(output_, 0, sizeof(*output_) * n * n * 2); 102 } 103 void TearDown() override { 104 aom_free(input_); 105 aom_free(temp_); 106 aom_free(output_); 107 } 108 float *input_; 109 float *temp_; 110 float *output_; 111 }; 112 113 TEST_P(FFT2DTest, Correct) { 114 int n = GetParam().n; 115 for (int i = 0; i < n * n; ++i) { 116 input_[i] = 1; 117 std::vector<std::complex<float> > expected = fft2d<float>(&input_[0], n); 118 GetParam().fft(&input_[0], &temp_[0], &output_[0]); 119 for (int y = 0; y < n; ++y) { 120 for (int x = 0; x < (n / 2) + 1; ++x) { 121 EXPECT_NEAR(expected[y * n + x].real(), output_[2 * (y * n + x)], 1e-5); 122 EXPECT_NEAR(expected[y * n + x].imag(), output_[2 * (y * n + x) + 1], 123 1e-5); 124 } 125 } 126 input_[i] = 0; 127 } 128 } 129 130 TEST_P(FFT2DTest, Benchmark) { 131 int n = GetParam().n; 132 float sum = 0; 133 const int num_trials = 1000 * (64 - n); 134 for (int i = 0; i < num_trials; ++i) { 135 input_[i % (n * n)] = 1; 136 GetParam().fft(&input_[0], &temp_[0], &output_[0]); 137 sum += output_[0]; 138 input_[i % (n * n)] = 0; 139 } 140 EXPECT_NEAR(sum, num_trials, 1e-3); 141 } 142 143 INSTANTIATE_TEST_SUITE_P(C, FFT2DTest, 144 ::testing::Values(FFTTestArg(2, aom_fft2x2_float_c), 145 FFTTestArg(4, aom_fft4x4_float_c), 146 FFTTestArg(8, aom_fft8x8_float_c), 147 FFTTestArg(16, aom_fft16x16_float_c), 148 FFTTestArg(32, 149 aom_fft32x32_float_c))); 150 #if AOM_ARCH_X86 || AOM_ARCH_X86_64 151 #if HAVE_SSE2 152 INSTANTIATE_TEST_SUITE_P( 153 SSE2, FFT2DTest, 154 ::testing::Values(FFTTestArg(4, aom_fft4x4_float_sse2), 155 FFTTestArg(8, aom_fft8x8_float_sse2), 156 FFTTestArg(16, aom_fft16x16_float_sse2), 157 FFTTestArg(32, aom_fft32x32_float_sse2))); 158 #endif // HAVE_SSE2 159 #if HAVE_AVX2 160 INSTANTIATE_TEST_SUITE_P( 161 AVX2, FFT2DTest, 162 ::testing::Values(FFTTestArg(8, aom_fft8x8_float_avx2), 163 FFTTestArg(16, aom_fft16x16_float_avx2), 164 FFTTestArg(32, aom_fft32x32_float_avx2))); 165 #endif // HAVE_AVX2 166 #endif // AOM_ARCH_X86 || AOM_ARCH_X86_64 167 168 struct IFFTTestArg { 169 int n; 170 tform_fun_t ifft; 171 IFFTTestArg(int n_in, tform_fun_t ifft_in) : n(n_in), ifft(ifft_in) {} 172 }; 173 174 std::ostream &operator<<(std::ostream &os, const IFFTTestArg &test_arg) { 175 return os << "ifft_arg { n:" << test_arg.n 176 << " fft:" << reinterpret_cast<const void *>(test_arg.ifft) << " }"; 177 } 178 179 class IFFT2DTest : public ::testing::TestWithParam<IFFTTestArg> { 180 protected: 181 void SetUp() override { 182 int n = GetParam().n; 183 input_ = (float *)aom_memalign(32, sizeof(*input_) * n * n * 2); 184 temp_ = (float *)aom_memalign(32, sizeof(*temp_) * n * n * 2); 185 output_ = (float *)aom_memalign(32, sizeof(*output_) * n * n); 186 ASSERT_NE(input_, nullptr); 187 ASSERT_NE(temp_, nullptr); 188 ASSERT_NE(output_, nullptr); 189 memset(input_, 0, sizeof(*input_) * n * n * 2); 190 memset(temp_, 0, sizeof(*temp_) * n * n * 2); 191 memset(output_, 0, sizeof(*output_) * n * n); 192 } 193 void TearDown() override { 194 aom_free(input_); 195 aom_free(temp_); 196 aom_free(output_); 197 } 198 float *input_; 199 float *temp_; 200 float *output_; 201 }; 202 203 TEST_P(IFFT2DTest, Correctness) { 204 int n = GetParam().n; 205 ASSERT_GE(n, 2); 206 std::vector<float> expected(n * n); 207 std::vector<float> actual(n * n); 208 // Do forward transform then invert to make sure we get back expected 209 for (int y = 0; y < n; ++y) { 210 for (int x = 0; x < n; ++x) { 211 expected[y * n + x] = 1; 212 std::vector<std::complex<float> > input_c = fft2d(&expected[0], n); 213 for (int i = 0; i < n * n; ++i) { 214 input_[2 * i + 0] = input_c[i].real(); 215 input_[2 * i + 1] = input_c[i].imag(); 216 } 217 GetParam().ifft(&input_[0], &temp_[0], &output_[0]); 218 219 for (int yy = 0; yy < n; ++yy) { 220 for (int xx = 0; xx < n; ++xx) { 221 EXPECT_NEAR(expected[yy * n + xx], output_[yy * n + xx] / (n * n), 222 1e-5); 223 } 224 } 225 expected[y * n + x] = 0; 226 } 227 } 228 } 229 230 TEST_P(IFFT2DTest, Benchmark) { 231 int n = GetParam().n; 232 float sum = 0; 233 const int num_trials = 1000 * (64 - n); 234 for (int i = 0; i < num_trials; ++i) { 235 input_[i % (n * n)] = 1; 236 GetParam().ifft(&input_[0], &temp_[0], &output_[0]); 237 sum += output_[0]; 238 input_[i % (n * n)] = 0; 239 } 240 EXPECT_GE(sum, num_trials / 2); 241 } 242 INSTANTIATE_TEST_SUITE_P( 243 C, IFFT2DTest, 244 ::testing::Values(IFFTTestArg(2, aom_ifft2x2_float_c), 245 IFFTTestArg(4, aom_ifft4x4_float_c), 246 IFFTTestArg(8, aom_ifft8x8_float_c), 247 IFFTTestArg(16, aom_ifft16x16_float_c), 248 IFFTTestArg(32, aom_ifft32x32_float_c))); 249 #if AOM_ARCH_X86 || AOM_ARCH_X86_64 250 #if HAVE_SSE2 251 INSTANTIATE_TEST_SUITE_P( 252 SSE2, IFFT2DTest, 253 ::testing::Values(IFFTTestArg(4, aom_ifft4x4_float_sse2), 254 IFFTTestArg(8, aom_ifft8x8_float_sse2), 255 IFFTTestArg(16, aom_ifft16x16_float_sse2), 256 IFFTTestArg(32, aom_ifft32x32_float_sse2))); 257 #endif // HAVE_SSE2 258 259 #if HAVE_AVX2 260 INSTANTIATE_TEST_SUITE_P( 261 AVX2, IFFT2DTest, 262 ::testing::Values(IFFTTestArg(8, aom_ifft8x8_float_avx2), 263 IFFTTestArg(16, aom_ifft16x16_float_avx2), 264 IFFTTestArg(32, aom_ifft32x32_float_avx2))); 265 #endif // HAVE_AVX2 266 #endif // AOM_ARCH_X86 || AOM_ARCH_X86_64 267 268 } // namespace