tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

suppression_gain.cc (19518B)


      1 /*
      2 *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #include "modules/audio_processing/aec3/suppression_gain.h"
     12 
     13 #include <algorithm>
     14 #include <array>
     15 #include <atomic>
     16 #include <cmath>
     17 #include <cstddef>
     18 #include <memory>
     19 #include <numeric>
     20 #include <optional>
     21 
     22 #include "api/array_view.h"
     23 #include "api/audio/echo_canceller3_config.h"
     24 #include "modules/audio_processing/aec3/aec3_common.h"
     25 #include "modules/audio_processing/aec3/aec_state.h"
     26 #include "modules/audio_processing/aec3/block.h"
     27 #include "modules/audio_processing/aec3/dominant_nearend_detector.h"
     28 #include "modules/audio_processing/aec3/moving_average.h"
     29 #include "modules/audio_processing/aec3/render_signal_analyzer.h"
     30 #include "modules/audio_processing/aec3/subband_nearend_detector.h"
     31 #include "modules/audio_processing/aec3/vector_math.h"
     32 #include "modules/audio_processing/logging/apm_data_dumper.h"
     33 #include "rtc_base/checks.h"
     34 
     35 namespace webrtc {
     36 namespace {
     37 
     38 void LimitLowFrequencyGains(std::array<float, kFftLengthBy2Plus1>* gain) {
     39  // Limit the low frequency gains to avoid the impact of the high-pass filter
     40  // on the lower-frequency gain influencing the overall achieved gain.
     41  (*gain)[0] = (*gain)[1] = std::min((*gain)[1], (*gain)[2]);
     42 }
     43 
     44 void LimitHighFrequencyGains(const EchoCanceller3Config::Suppressor& config,
     45                             std::array<float, kFftLengthBy2Plus1>* gain) {
     46  // Limit the high frequency gains to avoid echo leakage due to an imperfect
     47  // filter.
     48  const int limiting_gain_band =
     49      config.high_frequency_suppression.limiting_gain_band;
     50  const int bands_in_limiting_gain =
     51      config.high_frequency_suppression.bands_in_limiting_gain;
     52  if (bands_in_limiting_gain > 0) {
     53    RTC_DCHECK_GE(limiting_gain_band, 0);
     54    RTC_DCHECK_LE(limiting_gain_band + bands_in_limiting_gain, gain->size());
     55    float min_upper_gain = 1.f;
     56    for (int band = limiting_gain_band;
     57         band < limiting_gain_band + bands_in_limiting_gain; ++band) {
     58      min_upper_gain = std::min(min_upper_gain, (*gain)[band]);
     59    }
     60    std::for_each(
     61        gain->begin() + limiting_gain_band + 1, gain->end(),
     62        [min_upper_gain](float& a) { a = std::min(a, min_upper_gain); });
     63  }
     64  (*gain)[kFftLengthBy2] = (*gain)[kFftLengthBy2Minus1];
     65 
     66  if (config.conservative_hf_suppression) {
     67    // Limits the gain in the frequencies for which the adaptive filter has not
     68    // converged.
     69    // TODO(peah): Make adaptive to take the actual filter error into account.
     70    constexpr size_t kUpperAccurateBandPlus1 = 29;
     71 
     72    constexpr float oneByBandsInSum =
     73        1 / static_cast<float>(kUpperAccurateBandPlus1 - 20);
     74    const float hf_gain_bound =
     75        std::accumulate(gain->begin() + 20,
     76                        gain->begin() + kUpperAccurateBandPlus1, 0.f) *
     77        oneByBandsInSum;
     78 
     79    std::for_each(
     80        gain->begin() + kUpperAccurateBandPlus1, gain->end(),
     81        [hf_gain_bound](float& a) { a = std::min(a, hf_gain_bound); });
     82  }
     83 }
     84 
     85 // Scales the echo according to assessed audibility at the other end.
     86 void WeightEchoForAudibility(const EchoCanceller3Config& config,
     87                             ArrayView<const float> echo,
     88                             ArrayView<float> weighted_echo) {
     89  RTC_DCHECK_EQ(kFftLengthBy2Plus1, echo.size());
     90  RTC_DCHECK_EQ(kFftLengthBy2Plus1, weighted_echo.size());
     91 
     92  auto weigh = [](float threshold, float normalizer, size_t begin, size_t end,
     93                  ArrayView<const float> echo, ArrayView<float> weighted_echo) {
     94    for (size_t k = begin; k < end; ++k) {
     95      if (echo[k] < threshold) {
     96        float tmp = (threshold - echo[k]) * normalizer;
     97        weighted_echo[k] = echo[k] * std::max(0.f, 1.f - tmp * tmp);
     98      } else {
     99        weighted_echo[k] = echo[k];
    100      }
    101    }
    102  };
    103 
    104  float threshold = config.echo_audibility.floor_power *
    105                    config.echo_audibility.audibility_threshold_lf;
    106  float normalizer = 1.f / (threshold - config.echo_audibility.floor_power);
    107  weigh(threshold, normalizer, 0, 3, echo, weighted_echo);
    108 
    109  threshold = config.echo_audibility.floor_power *
    110              config.echo_audibility.audibility_threshold_mf;
    111  normalizer = 1.f / (threshold - config.echo_audibility.floor_power);
    112  weigh(threshold, normalizer, 3, 7, echo, weighted_echo);
    113 
    114  threshold = config.echo_audibility.floor_power *
    115              config.echo_audibility.audibility_threshold_hf;
    116  normalizer = 1.f / (threshold - config.echo_audibility.floor_power);
    117  weigh(threshold, normalizer, 7, kFftLengthBy2Plus1, echo, weighted_echo);
    118 }
    119 
    120 }  // namespace
    121 
    122 std::atomic<int> SuppressionGain::instance_count_(0);
    123 
    124 float SuppressionGain::UpperBandsGain(
    125    ArrayView<const std::array<float, kFftLengthBy2Plus1>> echo_spectrum,
    126    ArrayView<const std::array<float, kFftLengthBy2Plus1>>
    127        comfort_noise_spectrum,
    128    const std::optional<int>& narrow_peak_band,
    129    bool saturated_echo,
    130    const Block& render,
    131    const std::array<float, kFftLengthBy2Plus1>& low_band_gain) const {
    132  RTC_DCHECK_LT(0, render.NumBands());
    133  if (render.NumBands() == 1) {
    134    return 1.f;
    135  }
    136  const int num_render_channels = render.NumChannels();
    137 
    138  if (narrow_peak_band &&
    139      (*narrow_peak_band > static_cast<int>(kFftLengthBy2Plus1 - 10))) {
    140    return 0.001f;
    141  }
    142 
    143  constexpr size_t kLowBandGainLimit = kFftLengthBy2 / 2;
    144  const float gain_below_8_khz = *std::min_element(
    145      low_band_gain.begin() + kLowBandGainLimit, low_band_gain.end());
    146 
    147  // Always attenuate the upper bands when there is saturated echo.
    148  if (saturated_echo) {
    149    return std::min(0.001f, gain_below_8_khz);
    150  }
    151 
    152  // Compute the upper and lower band energies.
    153  const auto sum_of_squares = [](float a, float b) { return a + b * b; };
    154  float low_band_energy = 0.f;
    155  for (int ch = 0; ch < num_render_channels; ++ch) {
    156    const float channel_energy =
    157        std::accumulate(render.begin(/*band=*/0, ch),
    158                        render.end(/*band=*/0, ch), 0.0f, sum_of_squares);
    159    low_band_energy = std::max(low_band_energy, channel_energy);
    160  }
    161  float high_band_energy = 0.f;
    162  for (int k = 1; k < render.NumBands(); ++k) {
    163    for (int ch = 0; ch < num_render_channels; ++ch) {
    164      const float energy = std::accumulate(
    165          render.begin(k, ch), render.end(k, ch), 0.f, sum_of_squares);
    166      high_band_energy = std::max(high_band_energy, energy);
    167    }
    168  }
    169 
    170  // If there is more power in the lower frequencies than the upper frequencies,
    171  // or if the power in upper frequencies is low, do not bound the gain in the
    172  // upper bands.
    173  float anti_howling_gain;
    174  const float activation_threshold =
    175      kBlockSize * config_.suppressor.high_bands_suppression
    176                       .anti_howling_activation_threshold;
    177  if (high_band_energy < std::max(low_band_energy, activation_threshold)) {
    178    anti_howling_gain = 1.f;
    179  } else {
    180    // In all other cases, bound the gain for upper frequencies.
    181    RTC_DCHECK_LE(low_band_energy, high_band_energy);
    182    RTC_DCHECK_NE(0.f, high_band_energy);
    183    anti_howling_gain =
    184        config_.suppressor.high_bands_suppression.anti_howling_gain *
    185        sqrtf(low_band_energy / high_band_energy);
    186  }
    187 
    188  float gain_bound = 1.f;
    189  if (!dominant_nearend_detector_->IsNearendState()) {
    190    // Bound the upper gain during significant echo activity.
    191    const auto& cfg = config_.suppressor.high_bands_suppression;
    192    auto low_frequency_energy = [](ArrayView<const float> spectrum) {
    193      RTC_DCHECK_LE(16, spectrum.size());
    194      return std::accumulate(spectrum.begin() + 1, spectrum.begin() + 16, 0.f);
    195    };
    196    for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
    197      const float echo_sum = low_frequency_energy(echo_spectrum[ch]);
    198      const float noise_sum = low_frequency_energy(comfort_noise_spectrum[ch]);
    199      if (echo_sum > cfg.enr_threshold * noise_sum) {
    200        gain_bound = cfg.max_gain_during_echo;
    201        break;
    202      }
    203    }
    204  }
    205 
    206  // Choose the gain as the minimum of the lower and upper gains.
    207  return std::min(std::min(gain_below_8_khz, anti_howling_gain), gain_bound);
    208 }
    209 
    210 // Computes the gain to reduce the echo to a non audible level.
    211 void SuppressionGain::GainToNoAudibleEcho(
    212    const std::array<float, kFftLengthBy2Plus1>& nearend,
    213    const std::array<float, kFftLengthBy2Plus1>& echo,
    214    const std::array<float, kFftLengthBy2Plus1>& masker,
    215    std::array<float, kFftLengthBy2Plus1>* gain) const {
    216  const auto& p = dominant_nearend_detector_->IsNearendState() ? nearend_params_
    217                                                               : normal_params_;
    218  for (size_t k = 0; k < gain->size(); ++k) {
    219    float enr = echo[k] / (nearend[k] + 1.f);  // Echo-to-nearend ratio.
    220    float emr = echo[k] / (masker[k] + 1.f);   // Echo-to-masker (noise) ratio.
    221    float g = 1.0f;
    222    if (enr > p.enr_transparent_[k] && emr > p.emr_transparent_[k]) {
    223      g = (p.enr_suppress_[k] - enr) /
    224          (p.enr_suppress_[k] - p.enr_transparent_[k]);
    225      g = std::max(g, p.emr_transparent_[k] / emr);
    226    }
    227    (*gain)[k] = g;
    228  }
    229 }
    230 
    231 // Compute the minimum gain as the attenuating gain to put the signal just
    232 // above the zero sample values.
    233 void SuppressionGain::GetMinGain(ArrayView<const float> weighted_residual_echo,
    234                                 ArrayView<const float> last_nearend,
    235                                 ArrayView<const float> last_echo,
    236                                 bool low_noise_render,
    237                                 bool saturated_echo,
    238                                 ArrayView<float> min_gain) const {
    239  if (!saturated_echo) {
    240    const float min_echo_power =
    241        low_noise_render ? config_.echo_audibility.low_render_limit
    242                         : config_.echo_audibility.normal_render_limit;
    243 
    244    for (size_t k = 0; k < min_gain.size(); ++k) {
    245      min_gain[k] = weighted_residual_echo[k] > 0.f
    246                        ? min_echo_power / weighted_residual_echo[k]
    247                        : 1.f;
    248      min_gain[k] = std::min(min_gain[k], 1.f);
    249    }
    250 
    251    if (!initial_state_ ||
    252        config_.suppressor.lf_smoothing_during_initial_phase) {
    253      const float& dec = dominant_nearend_detector_->IsNearendState()
    254                             ? nearend_params_.max_dec_factor_lf
    255                             : normal_params_.max_dec_factor_lf;
    256 
    257      for (int k = 0; k <= config_.suppressor.last_lf_smoothing_band; ++k) {
    258        // Make sure the gains of the low frequencies do not decrease too
    259        // quickly after strong nearend.
    260        if (last_nearend[k] > last_echo[k] ||
    261            k <= config_.suppressor.last_permanent_lf_smoothing_band) {
    262          min_gain[k] = std::max(min_gain[k], last_gain_[k] * dec);
    263          min_gain[k] = std::min(min_gain[k], 1.f);
    264        }
    265      }
    266    }
    267  } else {
    268    std::fill(min_gain.begin(), min_gain.end(), 0.f);
    269  }
    270 }
    271 
    272 // Compute the maximum gain by limiting the gain increase from the previous
    273 // gain.
    274 void SuppressionGain::GetMaxGain(ArrayView<float> max_gain) const {
    275  const auto& inc = dominant_nearend_detector_->IsNearendState()
    276                        ? nearend_params_.max_inc_factor
    277                        : normal_params_.max_inc_factor;
    278  const auto& floor = config_.suppressor.floor_first_increase;
    279  for (size_t k = 0; k < max_gain.size(); ++k) {
    280    max_gain[k] = std::min(std::max(last_gain_[k] * inc, floor), 1.f);
    281  }
    282 }
    283 
    284 void SuppressionGain::LowerBandGain(
    285    bool low_noise_render,
    286    const AecState& aec_state,
    287    ArrayView<const std::array<float, kFftLengthBy2Plus1>> suppressor_input,
    288    ArrayView<const std::array<float, kFftLengthBy2Plus1>> residual_echo,
    289    ArrayView<const std::array<float, kFftLengthBy2Plus1>> comfort_noise,
    290    bool clock_drift,
    291    std::array<float, kFftLengthBy2Plus1>* gain) {
    292  gain->fill(1.f);
    293  const bool saturated_echo = aec_state.SaturatedEcho();
    294  std::array<float, kFftLengthBy2Plus1> max_gain;
    295  GetMaxGain(max_gain);
    296 
    297  for (size_t ch = 0; ch < num_capture_channels_; ++ch) {
    298    std::array<float, kFftLengthBy2Plus1> G;
    299    std::array<float, kFftLengthBy2Plus1> nearend;
    300    nearend_smoothers_[ch].Average(suppressor_input[ch], nearend);
    301 
    302    // Weight echo power in terms of audibility.
    303    std::array<float, kFftLengthBy2Plus1> weighted_residual_echo;
    304    WeightEchoForAudibility(config_, residual_echo[ch], weighted_residual_echo);
    305 
    306    std::array<float, kFftLengthBy2Plus1> min_gain;
    307    GetMinGain(weighted_residual_echo, last_nearend_[ch], last_echo_[ch],
    308               low_noise_render, saturated_echo, min_gain);
    309 
    310    GainToNoAudibleEcho(nearend, weighted_residual_echo, comfort_noise[0], &G);
    311 
    312    // Clamp gains.
    313    for (size_t k = 0; k < gain->size(); ++k) {
    314      G[k] = std::max(std::min(G[k], max_gain[k]), min_gain[k]);
    315      (*gain)[k] = std::min((*gain)[k], G[k]);
    316    }
    317 
    318    // Store data required for the gain computation of the next block.
    319    std::copy(nearend.begin(), nearend.end(), last_nearend_[ch].begin());
    320    std::copy(weighted_residual_echo.begin(), weighted_residual_echo.end(),
    321              last_echo_[ch].begin());
    322  }
    323 
    324  LimitLowFrequencyGains(gain);
    325  // Use conservative high-frequency gains during clock-drift or when not in
    326  // dominant nearend.
    327  if (!dominant_nearend_detector_->IsNearendState() || clock_drift ||
    328      config_.suppressor.conservative_hf_suppression) {
    329    LimitHighFrequencyGains(config_.suppressor, gain);
    330  }
    331 
    332  // Store computed gains.
    333  std::copy(gain->begin(), gain->end(), last_gain_.begin());
    334 
    335  // Transform gains to amplitude domain.
    336  aec3::VectorMath(optimization_).Sqrt(*gain);
    337 }
    338 
    339 SuppressionGain::SuppressionGain(const EchoCanceller3Config& config,
    340                                 Aec3Optimization optimization,
    341                                 int /* sample_rate_hz */,
    342                                 size_t num_capture_channels)
    343    : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)),
    344      optimization_(optimization),
    345      config_(config),
    346      num_capture_channels_(num_capture_channels),
    347      state_change_duration_blocks_(
    348          static_cast<int>(config_.filter.config_change_duration_blocks)),
    349      last_nearend_(num_capture_channels_, {0}),
    350      last_echo_(num_capture_channels_, {0}),
    351      nearend_smoothers_(
    352          num_capture_channels_,
    353          aec3::MovingAverage(kFftLengthBy2Plus1,
    354                              config.suppressor.nearend_average_blocks)),
    355      nearend_params_(config_.suppressor.last_lf_band,
    356                      config_.suppressor.first_hf_band,
    357                      config_.suppressor.nearend_tuning),
    358      normal_params_(config_.suppressor.last_lf_band,
    359                     config_.suppressor.first_hf_band,
    360                     config_.suppressor.normal_tuning),
    361      use_unbounded_echo_spectrum_(config.suppressor.dominant_nearend_detection
    362                                       .use_unbounded_echo_spectrum) {
    363  RTC_DCHECK_LT(0, state_change_duration_blocks_);
    364  last_gain_.fill(1.f);
    365  if (config_.suppressor.use_subband_nearend_detection) {
    366    dominant_nearend_detector_ = std::make_unique<SubbandNearendDetector>(
    367        config_.suppressor.subband_nearend_detection, num_capture_channels_);
    368  } else {
    369    dominant_nearend_detector_ = std::make_unique<DominantNearendDetector>(
    370        config_.suppressor.dominant_nearend_detection, num_capture_channels_);
    371  }
    372  RTC_DCHECK(dominant_nearend_detector_);
    373 }
    374 
    375 SuppressionGain::~SuppressionGain() = default;
    376 
    377 void SuppressionGain::GetGain(
    378    ArrayView<const std::array<float, kFftLengthBy2Plus1>> nearend_spectrum,
    379    ArrayView<const std::array<float, kFftLengthBy2Plus1>> echo_spectrum,
    380    ArrayView<const std::array<float, kFftLengthBy2Plus1>>
    381        residual_echo_spectrum,
    382    ArrayView<const std::array<float, kFftLengthBy2Plus1>>
    383        residual_echo_spectrum_unbounded,
    384    ArrayView<const std::array<float, kFftLengthBy2Plus1>>
    385        comfort_noise_spectrum,
    386    const RenderSignalAnalyzer& render_signal_analyzer,
    387    const AecState& aec_state,
    388    const Block& render,
    389    bool clock_drift,
    390    float* high_bands_gain,
    391    std::array<float, kFftLengthBy2Plus1>* low_band_gain) {
    392  RTC_DCHECK(high_bands_gain);
    393  RTC_DCHECK(low_band_gain);
    394 
    395  // Choose residual echo spectrum for dominant nearend detection.
    396  const auto echo = use_unbounded_echo_spectrum_
    397                        ? residual_echo_spectrum_unbounded
    398                        : residual_echo_spectrum;
    399 
    400  // Update the nearend state selection.
    401  dominant_nearend_detector_->Update(nearend_spectrum, echo,
    402                                     comfort_noise_spectrum, initial_state_);
    403 
    404  // Compute gain for the lower band.
    405  bool low_noise_render = low_render_detector_.Detect(render);
    406  LowerBandGain(low_noise_render, aec_state, nearend_spectrum,
    407                residual_echo_spectrum, comfort_noise_spectrum, clock_drift,
    408                low_band_gain);
    409 
    410  // Compute the gain for the upper bands.
    411  const std::optional<int> narrow_peak_band =
    412      render_signal_analyzer.NarrowPeakBand();
    413 
    414  *high_bands_gain =
    415      UpperBandsGain(echo_spectrum, comfort_noise_spectrum, narrow_peak_band,
    416                     aec_state.SaturatedEcho(), render, *low_band_gain);
    417 
    418  data_dumper_->DumpRaw("aec3_dominant_nearend",
    419                        dominant_nearend_detector_->IsNearendState());
    420 }
    421 
    422 void SuppressionGain::SetInitialState(bool state) {
    423  initial_state_ = state;
    424  if (state) {
    425    initial_state_change_counter_ = state_change_duration_blocks_;
    426  } else {
    427    initial_state_change_counter_ = 0;
    428  }
    429 }
    430 
    431 // Detects when the render signal can be considered to have low power and
    432 // consist of stationary noise.
    433 bool SuppressionGain::LowNoiseRenderDetector::Detect(const Block& render) {
    434  float x2_sum = 0.f;
    435  float x2_max = 0.f;
    436  for (int ch = 0; ch < render.NumChannels(); ++ch) {
    437    for (float x_k : render.View(/*band=*/0, ch)) {
    438      const float x2 = x_k * x_k;
    439      x2_sum += x2;
    440      x2_max = std::max(x2_max, x2);
    441    }
    442  }
    443  x2_sum = x2_sum / render.NumChannels();
    444 
    445  constexpr float kThreshold = 50.f * 50.f * 64.f;
    446  const bool low_noise_render =
    447      average_power_ < kThreshold && x2_max < 3 * average_power_;
    448  average_power_ = average_power_ * 0.9f + x2_sum * 0.1f;
    449  return low_noise_render;
    450 }
    451 
    452 SuppressionGain::GainParameters::GainParameters(
    453    int last_lf_band,
    454    int first_hf_band,
    455    const EchoCanceller3Config::Suppressor::Tuning& tuning)
    456    : max_inc_factor(tuning.max_inc_factor),
    457      max_dec_factor_lf(tuning.max_dec_factor_lf) {
    458  // Compute per-band masking thresholds.
    459  RTC_DCHECK_LT(last_lf_band, first_hf_band);
    460  auto& lf = tuning.mask_lf;
    461  auto& hf = tuning.mask_hf;
    462  RTC_DCHECK_LT(lf.enr_transparent, lf.enr_suppress);
    463  RTC_DCHECK_LT(hf.enr_transparent, hf.enr_suppress);
    464  for (int k = 0; k < static_cast<int>(kFftLengthBy2Plus1); k++) {
    465    float a;
    466    if (k <= last_lf_band) {
    467      a = 0.f;
    468    } else if (k < first_hf_band) {
    469      a = (k - last_lf_band) / static_cast<float>(first_hf_band - last_lf_band);
    470    } else {
    471      a = 1.f;
    472    }
    473    enr_transparent_[k] = (1 - a) * lf.enr_transparent + a * hf.enr_transparent;
    474    enr_suppress_[k] = (1 - a) * lf.enr_suppress + a * hf.enr_suppress;
    475    emr_transparent_[k] = (1 - a) * lf.emr_transparent + a * hf.emr_transparent;
    476  }
    477 }
    478 
    479 }  // namespace webrtc