tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

erl_estimator.cc (5125B)


      1 /*
      2 *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #include "modules/audio_processing/aec3/erl_estimator.h"
     12 
     13 #include <algorithm>
     14 #include <array>
     15 #include <cstddef>
     16 #include <iterator>
     17 #include <numeric>
     18 #include <vector>
     19 
     20 #include "api/array_view.h"
     21 #include "modules/audio_processing/aec3/aec3_common.h"
     22 #include "rtc_base/checks.h"
     23 
     24 namespace webrtc {
     25 
     26 namespace {
     27 
     28 constexpr float kMinErl = 0.01f;
     29 constexpr float kMaxErl = 1000.f;
     30 
     31 }  // namespace
     32 
     33 ErlEstimator::ErlEstimator(size_t startup_phase_length_blocks_)
     34    : startup_phase_length_blocks__(startup_phase_length_blocks_) {
     35  erl_.fill(kMaxErl);
     36  hold_counters_.fill(0);
     37  erl_time_domain_ = kMaxErl;
     38  hold_counter_time_domain_ = 0;
     39 }
     40 
     41 ErlEstimator::~ErlEstimator() = default;
     42 
     43 void ErlEstimator::Reset() {
     44  blocks_since_reset_ = 0;
     45 }
     46 
     47 void ErlEstimator::Update(
     48    const std::vector<bool>& converged_filters,
     49    ArrayView<const std::array<float, kFftLengthBy2Plus1>> render_spectra,
     50    ArrayView<const std::array<float, kFftLengthBy2Plus1>> capture_spectra) {
     51  const size_t num_capture_channels = converged_filters.size();
     52  RTC_DCHECK_EQ(capture_spectra.size(), num_capture_channels);
     53 
     54  // Corresponds to WGN of power -46 dBFS.
     55  constexpr float kX2Min = 44015068.0f;
     56 
     57  const auto first_converged_iter =
     58      std::find(converged_filters.begin(), converged_filters.end(), true);
     59  const bool any_filter_converged =
     60      first_converged_iter != converged_filters.end();
     61 
     62  if (++blocks_since_reset_ < startup_phase_length_blocks__ ||
     63      !any_filter_converged) {
     64    return;
     65  }
     66 
     67  // Use the maximum spectrum across capture and the maximum across render.
     68  std::array<float, kFftLengthBy2Plus1> max_capture_spectrum_data;
     69  std::array<float, kFftLengthBy2Plus1> max_capture_spectrum =
     70      capture_spectra[/*channel=*/0];
     71  if (num_capture_channels > 1) {
     72    // Initialize using the first channel with a converged filter.
     73    const size_t first_converged =
     74        std::distance(converged_filters.begin(), first_converged_iter);
     75    RTC_DCHECK_GE(first_converged, 0);
     76    RTC_DCHECK_LT(first_converged, num_capture_channels);
     77    max_capture_spectrum_data = capture_spectra[first_converged];
     78 
     79    for (size_t ch = first_converged + 1; ch < num_capture_channels; ++ch) {
     80      if (!converged_filters[ch]) {
     81        continue;
     82      }
     83      for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
     84        max_capture_spectrum_data[k] =
     85            std::max(max_capture_spectrum_data[k], capture_spectra[ch][k]);
     86      }
     87    }
     88    max_capture_spectrum = max_capture_spectrum_data;
     89  }
     90 
     91  const size_t num_render_channels = render_spectra.size();
     92  std::array<float, kFftLengthBy2Plus1> max_render_spectrum_data;
     93  ArrayView<const float, kFftLengthBy2Plus1> max_render_spectrum =
     94      render_spectra[/*channel=*/0];
     95  if (num_render_channels > 1) {
     96    std::copy(render_spectra[0].begin(), render_spectra[0].end(),
     97              max_render_spectrum_data.begin());
     98    for (size_t ch = 1; ch < num_render_channels; ++ch) {
     99      for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) {
    100        max_render_spectrum_data[k] =
    101            std::max(max_render_spectrum_data[k], render_spectra[ch][k]);
    102      }
    103    }
    104    max_render_spectrum = max_render_spectrum_data;
    105  }
    106 
    107  const auto& X2 = max_render_spectrum;
    108  const auto& Y2 = max_capture_spectrum;
    109 
    110  // Update the estimates in a maximum statistics manner.
    111  for (size_t k = 1; k < kFftLengthBy2; ++k) {
    112    if (X2[k] > kX2Min) {
    113      const float new_erl = Y2[k] / X2[k];
    114      if (new_erl < erl_[k]) {
    115        hold_counters_[k - 1] = 1000;
    116        erl_[k] += 0.1f * (new_erl - erl_[k]);
    117        erl_[k] = std::max(erl_[k], kMinErl);
    118      }
    119    }
    120  }
    121 
    122  std::for_each(hold_counters_.begin(), hold_counters_.end(),
    123                [](int& a) { --a; });
    124  std::transform(hold_counters_.begin(), hold_counters_.end(), erl_.begin() + 1,
    125                 erl_.begin() + 1, [](int a, float b) {
    126                   return a > 0 ? b : std::min(kMaxErl, 2.f * b);
    127                 });
    128 
    129  erl_[0] = erl_[1];
    130  erl_[kFftLengthBy2] = erl_[kFftLengthBy2 - 1];
    131 
    132  // Compute ERL over all frequency bins.
    133  const float X2_sum = std::accumulate(X2.begin(), X2.end(), 0.0f);
    134 
    135  if (X2_sum > kX2Min * X2.size()) {
    136    const float Y2_sum = std::accumulate(Y2.begin(), Y2.end(), 0.0f);
    137    const float new_erl = Y2_sum / X2_sum;
    138    if (new_erl < erl_time_domain_) {
    139      hold_counter_time_domain_ = 1000;
    140      erl_time_domain_ += 0.1f * (new_erl - erl_time_domain_);
    141      erl_time_domain_ = std::max(erl_time_domain_, kMinErl);
    142    }
    143  }
    144 
    145  --hold_counter_time_domain_;
    146  erl_time_domain_ = (hold_counter_time_domain_ > 0)
    147                         ? erl_time_domain_
    148                         : std::min(kMaxErl, 2.f * erl_time_domain_);
    149 }
    150 
    151 }  // namespace webrtc