subtractor.cc (15052B)
1 /* 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/aec3/subtractor.h" 12 13 #include <algorithm> 14 #include <array> 15 #include <cstddef> 16 #include <memory> 17 #include <vector> 18 19 #include "api/array_view.h" 20 #include "api/audio/echo_canceller3_config.h" 21 #include "api/environment/environment.h" 22 #include "api/field_trials_view.h" 23 #include "modules/audio_processing/aec3/adaptive_fir_filter.h" 24 #include "modules/audio_processing/aec3/adaptive_fir_filter_erl.h" 25 #include "modules/audio_processing/aec3/aec3_common.h" 26 #include "modules/audio_processing/aec3/aec3_fft.h" 27 #include "modules/audio_processing/aec3/aec_state.h" 28 #include "modules/audio_processing/aec3/block.h" 29 #include "modules/audio_processing/aec3/coarse_filter_update_gain.h" 30 #include "modules/audio_processing/aec3/echo_path_variability.h" 31 #include "modules/audio_processing/aec3/fft_data.h" 32 #include "modules/audio_processing/aec3/refined_filter_update_gain.h" 33 #include "modules/audio_processing/aec3/render_buffer.h" 34 #include "modules/audio_processing/aec3/render_signal_analyzer.h" 35 #include "modules/audio_processing/aec3/subtractor_output.h" 36 #include "modules/audio_processing/logging/apm_data_dumper.h" 37 #include "rtc_base/checks.h" 38 #include "rtc_base/numerics/safe_minmax.h" 39 40 namespace webrtc { 41 42 namespace { 43 44 bool UseCoarseFilterResetHangover(const FieldTrialsView& field_trials) { 45 return !field_trials.IsEnabled( 46 "WebRTC-Aec3CoarseFilterResetHangoverKillSwitch"); 47 } 48 49 void PredictionError(const Aec3Fft& fft, 50 const FftData& S, 51 ArrayView<const float> y, 52 std::array<float, kBlockSize>* e, 53 std::array<float, kBlockSize>* s) { 54 std::array<float, kFftLength> tmp; 55 fft.Ifft(S, &tmp); 56 constexpr float kScale = 1.0f / kFftLengthBy2; 57 std::transform(y.begin(), y.end(), tmp.begin() + kFftLengthBy2, e->begin(), 58 [&](float a, float b) { return a - b * kScale; }); 59 60 if (s) { 61 for (size_t k = 0; k < s->size(); ++k) { 62 (*s)[k] = kScale * tmp[k + kFftLengthBy2]; 63 } 64 } 65 } 66 67 void ScaleFilterOutput(ArrayView<const float> y, 68 float factor, 69 ArrayView<float> e, 70 ArrayView<float> s) { 71 RTC_DCHECK_EQ(y.size(), e.size()); 72 RTC_DCHECK_EQ(y.size(), s.size()); 73 for (size_t k = 0; k < y.size(); ++k) { 74 s[k] *= factor; 75 e[k] = y[k] - s[k]; 76 } 77 } 78 79 } // namespace 80 81 Subtractor::Subtractor(const Environment& env, 82 const EchoCanceller3Config& config, 83 size_t num_render_channels, 84 size_t num_capture_channels, 85 ApmDataDumper* data_dumper, 86 Aec3Optimization optimization) 87 : fft_(), 88 data_dumper_(data_dumper), 89 optimization_(optimization), 90 config_(config), 91 num_capture_channels_(num_capture_channels), 92 use_coarse_filter_reset_hangover_( 93 UseCoarseFilterResetHangover(env.field_trials())), 94 refined_filters_(num_capture_channels_), 95 coarse_filter_(num_capture_channels_), 96 refined_gains_(num_capture_channels_), 97 coarse_gains_(num_capture_channels_), 98 filter_misadjustment_estimators_(num_capture_channels_), 99 poor_coarse_filter_counters_(num_capture_channels_, 0), 100 coarse_filter_reset_hangover_(num_capture_channels_, 0), 101 refined_frequency_responses_( 102 num_capture_channels_, 103 std::vector<std::array<float, kFftLengthBy2Plus1>>( 104 std::max(config_.filter.refined_initial.length_blocks, 105 config_.filter.refined.length_blocks), 106 std::array<float, kFftLengthBy2Plus1>())), 107 refined_impulse_responses_( 108 num_capture_channels_, 109 std::vector<float>(GetTimeDomainLength(std::max( 110 config_.filter.refined_initial.length_blocks, 111 config_.filter.refined.length_blocks)), 112 0.f)), 113 coarse_impulse_responses_(0) { 114 // Set up the storing of coarse impulse responses if data dumping is 115 // available. 116 if (ApmDataDumper::IsAvailable()) { 117 coarse_impulse_responses_.resize(num_capture_channels_); 118 const size_t filter_size = GetTimeDomainLength( 119 std::max(config_.filter.coarse_initial.length_blocks, 120 config_.filter.coarse.length_blocks)); 121 for (std::vector<float>& impulse_response : coarse_impulse_responses_) { 122 impulse_response.resize(filter_size, 0.f); 123 } 124 } 125 126 for (size_t ch = 0; ch < num_capture_channels_; ++ch) { 127 refined_filters_[ch] = std::make_unique<AdaptiveFirFilter>( 128 config_.filter.refined.length_blocks, 129 config_.filter.refined_initial.length_blocks, 130 config.filter.config_change_duration_blocks, num_render_channels, 131 optimization, data_dumper_); 132 133 coarse_filter_[ch] = std::make_unique<AdaptiveFirFilter>( 134 config_.filter.coarse.length_blocks, 135 config_.filter.coarse_initial.length_blocks, 136 config.filter.config_change_duration_blocks, num_render_channels, 137 optimization, data_dumper_); 138 refined_gains_[ch] = std::make_unique<RefinedFilterUpdateGain>( 139 config_.filter.refined_initial, 140 config_.filter.config_change_duration_blocks); 141 coarse_gains_[ch] = std::make_unique<CoarseFilterUpdateGain>( 142 config_.filter.coarse_initial, 143 config.filter.config_change_duration_blocks); 144 } 145 146 RTC_DCHECK(data_dumper_); 147 for (size_t ch = 0; ch < num_capture_channels_; ++ch) { 148 for (auto& H2_k : refined_frequency_responses_[ch]) { 149 H2_k.fill(0.f); 150 } 151 } 152 } 153 154 Subtractor::~Subtractor() = default; 155 156 void Subtractor::HandleEchoPathChange( 157 const EchoPathVariability& echo_path_variability) { 158 const auto full_reset = [&]() { 159 for (size_t ch = 0; ch < num_capture_channels_; ++ch) { 160 refined_filters_[ch]->HandleEchoPathChange(); 161 coarse_filter_[ch]->HandleEchoPathChange(); 162 refined_gains_[ch]->HandleEchoPathChange(echo_path_variability); 163 coarse_gains_[ch]->HandleEchoPathChange(); 164 refined_gains_[ch]->SetConfig(config_.filter.refined_initial, true); 165 coarse_gains_[ch]->SetConfig(config_.filter.coarse_initial, true); 166 refined_filters_[ch]->SetSizePartitions( 167 config_.filter.refined_initial.length_blocks, true); 168 coarse_filter_[ch]->SetSizePartitions( 169 config_.filter.coarse_initial.length_blocks, true); 170 } 171 }; 172 173 if (echo_path_variability.delay_change != 174 EchoPathVariability::DelayAdjustment::kNone) { 175 full_reset(); 176 } 177 178 if (echo_path_variability.gain_change) { 179 for (size_t ch = 0; ch < num_capture_channels_; ++ch) { 180 refined_gains_[ch]->HandleEchoPathChange(echo_path_variability); 181 } 182 } 183 } 184 185 void Subtractor::ExitInitialState() { 186 for (size_t ch = 0; ch < num_capture_channels_; ++ch) { 187 refined_gains_[ch]->SetConfig(config_.filter.refined, false); 188 coarse_gains_[ch]->SetConfig(config_.filter.coarse, false); 189 refined_filters_[ch]->SetSizePartitions( 190 config_.filter.refined.length_blocks, false); 191 coarse_filter_[ch]->SetSizePartitions(config_.filter.coarse.length_blocks, 192 false); 193 } 194 } 195 196 void Subtractor::Process(const RenderBuffer& render_buffer, 197 const Block& capture, 198 const RenderSignalAnalyzer& render_signal_analyzer, 199 const AecState& aec_state, 200 ArrayView<SubtractorOutput> outputs) { 201 RTC_DCHECK_EQ(num_capture_channels_, capture.NumChannels()); 202 203 // Compute the render powers. 204 const bool same_filter_sizes = refined_filters_[0]->SizePartitions() == 205 coarse_filter_[0]->SizePartitions(); 206 std::array<float, kFftLengthBy2Plus1> X2_refined; 207 std::array<float, kFftLengthBy2Plus1> X2_coarse_data; 208 auto& X2_coarse = same_filter_sizes ? X2_refined : X2_coarse_data; 209 if (same_filter_sizes) { 210 render_buffer.SpectralSum(refined_filters_[0]->SizePartitions(), 211 &X2_refined); 212 } else if (refined_filters_[0]->SizePartitions() > 213 coarse_filter_[0]->SizePartitions()) { 214 render_buffer.SpectralSums(coarse_filter_[0]->SizePartitions(), 215 refined_filters_[0]->SizePartitions(), 216 &X2_coarse, &X2_refined); 217 } else { 218 render_buffer.SpectralSums(refined_filters_[0]->SizePartitions(), 219 coarse_filter_[0]->SizePartitions(), &X2_refined, 220 &X2_coarse); 221 } 222 223 // Process all capture channels 224 for (size_t ch = 0; ch < num_capture_channels_; ++ch) { 225 SubtractorOutput& output = outputs[ch]; 226 ArrayView<const float> y = capture.View(/*band=*/0, ch); 227 FftData& E_refined = output.E_refined; 228 FftData E_coarse; 229 std::array<float, kBlockSize>& e_refined = output.e_refined; 230 std::array<float, kBlockSize>& e_coarse = output.e_coarse; 231 232 FftData S; 233 FftData& G = S; 234 235 // Form the outputs of the refined and coarse filters. 236 refined_filters_[ch]->Filter(render_buffer, &S); 237 PredictionError(fft_, S, y, &e_refined, &output.s_refined); 238 239 coarse_filter_[ch]->Filter(render_buffer, &S); 240 PredictionError(fft_, S, y, &e_coarse, &output.s_coarse); 241 242 // Compute the signal powers in the subtractor output. 243 output.ComputeMetrics(y); 244 245 // Adjust the filter if needed. 246 bool refined_filters_adjusted = false; 247 filter_misadjustment_estimators_[ch].Update(output); 248 if (filter_misadjustment_estimators_[ch].IsAdjustmentNeeded()) { 249 float scale = filter_misadjustment_estimators_[ch].GetMisadjustment(); 250 refined_filters_[ch]->ScaleFilter(scale); 251 for (auto& h_k : refined_impulse_responses_[ch]) { 252 h_k *= scale; 253 } 254 ScaleFilterOutput(y, scale, e_refined, output.s_refined); 255 filter_misadjustment_estimators_[ch].Reset(); 256 refined_filters_adjusted = true; 257 } 258 259 // Compute the FFts of the refined and coarse filter outputs. 260 fft_.ZeroPaddedFft(e_refined, Aec3Fft::Window::kHanning, &E_refined); 261 fft_.ZeroPaddedFft(e_coarse, Aec3Fft::Window::kHanning, &E_coarse); 262 263 // Compute spectra for future use. 264 E_coarse.Spectrum(optimization_, output.E2_coarse); 265 E_refined.Spectrum(optimization_, output.E2_refined); 266 267 // Update the refined filter. 268 if (!refined_filters_adjusted) { 269 // Do not allow the performance of the coarse filter to affect the 270 // adaptation speed of the refined filter just after the coarse filter has 271 // been reset. 272 const bool disallow_leakage_diverged = 273 coarse_filter_reset_hangover_[ch] > 0 && 274 use_coarse_filter_reset_hangover_; 275 276 std::array<float, kFftLengthBy2Plus1> erl; 277 ComputeErl(optimization_, refined_frequency_responses_[ch], erl); 278 refined_gains_[ch]->Compute(X2_refined, render_signal_analyzer, output, 279 erl, refined_filters_[ch]->SizePartitions(), 280 aec_state.SaturatedCapture(), 281 disallow_leakage_diverged, &G); 282 } else { 283 G.re.fill(0.f); 284 G.im.fill(0.f); 285 } 286 refined_filters_[ch]->Adapt(render_buffer, G, 287 &refined_impulse_responses_[ch]); 288 refined_filters_[ch]->ComputeFrequencyResponse( 289 &refined_frequency_responses_[ch]); 290 291 if (ch == 0) { 292 data_dumper_->DumpRaw("aec3_subtractor_G_refined", G.re); 293 data_dumper_->DumpRaw("aec3_subtractor_G_refined", G.im); 294 } 295 296 // Update the coarse filter. 297 poor_coarse_filter_counters_[ch] = 298 output.e2_refined < output.e2_coarse 299 ? poor_coarse_filter_counters_[ch] + 1 300 : 0; 301 if (poor_coarse_filter_counters_[ch] < 5) { 302 coarse_gains_[ch]->Compute(X2_coarse, render_signal_analyzer, E_coarse, 303 coarse_filter_[ch]->SizePartitions(), 304 aec_state.SaturatedCapture(), &G); 305 coarse_filter_reset_hangover_[ch] = 306 std::max(coarse_filter_reset_hangover_[ch] - 1, 0); 307 } else { 308 poor_coarse_filter_counters_[ch] = 0; 309 coarse_filter_[ch]->SetFilter(refined_filters_[ch]->SizePartitions(), 310 refined_filters_[ch]->GetFilter()); 311 coarse_gains_[ch]->Compute(X2_coarse, render_signal_analyzer, E_refined, 312 coarse_filter_[ch]->SizePartitions(), 313 aec_state.SaturatedCapture(), &G); 314 coarse_filter_reset_hangover_[ch] = 315 config_.filter.coarse_reset_hangover_blocks; 316 } 317 318 if (ApmDataDumper::IsAvailable()) { 319 RTC_DCHECK_LT(ch, coarse_impulse_responses_.size()); 320 coarse_filter_[ch]->Adapt(render_buffer, G, 321 &coarse_impulse_responses_[ch]); 322 } else { 323 coarse_filter_[ch]->Adapt(render_buffer, G); 324 } 325 326 if (ch == 0) { 327 data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.re); 328 data_dumper_->DumpRaw("aec3_subtractor_G_coarse", G.im); 329 filter_misadjustment_estimators_[ch].Dump(data_dumper_); 330 DumpFilters(); 331 } 332 333 std::for_each(e_refined.begin(), e_refined.end(), 334 [](float& a) { a = SafeClamp(a, -32768.f, 32767.f); }); 335 336 if (ch == 0) { 337 data_dumper_->DumpWav("aec3_refined_filters_output", kBlockSize, 338 &e_refined[0], 16000, 1); 339 data_dumper_->DumpWav("aec3_coarse_filter_output", kBlockSize, 340 &e_coarse[0], 16000, 1); 341 } 342 } 343 } 344 345 void Subtractor::FilterMisadjustmentEstimator::Update( 346 const SubtractorOutput& output) { 347 e2_acum_ += output.e2_refined; 348 y2_acum_ += output.y2; 349 if (++n_blocks_acum_ == n_blocks_) { 350 if (y2_acum_ > n_blocks_ * 200.f * 200.f * kBlockSize) { 351 float update = (e2_acum_ / y2_acum_); 352 if (e2_acum_ > n_blocks_ * 7500.f * 7500.f * kBlockSize) { 353 // Duration equal to blockSizeMs * n_blocks_ * 4. 354 overhang_ = 4; 355 } else { 356 overhang_ = std::max(overhang_ - 1, 0); 357 } 358 359 if ((update < inv_misadjustment_) || (overhang_ > 0)) { 360 inv_misadjustment_ += 0.1f * (update - inv_misadjustment_); 361 } 362 } 363 e2_acum_ = 0.f; 364 y2_acum_ = 0.f; 365 n_blocks_acum_ = 0; 366 } 367 } 368 369 void Subtractor::FilterMisadjustmentEstimator::Reset() { 370 e2_acum_ = 0.f; 371 y2_acum_ = 0.f; 372 n_blocks_acum_ = 0; 373 inv_misadjustment_ = 0.f; 374 overhang_ = 0.f; 375 } 376 377 void Subtractor::FilterMisadjustmentEstimator::Dump( 378 ApmDataDumper* data_dumper) const { 379 data_dumper->DumpRaw("aec3_inv_misadjustment_factor", inv_misadjustment_); 380 } 381 382 } // namespace webrtc