suppression_gain.cc (19518B)
1 /* 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/aec3/suppression_gain.h" 12 13 #include <algorithm> 14 #include <array> 15 #include <atomic> 16 #include <cmath> 17 #include <cstddef> 18 #include <memory> 19 #include <numeric> 20 #include <optional> 21 22 #include "api/array_view.h" 23 #include "api/audio/echo_canceller3_config.h" 24 #include "modules/audio_processing/aec3/aec3_common.h" 25 #include "modules/audio_processing/aec3/aec_state.h" 26 #include "modules/audio_processing/aec3/block.h" 27 #include "modules/audio_processing/aec3/dominant_nearend_detector.h" 28 #include "modules/audio_processing/aec3/moving_average.h" 29 #include "modules/audio_processing/aec3/render_signal_analyzer.h" 30 #include "modules/audio_processing/aec3/subband_nearend_detector.h" 31 #include "modules/audio_processing/aec3/vector_math.h" 32 #include "modules/audio_processing/logging/apm_data_dumper.h" 33 #include "rtc_base/checks.h" 34 35 namespace webrtc { 36 namespace { 37 38 void LimitLowFrequencyGains(std::array<float, kFftLengthBy2Plus1>* gain) { 39 // Limit the low frequency gains to avoid the impact of the high-pass filter 40 // on the lower-frequency gain influencing the overall achieved gain. 41 (*gain)[0] = (*gain)[1] = std::min((*gain)[1], (*gain)[2]); 42 } 43 44 void LimitHighFrequencyGains(const EchoCanceller3Config::Suppressor& config, 45 std::array<float, kFftLengthBy2Plus1>* gain) { 46 // Limit the high frequency gains to avoid echo leakage due to an imperfect 47 // filter. 48 const int limiting_gain_band = 49 config.high_frequency_suppression.limiting_gain_band; 50 const int bands_in_limiting_gain = 51 config.high_frequency_suppression.bands_in_limiting_gain; 52 if (bands_in_limiting_gain > 0) { 53 RTC_DCHECK_GE(limiting_gain_band, 0); 54 RTC_DCHECK_LE(limiting_gain_band + bands_in_limiting_gain, gain->size()); 55 float min_upper_gain = 1.f; 56 for (int band = limiting_gain_band; 57 band < limiting_gain_band + bands_in_limiting_gain; ++band) { 58 min_upper_gain = std::min(min_upper_gain, (*gain)[band]); 59 } 60 std::for_each( 61 gain->begin() + limiting_gain_band + 1, gain->end(), 62 [min_upper_gain](float& a) { a = std::min(a, min_upper_gain); }); 63 } 64 (*gain)[kFftLengthBy2] = (*gain)[kFftLengthBy2Minus1]; 65 66 if (config.conservative_hf_suppression) { 67 // Limits the gain in the frequencies for which the adaptive filter has not 68 // converged. 69 // TODO(peah): Make adaptive to take the actual filter error into account. 70 constexpr size_t kUpperAccurateBandPlus1 = 29; 71 72 constexpr float oneByBandsInSum = 73 1 / static_cast<float>(kUpperAccurateBandPlus1 - 20); 74 const float hf_gain_bound = 75 std::accumulate(gain->begin() + 20, 76 gain->begin() + kUpperAccurateBandPlus1, 0.f) * 77 oneByBandsInSum; 78 79 std::for_each( 80 gain->begin() + kUpperAccurateBandPlus1, gain->end(), 81 [hf_gain_bound](float& a) { a = std::min(a, hf_gain_bound); }); 82 } 83 } 84 85 // Scales the echo according to assessed audibility at the other end. 86 void WeightEchoForAudibility(const EchoCanceller3Config& config, 87 ArrayView<const float> echo, 88 ArrayView<float> weighted_echo) { 89 RTC_DCHECK_EQ(kFftLengthBy2Plus1, echo.size()); 90 RTC_DCHECK_EQ(kFftLengthBy2Plus1, weighted_echo.size()); 91 92 auto weigh = [](float threshold, float normalizer, size_t begin, size_t end, 93 ArrayView<const float> echo, ArrayView<float> weighted_echo) { 94 for (size_t k = begin; k < end; ++k) { 95 if (echo[k] < threshold) { 96 float tmp = (threshold - echo[k]) * normalizer; 97 weighted_echo[k] = echo[k] * std::max(0.f, 1.f - tmp * tmp); 98 } else { 99 weighted_echo[k] = echo[k]; 100 } 101 } 102 }; 103 104 float threshold = config.echo_audibility.floor_power * 105 config.echo_audibility.audibility_threshold_lf; 106 float normalizer = 1.f / (threshold - config.echo_audibility.floor_power); 107 weigh(threshold, normalizer, 0, 3, echo, weighted_echo); 108 109 threshold = config.echo_audibility.floor_power * 110 config.echo_audibility.audibility_threshold_mf; 111 normalizer = 1.f / (threshold - config.echo_audibility.floor_power); 112 weigh(threshold, normalizer, 3, 7, echo, weighted_echo); 113 114 threshold = config.echo_audibility.floor_power * 115 config.echo_audibility.audibility_threshold_hf; 116 normalizer = 1.f / (threshold - config.echo_audibility.floor_power); 117 weigh(threshold, normalizer, 7, kFftLengthBy2Plus1, echo, weighted_echo); 118 } 119 120 } // namespace 121 122 std::atomic<int> SuppressionGain::instance_count_(0); 123 124 float SuppressionGain::UpperBandsGain( 125 ArrayView<const std::array<float, kFftLengthBy2Plus1>> echo_spectrum, 126 ArrayView<const std::array<float, kFftLengthBy2Plus1>> 127 comfort_noise_spectrum, 128 const std::optional<int>& narrow_peak_band, 129 bool saturated_echo, 130 const Block& render, 131 const std::array<float, kFftLengthBy2Plus1>& low_band_gain) const { 132 RTC_DCHECK_LT(0, render.NumBands()); 133 if (render.NumBands() == 1) { 134 return 1.f; 135 } 136 const int num_render_channels = render.NumChannels(); 137 138 if (narrow_peak_band && 139 (*narrow_peak_band > static_cast<int>(kFftLengthBy2Plus1 - 10))) { 140 return 0.001f; 141 } 142 143 constexpr size_t kLowBandGainLimit = kFftLengthBy2 / 2; 144 const float gain_below_8_khz = *std::min_element( 145 low_band_gain.begin() + kLowBandGainLimit, low_band_gain.end()); 146 147 // Always attenuate the upper bands when there is saturated echo. 148 if (saturated_echo) { 149 return std::min(0.001f, gain_below_8_khz); 150 } 151 152 // Compute the upper and lower band energies. 153 const auto sum_of_squares = [](float a, float b) { return a + b * b; }; 154 float low_band_energy = 0.f; 155 for (int ch = 0; ch < num_render_channels; ++ch) { 156 const float channel_energy = 157 std::accumulate(render.begin(/*band=*/0, ch), 158 render.end(/*band=*/0, ch), 0.0f, sum_of_squares); 159 low_band_energy = std::max(low_band_energy, channel_energy); 160 } 161 float high_band_energy = 0.f; 162 for (int k = 1; k < render.NumBands(); ++k) { 163 for (int ch = 0; ch < num_render_channels; ++ch) { 164 const float energy = std::accumulate( 165 render.begin(k, ch), render.end(k, ch), 0.f, sum_of_squares); 166 high_band_energy = std::max(high_band_energy, energy); 167 } 168 } 169 170 // If there is more power in the lower frequencies than the upper frequencies, 171 // or if the power in upper frequencies is low, do not bound the gain in the 172 // upper bands. 173 float anti_howling_gain; 174 const float activation_threshold = 175 kBlockSize * config_.suppressor.high_bands_suppression 176 .anti_howling_activation_threshold; 177 if (high_band_energy < std::max(low_band_energy, activation_threshold)) { 178 anti_howling_gain = 1.f; 179 } else { 180 // In all other cases, bound the gain for upper frequencies. 181 RTC_DCHECK_LE(low_band_energy, high_band_energy); 182 RTC_DCHECK_NE(0.f, high_band_energy); 183 anti_howling_gain = 184 config_.suppressor.high_bands_suppression.anti_howling_gain * 185 sqrtf(low_band_energy / high_band_energy); 186 } 187 188 float gain_bound = 1.f; 189 if (!dominant_nearend_detector_->IsNearendState()) { 190 // Bound the upper gain during significant echo activity. 191 const auto& cfg = config_.suppressor.high_bands_suppression; 192 auto low_frequency_energy = [](ArrayView<const float> spectrum) { 193 RTC_DCHECK_LE(16, spectrum.size()); 194 return std::accumulate(spectrum.begin() + 1, spectrum.begin() + 16, 0.f); 195 }; 196 for (size_t ch = 0; ch < num_capture_channels_; ++ch) { 197 const float echo_sum = low_frequency_energy(echo_spectrum[ch]); 198 const float noise_sum = low_frequency_energy(comfort_noise_spectrum[ch]); 199 if (echo_sum > cfg.enr_threshold * noise_sum) { 200 gain_bound = cfg.max_gain_during_echo; 201 break; 202 } 203 } 204 } 205 206 // Choose the gain as the minimum of the lower and upper gains. 207 return std::min(std::min(gain_below_8_khz, anti_howling_gain), gain_bound); 208 } 209 210 // Computes the gain to reduce the echo to a non audible level. 211 void SuppressionGain::GainToNoAudibleEcho( 212 const std::array<float, kFftLengthBy2Plus1>& nearend, 213 const std::array<float, kFftLengthBy2Plus1>& echo, 214 const std::array<float, kFftLengthBy2Plus1>& masker, 215 std::array<float, kFftLengthBy2Plus1>* gain) const { 216 const auto& p = dominant_nearend_detector_->IsNearendState() ? nearend_params_ 217 : normal_params_; 218 for (size_t k = 0; k < gain->size(); ++k) { 219 float enr = echo[k] / (nearend[k] + 1.f); // Echo-to-nearend ratio. 220 float emr = echo[k] / (masker[k] + 1.f); // Echo-to-masker (noise) ratio. 221 float g = 1.0f; 222 if (enr > p.enr_transparent_[k] && emr > p.emr_transparent_[k]) { 223 g = (p.enr_suppress_[k] - enr) / 224 (p.enr_suppress_[k] - p.enr_transparent_[k]); 225 g = std::max(g, p.emr_transparent_[k] / emr); 226 } 227 (*gain)[k] = g; 228 } 229 } 230 231 // Compute the minimum gain as the attenuating gain to put the signal just 232 // above the zero sample values. 233 void SuppressionGain::GetMinGain(ArrayView<const float> weighted_residual_echo, 234 ArrayView<const float> last_nearend, 235 ArrayView<const float> last_echo, 236 bool low_noise_render, 237 bool saturated_echo, 238 ArrayView<float> min_gain) const { 239 if (!saturated_echo) { 240 const float min_echo_power = 241 low_noise_render ? config_.echo_audibility.low_render_limit 242 : config_.echo_audibility.normal_render_limit; 243 244 for (size_t k = 0; k < min_gain.size(); ++k) { 245 min_gain[k] = weighted_residual_echo[k] > 0.f 246 ? min_echo_power / weighted_residual_echo[k] 247 : 1.f; 248 min_gain[k] = std::min(min_gain[k], 1.f); 249 } 250 251 if (!initial_state_ || 252 config_.suppressor.lf_smoothing_during_initial_phase) { 253 const float& dec = dominant_nearend_detector_->IsNearendState() 254 ? nearend_params_.max_dec_factor_lf 255 : normal_params_.max_dec_factor_lf; 256 257 for (int k = 0; k <= config_.suppressor.last_lf_smoothing_band; ++k) { 258 // Make sure the gains of the low frequencies do not decrease too 259 // quickly after strong nearend. 260 if (last_nearend[k] > last_echo[k] || 261 k <= config_.suppressor.last_permanent_lf_smoothing_band) { 262 min_gain[k] = std::max(min_gain[k], last_gain_[k] * dec); 263 min_gain[k] = std::min(min_gain[k], 1.f); 264 } 265 } 266 } 267 } else { 268 std::fill(min_gain.begin(), min_gain.end(), 0.f); 269 } 270 } 271 272 // Compute the maximum gain by limiting the gain increase from the previous 273 // gain. 274 void SuppressionGain::GetMaxGain(ArrayView<float> max_gain) const { 275 const auto& inc = dominant_nearend_detector_->IsNearendState() 276 ? nearend_params_.max_inc_factor 277 : normal_params_.max_inc_factor; 278 const auto& floor = config_.suppressor.floor_first_increase; 279 for (size_t k = 0; k < max_gain.size(); ++k) { 280 max_gain[k] = std::min(std::max(last_gain_[k] * inc, floor), 1.f); 281 } 282 } 283 284 void SuppressionGain::LowerBandGain( 285 bool low_noise_render, 286 const AecState& aec_state, 287 ArrayView<const std::array<float, kFftLengthBy2Plus1>> suppressor_input, 288 ArrayView<const std::array<float, kFftLengthBy2Plus1>> residual_echo, 289 ArrayView<const std::array<float, kFftLengthBy2Plus1>> comfort_noise, 290 bool clock_drift, 291 std::array<float, kFftLengthBy2Plus1>* gain) { 292 gain->fill(1.f); 293 const bool saturated_echo = aec_state.SaturatedEcho(); 294 std::array<float, kFftLengthBy2Plus1> max_gain; 295 GetMaxGain(max_gain); 296 297 for (size_t ch = 0; ch < num_capture_channels_; ++ch) { 298 std::array<float, kFftLengthBy2Plus1> G; 299 std::array<float, kFftLengthBy2Plus1> nearend; 300 nearend_smoothers_[ch].Average(suppressor_input[ch], nearend); 301 302 // Weight echo power in terms of audibility. 303 std::array<float, kFftLengthBy2Plus1> weighted_residual_echo; 304 WeightEchoForAudibility(config_, residual_echo[ch], weighted_residual_echo); 305 306 std::array<float, kFftLengthBy2Plus1> min_gain; 307 GetMinGain(weighted_residual_echo, last_nearend_[ch], last_echo_[ch], 308 low_noise_render, saturated_echo, min_gain); 309 310 GainToNoAudibleEcho(nearend, weighted_residual_echo, comfort_noise[0], &G); 311 312 // Clamp gains. 313 for (size_t k = 0; k < gain->size(); ++k) { 314 G[k] = std::max(std::min(G[k], max_gain[k]), min_gain[k]); 315 (*gain)[k] = std::min((*gain)[k], G[k]); 316 } 317 318 // Store data required for the gain computation of the next block. 319 std::copy(nearend.begin(), nearend.end(), last_nearend_[ch].begin()); 320 std::copy(weighted_residual_echo.begin(), weighted_residual_echo.end(), 321 last_echo_[ch].begin()); 322 } 323 324 LimitLowFrequencyGains(gain); 325 // Use conservative high-frequency gains during clock-drift or when not in 326 // dominant nearend. 327 if (!dominant_nearend_detector_->IsNearendState() || clock_drift || 328 config_.suppressor.conservative_hf_suppression) { 329 LimitHighFrequencyGains(config_.suppressor, gain); 330 } 331 332 // Store computed gains. 333 std::copy(gain->begin(), gain->end(), last_gain_.begin()); 334 335 // Transform gains to amplitude domain. 336 aec3::VectorMath(optimization_).Sqrt(*gain); 337 } 338 339 SuppressionGain::SuppressionGain(const EchoCanceller3Config& config, 340 Aec3Optimization optimization, 341 int /* sample_rate_hz */, 342 size_t num_capture_channels) 343 : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)), 344 optimization_(optimization), 345 config_(config), 346 num_capture_channels_(num_capture_channels), 347 state_change_duration_blocks_( 348 static_cast<int>(config_.filter.config_change_duration_blocks)), 349 last_nearend_(num_capture_channels_, {0}), 350 last_echo_(num_capture_channels_, {0}), 351 nearend_smoothers_( 352 num_capture_channels_, 353 aec3::MovingAverage(kFftLengthBy2Plus1, 354 config.suppressor.nearend_average_blocks)), 355 nearend_params_(config_.suppressor.last_lf_band, 356 config_.suppressor.first_hf_band, 357 config_.suppressor.nearend_tuning), 358 normal_params_(config_.suppressor.last_lf_band, 359 config_.suppressor.first_hf_band, 360 config_.suppressor.normal_tuning), 361 use_unbounded_echo_spectrum_(config.suppressor.dominant_nearend_detection 362 .use_unbounded_echo_spectrum) { 363 RTC_DCHECK_LT(0, state_change_duration_blocks_); 364 last_gain_.fill(1.f); 365 if (config_.suppressor.use_subband_nearend_detection) { 366 dominant_nearend_detector_ = std::make_unique<SubbandNearendDetector>( 367 config_.suppressor.subband_nearend_detection, num_capture_channels_); 368 } else { 369 dominant_nearend_detector_ = std::make_unique<DominantNearendDetector>( 370 config_.suppressor.dominant_nearend_detection, num_capture_channels_); 371 } 372 RTC_DCHECK(dominant_nearend_detector_); 373 } 374 375 SuppressionGain::~SuppressionGain() = default; 376 377 void SuppressionGain::GetGain( 378 ArrayView<const std::array<float, kFftLengthBy2Plus1>> nearend_spectrum, 379 ArrayView<const std::array<float, kFftLengthBy2Plus1>> echo_spectrum, 380 ArrayView<const std::array<float, kFftLengthBy2Plus1>> 381 residual_echo_spectrum, 382 ArrayView<const std::array<float, kFftLengthBy2Plus1>> 383 residual_echo_spectrum_unbounded, 384 ArrayView<const std::array<float, kFftLengthBy2Plus1>> 385 comfort_noise_spectrum, 386 const RenderSignalAnalyzer& render_signal_analyzer, 387 const AecState& aec_state, 388 const Block& render, 389 bool clock_drift, 390 float* high_bands_gain, 391 std::array<float, kFftLengthBy2Plus1>* low_band_gain) { 392 RTC_DCHECK(high_bands_gain); 393 RTC_DCHECK(low_band_gain); 394 395 // Choose residual echo spectrum for dominant nearend detection. 396 const auto echo = use_unbounded_echo_spectrum_ 397 ? residual_echo_spectrum_unbounded 398 : residual_echo_spectrum; 399 400 // Update the nearend state selection. 401 dominant_nearend_detector_->Update(nearend_spectrum, echo, 402 comfort_noise_spectrum, initial_state_); 403 404 // Compute gain for the lower band. 405 bool low_noise_render = low_render_detector_.Detect(render); 406 LowerBandGain(low_noise_render, aec_state, nearend_spectrum, 407 residual_echo_spectrum, comfort_noise_spectrum, clock_drift, 408 low_band_gain); 409 410 // Compute the gain for the upper bands. 411 const std::optional<int> narrow_peak_band = 412 render_signal_analyzer.NarrowPeakBand(); 413 414 *high_bands_gain = 415 UpperBandsGain(echo_spectrum, comfort_noise_spectrum, narrow_peak_band, 416 aec_state.SaturatedEcho(), render, *low_band_gain); 417 418 data_dumper_->DumpRaw("aec3_dominant_nearend", 419 dominant_nearend_detector_->IsNearendState()); 420 } 421 422 void SuppressionGain::SetInitialState(bool state) { 423 initial_state_ = state; 424 if (state) { 425 initial_state_change_counter_ = state_change_duration_blocks_; 426 } else { 427 initial_state_change_counter_ = 0; 428 } 429 } 430 431 // Detects when the render signal can be considered to have low power and 432 // consist of stationary noise. 433 bool SuppressionGain::LowNoiseRenderDetector::Detect(const Block& render) { 434 float x2_sum = 0.f; 435 float x2_max = 0.f; 436 for (int ch = 0; ch < render.NumChannels(); ++ch) { 437 for (float x_k : render.View(/*band=*/0, ch)) { 438 const float x2 = x_k * x_k; 439 x2_sum += x2; 440 x2_max = std::max(x2_max, x2); 441 } 442 } 443 x2_sum = x2_sum / render.NumChannels(); 444 445 constexpr float kThreshold = 50.f * 50.f * 64.f; 446 const bool low_noise_render = 447 average_power_ < kThreshold && x2_max < 3 * average_power_; 448 average_power_ = average_power_ * 0.9f + x2_sum * 0.1f; 449 return low_noise_render; 450 } 451 452 SuppressionGain::GainParameters::GainParameters( 453 int last_lf_band, 454 int first_hf_band, 455 const EchoCanceller3Config::Suppressor::Tuning& tuning) 456 : max_inc_factor(tuning.max_inc_factor), 457 max_dec_factor_lf(tuning.max_dec_factor_lf) { 458 // Compute per-band masking thresholds. 459 RTC_DCHECK_LT(last_lf_band, first_hf_band); 460 auto& lf = tuning.mask_lf; 461 auto& hf = tuning.mask_hf; 462 RTC_DCHECK_LT(lf.enr_transparent, lf.enr_suppress); 463 RTC_DCHECK_LT(hf.enr_transparent, hf.enr_suppress); 464 for (int k = 0; k < static_cast<int>(kFftLengthBy2Plus1); k++) { 465 float a; 466 if (k <= last_lf_band) { 467 a = 0.f; 468 } else if (k < first_hf_band) { 469 a = (k - last_lf_band) / static_cast<float>(first_hf_band - last_lf_band); 470 } else { 471 a = 1.f; 472 } 473 enr_transparent_[k] = (1 - a) * lf.enr_transparent + a * hf.enr_transparent; 474 enr_suppress_[k] = (1 - a) * lf.enr_suppress + a * hf.enr_suppress; 475 emr_transparent_[k] = (1 - a) * lf.emr_transparent + a * hf.emr_transparent; 476 } 477 } 478 479 } // namespace webrtc