tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

subband_erle_estimator.cc (9350B)


      1 /*
      2 *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #include "modules/audio_processing/aec3/subband_erle_estimator.h"
     12 
     13 #include <algorithm>
     14 #include <array>
     15 #include <cstddef>
     16 #include <functional>
     17 #include <memory>
     18 #include <vector>
     19 
     20 #include "api/array_view.h"
     21 #include "api/audio/echo_canceller3_config.h"
     22 #include "api/environment/environment.h"
     23 #include "api/field_trials_view.h"
     24 #include "modules/audio_processing/aec3/aec3_common.h"
     25 #include "modules/audio_processing/logging/apm_data_dumper.h"
     26 #include "rtc_base/checks.h"
     27 #include "rtc_base/numerics/safe_minmax.h"
     28 
     29 namespace webrtc {
     30 
     31 namespace {
     32 
     33 constexpr float kX2BandEnergyThreshold = 44015068.0f;
     34 constexpr int kBlocksToHoldErle = 100;
     35 constexpr int kBlocksForOnsetDetection = kBlocksToHoldErle + 150;
     36 constexpr int kPointsToAccumulate = 6;
     37 
     38 std::array<float, kFftLengthBy2Plus1> SetMaxErleBands(float max_erle_l,
     39                                                      float max_erle_h) {
     40  std::array<float, kFftLengthBy2Plus1> max_erle;
     41  std::fill(max_erle.begin(), max_erle.begin() + kFftLengthBy2 / 2, max_erle_l);
     42  std::fill(max_erle.begin() + kFftLengthBy2 / 2, max_erle.end(), max_erle_h);
     43  return max_erle;
     44 }
     45 
     46 bool EnableMinErleDuringOnsets(const FieldTrialsView& field_trials) {
     47  return !field_trials.IsEnabled("WebRTC-Aec3MinErleDuringOnsetsKillSwitch");
     48 }
     49 
     50 }  // namespace
     51 
     52 SubbandErleEstimator::SubbandErleEstimator(const Environment& env,
     53                                           const EchoCanceller3Config& config,
     54                                           size_t num_capture_channels)
     55    : use_onset_detection_(config.erle.onset_detection),
     56      min_erle_(config.erle.min),
     57      max_erle_(SetMaxErleBands(config.erle.max_l, config.erle.max_h)),
     58      use_min_erle_during_onsets_(
     59          EnableMinErleDuringOnsets(env.field_trials())),
     60      accum_spectra_(num_capture_channels),
     61      erle_(num_capture_channels),
     62      erle_onset_compensated_(num_capture_channels),
     63      erle_unbounded_(num_capture_channels),
     64      erle_during_onsets_(num_capture_channels),
     65      coming_onset_(num_capture_channels),
     66      hold_counters_(num_capture_channels) {
     67  Reset();
     68 }
     69 
     70 SubbandErleEstimator::~SubbandErleEstimator() = default;
     71 
     72 void SubbandErleEstimator::Reset() {
     73  const size_t num_capture_channels = erle_.size();
     74  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
     75    erle_[ch].fill(min_erle_);
     76    erle_onset_compensated_[ch].fill(min_erle_);
     77    erle_unbounded_[ch].fill(min_erle_);
     78    erle_during_onsets_[ch].fill(min_erle_);
     79    coming_onset_[ch].fill(true);
     80    hold_counters_[ch].fill(0);
     81  }
     82  ResetAccumulatedSpectra();
     83 }
     84 
     85 void SubbandErleEstimator::Update(
     86    ArrayView<const float, kFftLengthBy2Plus1> X2,
     87    ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
     88    ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
     89    const std::vector<bool>& converged_filters) {
     90  UpdateAccumulatedSpectra(X2, Y2, E2, converged_filters);
     91  UpdateBands(converged_filters);
     92 
     93  if (use_onset_detection_) {
     94    DecreaseErlePerBandForLowRenderSignals();
     95  }
     96 
     97  const size_t num_capture_channels = erle_.size();
     98  for (size_t ch = 0; ch < num_capture_channels; ++ch) {
     99    auto& erle = erle_[ch];
    100    erle[0] = erle[1];
    101    erle[kFftLengthBy2] = erle[kFftLengthBy2 - 1];
    102 
    103    auto& erle_oc = erle_onset_compensated_[ch];
    104    erle_oc[0] = erle_oc[1];
    105    erle_oc[kFftLengthBy2] = erle_oc[kFftLengthBy2 - 1];
    106 
    107    auto& erle_u = erle_unbounded_[ch];
    108    erle_u[0] = erle_u[1];
    109    erle_u[kFftLengthBy2] = erle_u[kFftLengthBy2 - 1];
    110  }
    111 }
    112 
    113 void SubbandErleEstimator::Dump(
    114    const std::unique_ptr<ApmDataDumper>& data_dumper) const {
    115  data_dumper->DumpRaw("aec3_erle_onset", ErleDuringOnsets()[0]);
    116 }
    117 
    118 void SubbandErleEstimator::UpdateBands(
    119    const std::vector<bool>& converged_filters) {
    120  const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size());
    121  for (int ch = 0; ch < num_capture_channels; ++ch) {
    122    // Note that the use of the converged_filter flag already imposed
    123    // a minimum of the erle that can be estimated as that flag would
    124    // be false if the filter is performing poorly.
    125    if (!converged_filters[ch]) {
    126      continue;
    127    }
    128 
    129    if (accum_spectra_.num_points[ch] != kPointsToAccumulate) {
    130      continue;
    131    }
    132 
    133    std::array<float, kFftLengthBy2> new_erle;
    134    std::array<bool, kFftLengthBy2> is_erle_updated;
    135    is_erle_updated.fill(false);
    136 
    137    for (size_t k = 1; k < kFftLengthBy2; ++k) {
    138      if (accum_spectra_.E2[ch][k] > 0.f) {
    139        new_erle[k] = accum_spectra_.Y2[ch][k] / accum_spectra_.E2[ch][k];
    140        is_erle_updated[k] = true;
    141      }
    142    }
    143 
    144    if (use_onset_detection_) {
    145      for (size_t k = 1; k < kFftLengthBy2; ++k) {
    146        if (is_erle_updated[k] && !accum_spectra_.low_render_energy[ch][k]) {
    147          if (coming_onset_[ch][k]) {
    148            coming_onset_[ch][k] = false;
    149            if (!use_min_erle_during_onsets_) {
    150              float alpha =
    151                  new_erle[k] < erle_during_onsets_[ch][k] ? 0.3f : 0.15f;
    152              erle_during_onsets_[ch][k] = SafeClamp(
    153                  erle_during_onsets_[ch][k] +
    154                      alpha * (new_erle[k] - erle_during_onsets_[ch][k]),
    155                  min_erle_, max_erle_[k]);
    156            }
    157          }
    158          hold_counters_[ch][k] = kBlocksForOnsetDetection;
    159        }
    160      }
    161    }
    162 
    163    auto update_erle_band = [](float& erle, float new_erle,
    164                               bool low_render_energy, float min_erle,
    165                               float max_erle) {
    166      float alpha = 0.05f;
    167      if (new_erle < erle) {
    168        alpha = low_render_energy ? 0.f : 0.1f;
    169      }
    170      erle = SafeClamp(erle + alpha * (new_erle - erle), min_erle, max_erle);
    171    };
    172 
    173    for (size_t k = 1; k < kFftLengthBy2; ++k) {
    174      if (is_erle_updated[k]) {
    175        const bool low_render_energy = accum_spectra_.low_render_energy[ch][k];
    176        update_erle_band(erle_[ch][k], new_erle[k], low_render_energy,
    177                         min_erle_, max_erle_[k]);
    178        if (use_onset_detection_) {
    179          update_erle_band(erle_onset_compensated_[ch][k], new_erle[k],
    180                           low_render_energy, min_erle_, max_erle_[k]);
    181        }
    182 
    183        // Virtually unbounded ERLE.
    184        constexpr float kUnboundedErleMax = 100000.0f;
    185        update_erle_band(erle_unbounded_[ch][k], new_erle[k], low_render_energy,
    186                         min_erle_, kUnboundedErleMax);
    187      }
    188    }
    189  }
    190 }
    191 
    192 void SubbandErleEstimator::DecreaseErlePerBandForLowRenderSignals() {
    193  const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size());
    194  for (int ch = 0; ch < num_capture_channels; ++ch) {
    195    for (size_t k = 1; k < kFftLengthBy2; ++k) {
    196      --hold_counters_[ch][k];
    197      if (hold_counters_[ch][k] <=
    198          (kBlocksForOnsetDetection - kBlocksToHoldErle)) {
    199        if (erle_onset_compensated_[ch][k] > erle_during_onsets_[ch][k]) {
    200          erle_onset_compensated_[ch][k] =
    201              std::max(erle_during_onsets_[ch][k],
    202                       0.97f * erle_onset_compensated_[ch][k]);
    203          RTC_DCHECK_LE(min_erle_, erle_onset_compensated_[ch][k]);
    204        }
    205        if (hold_counters_[ch][k] <= 0) {
    206          coming_onset_[ch][k] = true;
    207          hold_counters_[ch][k] = 0;
    208        }
    209      }
    210    }
    211  }
    212 }
    213 
    214 void SubbandErleEstimator::ResetAccumulatedSpectra() {
    215  for (size_t ch = 0; ch < erle_during_onsets_.size(); ++ch) {
    216    accum_spectra_.Y2[ch].fill(0.f);
    217    accum_spectra_.E2[ch].fill(0.f);
    218    accum_spectra_.num_points[ch] = 0;
    219    accum_spectra_.low_render_energy[ch].fill(false);
    220  }
    221 }
    222 
    223 void SubbandErleEstimator::UpdateAccumulatedSpectra(
    224    ArrayView<const float, kFftLengthBy2Plus1> X2,
    225    ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2,
    226    ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2,
    227    const std::vector<bool>& converged_filters) {
    228  auto& st = accum_spectra_;
    229  RTC_DCHECK_EQ(st.E2.size(), E2.size());
    230  RTC_DCHECK_EQ(st.E2.size(), E2.size());
    231  const int num_capture_channels = static_cast<int>(Y2.size());
    232  for (int ch = 0; ch < num_capture_channels; ++ch) {
    233    // Note that the use of the converged_filter flag already imposed
    234    // a minimum of the erle that can be estimated as that flag would
    235    // be false if the filter is performing poorly.
    236    if (!converged_filters[ch]) {
    237      continue;
    238    }
    239 
    240    if (st.num_points[ch] == kPointsToAccumulate) {
    241      st.num_points[ch] = 0;
    242      st.Y2[ch].fill(0.f);
    243      st.E2[ch].fill(0.f);
    244      st.low_render_energy[ch].fill(false);
    245    }
    246 
    247    std::transform(Y2[ch].begin(), Y2[ch].end(), st.Y2[ch].begin(),
    248                   st.Y2[ch].begin(), std::plus<float>());
    249    std::transform(E2[ch].begin(), E2[ch].end(), st.E2[ch].begin(),
    250                   st.E2[ch].begin(), std::plus<float>());
    251 
    252    for (size_t k = 0; k < X2.size(); ++k) {
    253      st.low_render_energy[ch][k] =
    254          st.low_render_energy[ch][k] || X2[k] < kX2BandEnergyThreshold;
    255    }
    256 
    257    ++st.num_points[ch];
    258  }
    259 }
    260 
    261 }  // namespace webrtc