tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

neural_feature_extractor.cc (3554B)


      1 /*
      2 *  Copyright (c) 2025 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #include "modules/audio_processing/aec3/neural_feature_extractor.h"
     12 
     13 #include <algorithm>
     14 #include <cmath>
     15 #include <cstring>
     16 #include <vector>
     17 
     18 #include "api/array_view.h"
     19 #include "common_audio/window_generator.h"
     20 #include "rtc_base/checks.h"
     21 #include "third_party/pffft/src/pffft.h"
     22 
     23 namespace webrtc {
     24 
     25 namespace {
     26 // Trained moodel expects [-1,1]-scaled signals while AEC3 and APM scale
     27 // floating point signals up by 32768 to match 16-bit fixed-point formats, so we
     28 // convert to [-1,1] scale here.
     29 constexpr float kScale = 1.0f / 32768;
     30 // Exponent used to compress the power spectra.
     31 constexpr float kSpectrumCompressionExponent = 0.15f;
     32 
     33 std::vector<float> GetSqrtHanningWindow(int frame_size, float scale) {
     34  std::vector<float> window(frame_size);
     35  WindowGenerator::Hanning(frame_size, window.data());
     36  std::transform(window.begin(), window.end(), window.begin(),
     37                 [scale](float x) { return scale * std::sqrt(x); });
     38  return window;
     39 }
     40 
     41 }  // namespace
     42 
     43 void TimeDomainFeatureExtractor::PushFeaturesToModelInput(
     44    std::vector<float>& frame,
     45    ArrayView<float> input) {
     46  // Shift down overlap from previous frames.
     47  std::copy(input.begin() + frame.size(), input.end(), input.begin());
     48  std::transform(frame.begin(), frame.end(), input.end() - frame.size(),
     49                 [](float x) { return x * kScale; });
     50  frame.clear();
     51 }
     52 
     53 FrequencyDomainFeatureExtractor::FrequencyDomainFeatureExtractor(int step_size)
     54    : step_size_(step_size),
     55      frame_size_(2 * step_size_),
     56      sqrt_hanning_(GetSqrtHanningWindow(frame_size_, kScale)),
     57      data_(static_cast<float*>(
     58          pffft_aligned_malloc(frame_size_ * sizeof(float)))),
     59      spectrum_(static_cast<float*>(
     60          pffft_aligned_malloc(frame_size_ * sizeof(float)))),
     61      pffft_setup_(pffft_new_setup(frame_size_, PFFFT_REAL)) {
     62  std::memset(data_, 0, sizeof(float) * frame_size_);
     63  std::memset(spectrum_, 0, sizeof(float) * frame_size_);
     64 }
     65 
     66 FrequencyDomainFeatureExtractor::~FrequencyDomainFeatureExtractor() {
     67  pffft_destroy_setup(pffft_setup_);
     68  pffft_aligned_free(spectrum_);
     69  pffft_aligned_free(data_);
     70 }
     71 
     72 void FrequencyDomainFeatureExtractor::PushFeaturesToModelInput(
     73    std::vector<float>& frame,
     74    ArrayView<float> input) {
     75  std::memcpy(data_ + step_size_, frame.data(), sizeof(float) * step_size_);
     76  for (int k = 0; k < frame_size_; ++k) {
     77    data_[k] *= sqrt_hanning_[k];
     78  }
     79  pffft_transform_ordered(pffft_setup_, data_, spectrum_, nullptr,
     80                          PFFFT_FORWARD);
     81  RTC_CHECK_EQ(input.size(), step_size_ + 1);
     82  input[0] = spectrum_[0] * spectrum_[0];
     83  input[step_size_] = spectrum_[1] * spectrum_[1];
     84  for (int k = 1; k < step_size_; k++) {
     85    input[k] = spectrum_[2 * k] * spectrum_[2 * k] +
     86               spectrum_[2 * k + 1] * spectrum_[2 * k + 1];
     87  }
     88  // Compress the power spectra.
     89  std::transform(input.begin(), input.end(), input.begin(), [](float a) {
     90    return std::pow(a, kSpectrumCompressionExponent);
     91  });
     92  // Saving the current frame as it is used when computing the next FFT.
     93  std::memcpy(data_, frame.data(), sizeof(float) * step_size_);
     94 }
     95 
     96 }  // namespace webrtc