tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

subtractor.cc (15052B)


      1 /*
      2 *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #include "modules/audio_processing/aec3/subtractor.h"
     12 
     13 #include <algorithm>
     14 #include <array>
     15 #include <cstddef>
     16 #include <memory>
     17 #include <vector>
     18 
     19 #include "api/array_view.h"
     20 #include "api/audio/echo_canceller3_config.h"
     21 #include "api/environment/environment.h"
     22 #include "api/field_trials_view.h"
     23 #include "modules/audio_processing/aec3/adaptive_fir_filter.h"
     24 #include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h"
     25 #include "modules/audio_processing/aec3/aec3_common.h"
     26 #include "modules/audio_processing/aec3/aec3_fft.h"
     27 #include "modules/audio_processing/aec3/aec_state.h"
     28 #include "modules/audio_processing/aec3/block.h"
     29 #include "modules/audio_processing/aec3/coarse_filter_update_gain.h"
     30 #include "modules/audio_processing/aec3/echo_path_variability.h"
     31 #include "modules/audio_processing/aec3/fft_data.h"
     32 #include "modules/audio_processing/aec3/refined_filter_update_gain.h"
     33 #include "modules/audio_processing/aec3/render_buffer.h"
     34 #include "modules/audio_processing/aec3/render_signal_analyzer.h"
     35 #include "modules/audio_processing/aec3/subtractor_output.h"
     36 #include "modules/audio_processing/logging/apm_data_dumper.h"
     37 #include "rtc_base/checks.h"
     38 #include "rtc_base/numerics/safe_minmax.h"
     39 
     40 namespace webrtc {
     41 
     42 namespace {
     43 
     44 bool UseCoarseFilterResetHangover(const FieldTrialsView& field_trials) {
     45  return !field_trials.IsEnabled(
     46      "WebRTC-Aec3CoarseFilterResetHangoverKillSwitch");
     47 }
     48 
     49 void PredictionError(const Aec3Fft& fft,
     50                     const FftData& S,
     51                     ArrayView<const float> y,
     52                     std::array<float, kBlockSize>* e,
     53                     std::array<float, kBlockSize>* s) {
     54  std::array<float, kFftLength> tmp;
     55  fft.Ifft(S, &tmp);
     56  constexpr float kScale = 1.0f / kFftLengthBy2;
     57  std::transform(y.begin(), y.end(), tmp.begin() + kFftLengthBy2, e->begin(),
     58                 [&](float a, float b) { return a - b * kScale; });
     59 
     60  if (s) {
     61    for (size_t k = 0; k < s->size(); ++k) {
     62      (*s)[k] = kScale * tmp[k + kFftLengthBy2];
     63    }
     64  }
     65 }
     66 
     67 void ScaleFilterOutput(ArrayView<const float> y,
     68                       float factor,
     69                       ArrayView<float> e,
     70                       ArrayView<float> s) {
     71  RTC_DCHECK_EQ(y.size(), e.size());
     72  RTC_DCHECK_EQ(y.size(), s.size());
     73  for (size_t k = 0; k < y.size(); ++k) {
     74    s[k] *= factor;
     75    e[k] = y[k] - s[k];
     76  }
     77 }
     78 
     79 }  // namespace
     80 
     81 Subtractor::Subtractor(const Environment& env,
     82                       const EchoCanceller3Config& config,
     83                       size_t num_render_channels,
     84                       size_t num_capture_channels,
     85                       ApmDataDumper* data_dumper,
     86                       Aec3Optimization optimization)
     87    : fft_(),
     88      data_dumper_(data_dumper),
     89      optimization_(optimization),
     90      config_(config),
     91      num_capture_channels_(num_capture_channels),
     92      use_coarse_filter_reset_hangover_(
     93          UseCoarseFilterResetHangover(env.field_trials())),
     94      refined_filters_(num_capture_channels_),
     95      coarse_filter_(num_capture_channels_),
     96      refined_gains_(num_capture_channels_),
     97      coarse_gains_(num_capture_channels_),
     98      filter_misadjustment_estimators_(num_capture_channels_),
     99      poor_coarse_filter_counters_(num_capture_channels_, 0),
    100      coarse_filter_reset_hangover_(num_capture_channels_, 0),
    101      refined_frequency_responses_(
    102          num_capture_channels_,
    103          std::vector<std::array<float, kFftLengthBy2Plus1>>(
    104              std::max(config_.filter.refined_initial.length_blocks,
    105                       config_.filter.refined.length_blocks),
    106              std::array<float, kFftLengthBy2Plus1>())),
    107      refined_impulse_responses_(
    108          num_capture_channels_,
    109          std::vector<float>(GetTimeDomainLength(std::max(
    110                                 config_.filter.refined_initial.length_blocks,
    111                                 config_.filter.refined.length_blocks)),
    112                             0.f)),
    113      coarse_impulse_responses_(0) {
    114  // Set up the storing of coarse impulse responses if data dumping is
    115  // available.
    116  if (ApmDataDumper::IsAvailable()) {
    117    coarse_impulse_responses_.resize(num_capture_channels_);
    118    const size_t filter_size = GetTimeDomainLength(
    119        std::max(config_.filter.coarse_initial.length_blocks,
    120                 config_.filter.coarse.length_blocks));
    121    for (std::vector<float>& impulse_response : coarse_impulse_responses_) {
    122      impulse_response.resize(filter_size, 0.f);
    123    }
    124  }
    125 
    126  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
    127    refined_filters_[ch] = std::make_unique<AdaptiveFirFilter>(
    128        config_.filter.refined.length_blocks,
    129        config_.filter.refined_initial.length_blocks,
    130        config.filter.config_change_duration_blocks, num_render_channels,
    131        optimization, data_dumper_);
    132 
    133    coarse_filter_[ch] = std::make_unique<AdaptiveFirFilter>(
    134        config_.filter.coarse.length_blocks,
    135        config_.filter.coarse_initial.length_blocks,
    136        config.filter.config_change_duration_blocks, num_render_channels,
    137        optimization, data_dumper_);
    138    refined_gains_[ch] = std::make_unique<RefinedFilterUpdateGain>(
    139        config_.filter.refined_initial,
    140        config_.filter.config_change_duration_blocks);
    141    coarse_gains_[ch] = std::make_unique<CoarseFilterUpdateGain>(
    142        config_.filter.coarse_initial,
    143        config.filter.config_change_duration_blocks);
    144  }
    145 
    146  RTC_DCHECK(data_dumper_);
    147  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
    148    for (auto& H2_k : refined_frequency_responses_[ch]) {
    149      H2_k.fill(0.f);
    150    }
    151  }
    152 }
    153 
    154 Subtractor::~Subtractor() = default;
    155 
    156 void Subtractor::HandleEchoPathChange(
    157    const EchoPathVariability& echo_path_variability) {
    158  const auto full_reset = [&]() {
    159    for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
    160      refined_filters_[ch]->HandleEchoPathChange();
    161      coarse_filter_[ch]->HandleEchoPathChange();
    162      refined_gains_[ch]->HandleEchoPathChange(echo_path_variability);
    163      coarse_gains_[ch]->HandleEchoPathChange();
    164      refined_gains_[ch]->SetConfig(config_.filter.refined_initial, true);
    165      coarse_gains_[ch]->SetConfig(config_.filter.coarse_initial, true);
    166      refined_filters_[ch]->SetSizePartitions(
    167          config_.filter.refined_initial.length_blocks, true);
    168      coarse_filter_[ch]->SetSizePartitions(
    169          config_.filter.coarse_initial.length_blocks, true);
    170    }
    171  };
    172 
    173  if (echo_path_variability.delay_change !=
    174      EchoPathVariability::DelayAdjustment::kNone) {
    175    full_reset();
    176  }
    177 
    178  if (echo_path_variability.gain_change) {
    179    for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
    180      refined_gains_[ch]->HandleEchoPathChange(echo_path_variability);
    181    }
    182  }
    183 }
    184 
    185 void Subtractor::ExitInitialState() {
    186  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
    187    refined_gains_[ch]->SetConfig(config_.filter.refined, false);
    188    coarse_gains_[ch]->SetConfig(config_.filter.coarse, false);
    189    refined_filters_[ch]->SetSizePartitions(
    190        config_.filter.refined.length_blocks, false);
    191    coarse_filter_[ch]->SetSizePartitions(config_.filter.coarse.length_blocks,
    192                                          false);
    193  }
    194 }
    195 
    196 void Subtractor::Process(const RenderBuffer& render_buffer,
    197                         const Block& capture,
    198                         const RenderSignalAnalyzer& render_signal_analyzer,
    199                         const AecState& aec_state,
    200                         ArrayView<SubtractorOutput> outputs) {
    201  RTC_DCHECK_EQ(num_capture_channels_, capture.NumChannels());
    202 
    203  // Compute the render powers.
    204  const bool same_filter_sizes = refined_filters_[0]->SizePartitions() ==
    205                                 coarse_filter_[0]->SizePartitions();
    206  std::array<float, kFftLengthBy2Plus1> X2_refined;
    207  std::array<float, kFftLengthBy2Plus1> X2_coarse_data;
    208  auto& X2_coarse = same_filter_sizes ? X2_refined : X2_coarse_data;
    209  if (same_filter_sizes) {
    210    render_buffer.SpectralSum(refined_filters_[0]->SizePartitions(),
    211                              &X2_refined);
    212  } else if (refined_filters_[0]->SizePartitions() >
    213             coarse_filter_[0]->SizePartitions()) {
    214    render_buffer.SpectralSums(coarse_filter_[0]->SizePartitions(),
    215                               refined_filters_[0]->SizePartitions(),
    216                               &X2_coarse, &X2_refined);
    217  } else {
    218    render_buffer.SpectralSums(refined_filters_[0]->SizePartitions(),
    219                               coarse_filter_[0]->SizePartitions(), &X2_refined,
    220                               &X2_coarse);
    221  }
    222 
    223  // Process all capture channels
    224  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
    225    SubtractorOutput& output = outputs[ch];
    226    ArrayView<const float> y = capture.View(/*band=*/0, ch);
    227    FftData& E_refined = output.E_refined;
    228    FftData E_coarse;
    229    std::array<float, kBlockSize>& e_refined = output.e_refined;
    230    std::array<float, kBlockSize>& e_coarse = output.e_coarse;
    231 
    232    FftData S;
    233    FftData& G = S;
    234 
    235    // Form the outputs of the refined and coarse filters.
    236    refined_filters_[ch]->Filter(render_buffer, &S);
    237    PredictionError(fft_, S, y, &e_refined, &output.s_refined);
    238 
    239    coarse_filter_[ch]->Filter(render_buffer, &S);
    240    PredictionError(fft_, S, y, &e_coarse, &output.s_coarse);
    241 
    242    // Compute the signal powers in the subtractor output.
    243    output.ComputeMetrics(y);
    244 
    245    // Adjust the filter if needed.
    246    bool refined_filters_adjusted = false;
    247    filter_misadjustment_estimators_[ch].Update(output);
    248    if (filter_misadjustment_estimators_[ch].IsAdjustmentNeeded()) {
    249      float scale = filter_misadjustment_estimators_[ch].GetMisadjustment();
    250      refined_filters_[ch]->ScaleFilter(scale);
    251      for (auto& h_k : refined_impulse_responses_[ch]) {
    252        h_k *= scale;
    253      }
    254      ScaleFilterOutput(y, scale, e_refined, output.s_refined);
    255      filter_misadjustment_estimators_[ch].Reset();
    256      refined_filters_adjusted = true;
    257    }
    258 
    259    // Compute the FFts of the refined and coarse filter outputs.
    260    fft_.ZeroPaddedFft(e_refined, Aec3Fft::Window::kHanning, &E_refined);
    261    fft_.ZeroPaddedFft(e_coarse, Aec3Fft::Window::kHanning, &E_coarse);
    262 
    263    // Compute spectra for future use.
    264    E_coarse.Spectrum(optimization_, output.E2_coarse);
    265    E_refined.Spectrum(optimization_, output.E2_refined);
    266 
    267    // Update the refined filter.
    268    if (!refined_filters_adjusted) {
    269      // Do not allow the performance of the coarse filter to affect the
    270      // adaptation speed of the refined filter just after the coarse filter has
    271      // been reset.
    272      const bool disallow_leakage_diverged =
    273          coarse_filter_reset_hangover_[ch] > 0 &&
    274          use_coarse_filter_reset_hangover_;
    275 
    276      std::array<float, kFftLengthBy2Plus1> erl;
    277      ComputeErl(optimization_, refined_frequency_responses_[ch], erl);
    278      refined_gains_[ch]->Compute(X2_refined, render_signal_analyzer, output,
    279                                  erl, refined_filters_[ch]->SizePartitions(),
    280                                  aec_state.SaturatedCapture(),
    281                                  disallow_leakage_diverged, &G);
    282    } else {
    283      G.re.fill(0.f);
    284      G.im.fill(0.f);
    285    }
    286    refined_filters_[ch]->Adapt(render_buffer, G,
    287                                &refined_impulse_responses_[ch]);
    288    refined_filters_[ch]->ComputeFrequencyResponse(
    289        &refined_frequency_responses_[ch]);
    290 
    291    if (ch == 0) {
    292      data_dumper_->DumpRaw("aec3_subtractor_G_refined", G.re);
    293      data_dumper_->DumpRaw("aec3_subtractor_G_refined", G.im);
    294    }
    295 
    296    // Update the coarse filter.
    297    poor_coarse_filter_counters_[ch] =
    298        output.e2_refined < output.e2_coarse
    299            ? poor_coarse_filter_counters_[ch] + 1
    300            : 0;
    301    if (poor_coarse_filter_counters_[ch] < 5) {
    302      coarse_gains_[ch]->Compute(X2_coarse, render_signal_analyzer, E_coarse,
    303                                 coarse_filter_[ch]->SizePartitions(),
    304                                 aec_state.SaturatedCapture(), &G);
    305      coarse_filter_reset_hangover_[ch] =
    306          std::max(coarse_filter_reset_hangover_[ch] - 1, 0);
    307    } else {
    308      poor_coarse_filter_counters_[ch] = 0;
    309      coarse_filter_[ch]->SetFilter(refined_filters_[ch]->SizePartitions(),
    310                                    refined_filters_[ch]->GetFilter());
    311      coarse_gains_[ch]->Compute(X2_coarse, render_signal_analyzer, E_refined,
    312                                 coarse_filter_[ch]->SizePartitions(),
    313                                 aec_state.SaturatedCapture(), &G);
    314      coarse_filter_reset_hangover_[ch] =
    315          config_.filter.coarse_reset_hangover_blocks;
    316    }
    317 
    318    if (ApmDataDumper::IsAvailable()) {
    319      RTC_DCHECK_LT(ch, coarse_impulse_responses_.size());
    320      coarse_filter_[ch]->Adapt(render_buffer, G,
    321                                &coarse_impulse_responses_[ch]);
    322    } else {
    323      coarse_filter_[ch]->Adapt(render_buffer, G);
    324    }
    325 
    326    if (ch == 0) {
    327      data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.re);
    328      data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.im);
    329      filter_misadjustment_estimators_[ch].Dump(data_dumper_);
    330      DumpFilters();
    331    }
    332 
    333    std::for_each(e_refined.begin(), e_refined.end(),
    334                  [](float& a) { a = SafeClamp(a, -32768.f, 32767.f); });
    335 
    336    if (ch == 0) {
    337      data_dumper_->DumpWav("aec3_refined_filters_output", kBlockSize,
    338                            &e_refined[0], 16000, 1);
    339      data_dumper_->DumpWav("aec3_coarse_filter_output", kBlockSize,
    340                            &e_coarse[0], 16000, 1);
    341    }
    342  }
    343 }
    344 
    345 void Subtractor::FilterMisadjustmentEstimator::Update(
    346    const SubtractorOutput& output) {
    347  e2_acum_ += output.e2_refined;
    348  y2_acum_ += output.y2;
    349  if (++n_blocks_acum_ == n_blocks_) {
    350    if (y2_acum_ > n_blocks_ * 200.f * 200.f * kBlockSize) {
    351      float update = (e2_acum_ / y2_acum_);
    352      if (e2_acum_ > n_blocks_ * 7500.f * 7500.f * kBlockSize) {
    353        // Duration equal to blockSizeMs * n_blocks_ * 4.
    354        overhang_ = 4;
    355      } else {
    356        overhang_ = std::max(overhang_ - 1, 0);
    357      }
    358 
    359      if ((update < inv_misadjustment_) || (overhang_ > 0)) {
    360        inv_misadjustment_ += 0.1f * (update - inv_misadjustment_);
    361      }
    362    }
    363    e2_acum_ = 0.f;
    364    y2_acum_ = 0.f;
    365    n_blocks_acum_ = 0;
    366  }
    367 }
    368 
    369 void Subtractor::FilterMisadjustmentEstimator::Reset() {
    370  e2_acum_ = 0.f;
    371  y2_acum_ = 0.f;
    372  n_blocks_acum_ = 0;
    373  inv_misadjustment_ = 0.f;
    374  overhang_ = 0.f;
    375 }
    376 
    377 void Subtractor::FilterMisadjustmentEstimator::Dump(
    378    ApmDataDumper* data_dumper) const {
    379  data_dumper->DumpRaw("aec3_inv_misadjustment_factor", inv_misadjustment_);
    380 }
    381 
    382 }  // namespace webrtc