subband_erle_estimator.cc (9350B)
1 /* 2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/aec3/subband_erle_estimator.h" 12 13 #include <algorithm> 14 #include <array> 15 #include <cstddef> 16 #include <functional> 17 #include <memory> 18 #include <vector> 19 20 #include "api/array_view.h" 21 #include "api/audio/echo_canceller3_config.h" 22 #include "api/environment/environment.h" 23 #include "api/field_trials_view.h" 24 #include "modules/audio_processing/aec3/aec3_common.h" 25 #include "modules/audio_processing/logging/apm_data_dumper.h" 26 #include "rtc_base/checks.h" 27 #include "rtc_base/numerics/safe_minmax.h" 28 29 namespace webrtc { 30 31 namespace { 32 33 constexpr float kX2BandEnergyThreshold = 44015068.0f; 34 constexpr int kBlocksToHoldErle = 100; 35 constexpr int kBlocksForOnsetDetection = kBlocksToHoldErle + 150; 36 constexpr int kPointsToAccumulate = 6; 37 38 std::array<float, kFftLengthBy2Plus1> SetMaxErleBands(float max_erle_l, 39 float max_erle_h) { 40 std::array<float, kFftLengthBy2Plus1> max_erle; 41 std::fill(max_erle.begin(), max_erle.begin() + kFftLengthBy2 / 2, max_erle_l); 42 std::fill(max_erle.begin() + kFftLengthBy2 / 2, max_erle.end(), max_erle_h); 43 return max_erle; 44 } 45 46 bool EnableMinErleDuringOnsets(const FieldTrialsView& field_trials) { 47 return !field_trials.IsEnabled("WebRTC-Aec3MinErleDuringOnsetsKillSwitch"); 48 } 49 50 } // namespace 51 52 SubbandErleEstimator::SubbandErleEstimator(const Environment& env, 53 const EchoCanceller3Config& config, 54 size_t num_capture_channels) 55 : use_onset_detection_(config.erle.onset_detection), 56 min_erle_(config.erle.min), 57 max_erle_(SetMaxErleBands(config.erle.max_l, config.erle.max_h)), 58 use_min_erle_during_onsets_( 59 EnableMinErleDuringOnsets(env.field_trials())), 60 accum_spectra_(num_capture_channels), 61 erle_(num_capture_channels), 62 erle_onset_compensated_(num_capture_channels), 63 erle_unbounded_(num_capture_channels), 64 erle_during_onsets_(num_capture_channels), 65 coming_onset_(num_capture_channels), 66 hold_counters_(num_capture_channels) { 67 Reset(); 68 } 69 70 SubbandErleEstimator::~SubbandErleEstimator() = default; 71 72 void SubbandErleEstimator::Reset() { 73 const size_t num_capture_channels = erle_.size(); 74 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 75 erle_[ch].fill(min_erle_); 76 erle_onset_compensated_[ch].fill(min_erle_); 77 erle_unbounded_[ch].fill(min_erle_); 78 erle_during_onsets_[ch].fill(min_erle_); 79 coming_onset_[ch].fill(true); 80 hold_counters_[ch].fill(0); 81 } 82 ResetAccumulatedSpectra(); 83 } 84 85 void SubbandErleEstimator::Update( 86 ArrayView<const float, kFftLengthBy2Plus1> X2, 87 ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, 88 ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, 89 const std::vector<bool>& converged_filters) { 90 UpdateAccumulatedSpectra(X2, Y2, E2, converged_filters); 91 UpdateBands(converged_filters); 92 93 if (use_onset_detection_) { 94 DecreaseErlePerBandForLowRenderSignals(); 95 } 96 97 const size_t num_capture_channels = erle_.size(); 98 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 99 auto& erle = erle_[ch]; 100 erle[0] = erle[1]; 101 erle[kFftLengthBy2] = erle[kFftLengthBy2 - 1]; 102 103 auto& erle_oc = erle_onset_compensated_[ch]; 104 erle_oc[0] = erle_oc[1]; 105 erle_oc[kFftLengthBy2] = erle_oc[kFftLengthBy2 - 1]; 106 107 auto& erle_u = erle_unbounded_[ch]; 108 erle_u[0] = erle_u[1]; 109 erle_u[kFftLengthBy2] = erle_u[kFftLengthBy2 - 1]; 110 } 111 } 112 113 void SubbandErleEstimator::Dump( 114 const std::unique_ptr<ApmDataDumper>& data_dumper) const { 115 data_dumper->DumpRaw("aec3_erle_onset", ErleDuringOnsets()[0]); 116 } 117 118 void SubbandErleEstimator::UpdateBands( 119 const std::vector<bool>& converged_filters) { 120 const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size()); 121 for (int ch = 0; ch < num_capture_channels; ++ch) { 122 // Note that the use of the converged_filter flag already imposed 123 // a minimum of the erle that can be estimated as that flag would 124 // be false if the filter is performing poorly. 125 if (!converged_filters[ch]) { 126 continue; 127 } 128 129 if (accum_spectra_.num_points[ch] != kPointsToAccumulate) { 130 continue; 131 } 132 133 std::array<float, kFftLengthBy2> new_erle; 134 std::array<bool, kFftLengthBy2> is_erle_updated; 135 is_erle_updated.fill(false); 136 137 for (size_t k = 1; k < kFftLengthBy2; ++k) { 138 if (accum_spectra_.E2[ch][k] > 0.f) { 139 new_erle[k] = accum_spectra_.Y2[ch][k] / accum_spectra_.E2[ch][k]; 140 is_erle_updated[k] = true; 141 } 142 } 143 144 if (use_onset_detection_) { 145 for (size_t k = 1; k < kFftLengthBy2; ++k) { 146 if (is_erle_updated[k] && !accum_spectra_.low_render_energy[ch][k]) { 147 if (coming_onset_[ch][k]) { 148 coming_onset_[ch][k] = false; 149 if (!use_min_erle_during_onsets_) { 150 float alpha = 151 new_erle[k] < erle_during_onsets_[ch][k] ? 0.3f : 0.15f; 152 erle_during_onsets_[ch][k] = SafeClamp( 153 erle_during_onsets_[ch][k] + 154 alpha * (new_erle[k] - erle_during_onsets_[ch][k]), 155 min_erle_, max_erle_[k]); 156 } 157 } 158 hold_counters_[ch][k] = kBlocksForOnsetDetection; 159 } 160 } 161 } 162 163 auto update_erle_band = [](float& erle, float new_erle, 164 bool low_render_energy, float min_erle, 165 float max_erle) { 166 float alpha = 0.05f; 167 if (new_erle < erle) { 168 alpha = low_render_energy ? 0.f : 0.1f; 169 } 170 erle = SafeClamp(erle + alpha * (new_erle - erle), min_erle, max_erle); 171 }; 172 173 for (size_t k = 1; k < kFftLengthBy2; ++k) { 174 if (is_erle_updated[k]) { 175 const bool low_render_energy = accum_spectra_.low_render_energy[ch][k]; 176 update_erle_band(erle_[ch][k], new_erle[k], low_render_energy, 177 min_erle_, max_erle_[k]); 178 if (use_onset_detection_) { 179 update_erle_band(erle_onset_compensated_[ch][k], new_erle[k], 180 low_render_energy, min_erle_, max_erle_[k]); 181 } 182 183 // Virtually unbounded ERLE. 184 constexpr float kUnboundedErleMax = 100000.0f; 185 update_erle_band(erle_unbounded_[ch][k], new_erle[k], low_render_energy, 186 min_erle_, kUnboundedErleMax); 187 } 188 } 189 } 190 } 191 192 void SubbandErleEstimator::DecreaseErlePerBandForLowRenderSignals() { 193 const int num_capture_channels = static_cast<int>(accum_spectra_.Y2.size()); 194 for (int ch = 0; ch < num_capture_channels; ++ch) { 195 for (size_t k = 1; k < kFftLengthBy2; ++k) { 196 --hold_counters_[ch][k]; 197 if (hold_counters_[ch][k] <= 198 (kBlocksForOnsetDetection - kBlocksToHoldErle)) { 199 if (erle_onset_compensated_[ch][k] > erle_during_onsets_[ch][k]) { 200 erle_onset_compensated_[ch][k] = 201 std::max(erle_during_onsets_[ch][k], 202 0.97f * erle_onset_compensated_[ch][k]); 203 RTC_DCHECK_LE(min_erle_, erle_onset_compensated_[ch][k]); 204 } 205 if (hold_counters_[ch][k] <= 0) { 206 coming_onset_[ch][k] = true; 207 hold_counters_[ch][k] = 0; 208 } 209 } 210 } 211 } 212 } 213 214 void SubbandErleEstimator::ResetAccumulatedSpectra() { 215 for (size_t ch = 0; ch < erle_during_onsets_.size(); ++ch) { 216 accum_spectra_.Y2[ch].fill(0.f); 217 accum_spectra_.E2[ch].fill(0.f); 218 accum_spectra_.num_points[ch] = 0; 219 accum_spectra_.low_render_energy[ch].fill(false); 220 } 221 } 222 223 void SubbandErleEstimator::UpdateAccumulatedSpectra( 224 ArrayView<const float, kFftLengthBy2Plus1> X2, 225 ArrayView<const std::array<float, kFftLengthBy2Plus1>> Y2, 226 ArrayView<const std::array<float, kFftLengthBy2Plus1>> E2, 227 const std::vector<bool>& converged_filters) { 228 auto& st = accum_spectra_; 229 RTC_DCHECK_EQ(st.E2.size(), E2.size()); 230 RTC_DCHECK_EQ(st.E2.size(), E2.size()); 231 const int num_capture_channels = static_cast<int>(Y2.size()); 232 for (int ch = 0; ch < num_capture_channels; ++ch) { 233 // Note that the use of the converged_filter flag already imposed 234 // a minimum of the erle that can be estimated as that flag would 235 // be false if the filter is performing poorly. 236 if (!converged_filters[ch]) { 237 continue; 238 } 239 240 if (st.num_points[ch] == kPointsToAccumulate) { 241 st.num_points[ch] = 0; 242 st.Y2[ch].fill(0.f); 243 st.E2[ch].fill(0.f); 244 st.low_render_energy[ch].fill(false); 245 } 246 247 std::transform(Y2[ch].begin(), Y2[ch].end(), st.Y2[ch].begin(), 248 st.Y2[ch].begin(), std::plus<float>()); 249 std::transform(E2[ch].begin(), E2[ch].end(), st.E2[ch].begin(), 250 st.E2[ch].begin(), std::plus<float>()); 251 252 for (size_t k = 0; k < X2.size(); ++k) { 253 st.low_render_energy[ch][k] = 254 st.low_render_energy[ch][k] || X2[k] < kX2BandEnergyThreshold; 255 } 256 257 ++st.num_points[ch]; 258 } 259 } 260 261 } // namespace webrtc