neural_feature_extractor.cc (3554B)
1 /* 2 * Copyright (c) 2025 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/aec3/neural_feature_extractor.h" 12 13 #include <algorithm> 14 #include <cmath> 15 #include <cstring> 16 #include <vector> 17 18 #include "api/array_view.h" 19 #include "common_audio/window_generator.h" 20 #include "rtc_base/checks.h" 21 #include "third_party/pffft/src/pffft.h" 22 23 namespace webrtc { 24 25 namespace { 26 // Trained moodel expects [-1,1]-scaled signals while AEC3 and APM scale 27 // floating point signals up by 32768 to match 16-bit fixed-point formats, so we 28 // convert to [-1,1] scale here. 29 constexpr float kScale = 1.0f / 32768; 30 // Exponent used to compress the power spectra. 31 constexpr float kSpectrumCompressionExponent = 0.15f; 32 33 std::vector<float> GetSqrtHanningWindow(int frame_size, float scale) { 34 std::vector<float> window(frame_size); 35 WindowGenerator::Hanning(frame_size, window.data()); 36 std::transform(window.begin(), window.end(), window.begin(), 37 [scale](float x) { return scale * std::sqrt(x); }); 38 return window; 39 } 40 41 } // namespace 42 43 void TimeDomainFeatureExtractor::PushFeaturesToModelInput( 44 std::vector<float>& frame, 45 ArrayView<float> input) { 46 // Shift down overlap from previous frames. 47 std::copy(input.begin() + frame.size(), input.end(), input.begin()); 48 std::transform(frame.begin(), frame.end(), input.end() - frame.size(), 49 [](float x) { return x * kScale; }); 50 frame.clear(); 51 } 52 53 FrequencyDomainFeatureExtractor::FrequencyDomainFeatureExtractor(int step_size) 54 : step_size_(step_size), 55 frame_size_(2 * step_size_), 56 sqrt_hanning_(GetSqrtHanningWindow(frame_size_, kScale)), 57 data_(static_cast<float*>( 58 pffft_aligned_malloc(frame_size_ * sizeof(float)))), 59 spectrum_(static_cast<float*>( 60 pffft_aligned_malloc(frame_size_ * sizeof(float)))), 61 pffft_setup_(pffft_new_setup(frame_size_, PFFFT_REAL)) { 62 std::memset(data_, 0, sizeof(float) * frame_size_); 63 std::memset(spectrum_, 0, sizeof(float) * frame_size_); 64 } 65 66 FrequencyDomainFeatureExtractor::~FrequencyDomainFeatureExtractor() { 67 pffft_destroy_setup(pffft_setup_); 68 pffft_aligned_free(spectrum_); 69 pffft_aligned_free(data_); 70 } 71 72 void FrequencyDomainFeatureExtractor::PushFeaturesToModelInput( 73 std::vector<float>& frame, 74 ArrayView<float> input) { 75 std::memcpy(data_ + step_size_, frame.data(), sizeof(float) * step_size_); 76 for (int k = 0; k < frame_size_; ++k) { 77 data_[k] *= sqrt_hanning_[k]; 78 } 79 pffft_transform_ordered(pffft_setup_, data_, spectrum_, nullptr, 80 PFFFT_FORWARD); 81 RTC_CHECK_EQ(input.size(), step_size_ + 1); 82 input[0] = spectrum_[0] * spectrum_[0]; 83 input[step_size_] = spectrum_[1] * spectrum_[1]; 84 for (int k = 1; k < step_size_; k++) { 85 input[k] = spectrum_[2 * k] * spectrum_[2 * k] + 86 spectrum_[2 * k + 1] * spectrum_[2 * k + 1]; 87 } 88 // Compress the power spectra. 89 std::transform(input.begin(), input.end(), input.begin(), [](float a) { 90 return std::pow(a, kSpectrumCompressionExponent); 91 }); 92 // Saving the current frame as it is used when computing the next FFT. 93 std::memcpy(data_, frame.data(), sizeof(float) * step_size_); 94 } 95 96 } // namespace webrtc