pffft_wrapper.cc (4488B)
1 /* 2 * Copyright (c) 2019 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/utility/pffft_wrapper.h" 12 13 #include <cstddef> 14 #include <memory> 15 16 #include "api/array_view.h" 17 #include "rtc_base/checks.h" 18 #include "third_party/pffft/src/pffft.h" 19 20 namespace webrtc { 21 namespace { 22 23 size_t GetBufferSize(size_t fft_size, Pffft::FftType fft_type) { 24 return fft_size * (fft_type == Pffft::FftType::kReal ? 1 : 2); 25 } 26 27 float* AllocatePffftBuffer(size_t size) { 28 return static_cast<float*>(pffft_aligned_malloc(size * sizeof(float))); 29 } 30 31 } // namespace 32 33 Pffft::FloatBuffer::FloatBuffer(size_t fft_size, FftType fft_type) 34 : size_(GetBufferSize(fft_size, fft_type)), 35 data_(AllocatePffftBuffer(size_)) {} 36 37 Pffft::FloatBuffer::~FloatBuffer() { 38 pffft_aligned_free(data_); 39 } 40 41 ArrayView<const float> Pffft::FloatBuffer::GetConstView() const { 42 return {data_, size_}; 43 } 44 45 ArrayView<float> Pffft::FloatBuffer::GetView() { 46 return {data_, size_}; 47 } 48 49 Pffft::Pffft(size_t fft_size, FftType fft_type) 50 : fft_size_(fft_size), 51 fft_type_(fft_type), 52 pffft_status_(pffft_new_setup( 53 fft_size_, 54 fft_type == Pffft::FftType::kReal ? PFFFT_REAL : PFFFT_COMPLEX)), 55 scratch_buffer_( 56 AllocatePffftBuffer(GetBufferSize(fft_size_, fft_type_))) { 57 RTC_DCHECK(pffft_status_); 58 RTC_DCHECK(scratch_buffer_); 59 } 60 61 Pffft::~Pffft() { 62 pffft_destroy_setup(pffft_status_); 63 pffft_aligned_free(scratch_buffer_); 64 } 65 66 bool Pffft::IsValidFftSize(size_t fft_size, FftType fft_type) { 67 if (fft_size == 0) { 68 return false; 69 } 70 // PFFFT only supports transforms for inputs of length N of the form 71 // N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT 72 // and a >= 4 for the complex FFT. 73 constexpr int kFactors[] = {2, 3, 5}; 74 int factorization[] = {0, 0, 0}; 75 int n = static_cast<int>(fft_size); 76 for (int i = 0; i < 3; ++i) { 77 while (n % kFactors[i] == 0) { 78 n = n / kFactors[i]; 79 factorization[i]++; 80 } 81 } 82 int a_min = (fft_type == Pffft::FftType::kReal) ? 5 : 4; 83 return factorization[0] >= a_min && n == 1; 84 } 85 86 bool Pffft::IsSimdEnabled() { 87 return pffft_simd_size() > 1; 88 } 89 90 std::unique_ptr<Pffft::FloatBuffer> Pffft::CreateBuffer() const { 91 // Cannot use make_unique from absl because Pffft is the only friend of 92 // Pffft::FloatBuffer. 93 std::unique_ptr<Pffft::FloatBuffer> buffer( 94 new Pffft::FloatBuffer(fft_size_, fft_type_)); 95 return buffer; 96 } 97 98 void Pffft::ForwardTransform(const FloatBuffer& in, 99 FloatBuffer* out, 100 bool ordered) { 101 RTC_DCHECK_EQ(in.size(), GetBufferSize(fft_size_, fft_type_)); 102 RTC_DCHECK_EQ(in.size(), out->size()); 103 RTC_DCHECK(scratch_buffer_); 104 if (ordered) { 105 pffft_transform_ordered(pffft_status_, in.const_data(), out->data(), 106 scratch_buffer_, PFFFT_FORWARD); 107 } else { 108 pffft_transform(pffft_status_, in.const_data(), out->data(), 109 scratch_buffer_, PFFFT_FORWARD); 110 } 111 } 112 113 void Pffft::BackwardTransform(const FloatBuffer& in, 114 FloatBuffer* out, 115 bool ordered) { 116 RTC_DCHECK_EQ(in.size(), GetBufferSize(fft_size_, fft_type_)); 117 RTC_DCHECK_EQ(in.size(), out->size()); 118 RTC_DCHECK(scratch_buffer_); 119 if (ordered) { 120 pffft_transform_ordered(pffft_status_, in.const_data(), out->data(), 121 scratch_buffer_, PFFFT_BACKWARD); 122 } else { 123 pffft_transform(pffft_status_, in.const_data(), out->data(), 124 scratch_buffer_, PFFFT_BACKWARD); 125 } 126 } 127 128 void Pffft::FrequencyDomainConvolve(const FloatBuffer& fft_x, 129 const FloatBuffer& fft_y, 130 FloatBuffer* out, 131 float scaling) { 132 RTC_DCHECK_EQ(fft_x.size(), GetBufferSize(fft_size_, fft_type_)); 133 RTC_DCHECK_EQ(fft_x.size(), fft_y.size()); 134 RTC_DCHECK_EQ(fft_x.size(), out->size()); 135 pffft_zconvolve_accumulate(pffft_status_, fft_x.const_data(), 136 fft_y.const_data(), out->data(), scaling); 137 } 138 139 } // namespace webrtc