stationarity_estimator.cc (8510B)
1 /* 2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/aec3/stationarity_estimator.h" 12 13 #include <algorithm> 14 #include <array> 15 #include <atomic> 16 #include <cstddef> 17 18 #include "api/array_view.h" 19 #include "modules/audio_processing/aec3/aec3_common.h" 20 #include "modules/audio_processing/aec3/spectrum_buffer.h" 21 #include "modules/audio_processing/logging/apm_data_dumper.h" 22 #include "rtc_base/checks.h" 23 24 namespace webrtc { 25 26 namespace { 27 constexpr float kMinNoisePower = 10.f; 28 constexpr int kHangoverBlocks = kNumBlocksPerSecond / 20; 29 constexpr int kNBlocksAverageInitPhase = 20; 30 constexpr int kNBlocksInitialPhase = kNumBlocksPerSecond * 2.; 31 } // namespace 32 33 StationarityEstimator::StationarityEstimator() 34 : data_dumper_(new ApmDataDumper(instance_count_.fetch_add(1) + 1)) { 35 Reset(); 36 } 37 38 StationarityEstimator::~StationarityEstimator() = default; 39 40 void StationarityEstimator::Reset() { 41 noise_.Reset(); 42 hangovers_.fill(0); 43 stationarity_flags_.fill(false); 44 } 45 46 // Update just the noise estimator. Usefull until the delay is known 47 void StationarityEstimator::UpdateNoiseEstimator( 48 ArrayView<const std::array<float, kFftLengthBy2Plus1>> spectrum) { 49 noise_.Update(spectrum); 50 data_dumper_->DumpRaw("aec3_stationarity_noise_spectrum", noise_.Spectrum()); 51 data_dumper_->DumpRaw("aec3_stationarity_is_block_stationary", 52 IsBlockStationary()); 53 } 54 55 void StationarityEstimator::UpdateStationarityFlags( 56 const SpectrumBuffer& spectrum_buffer, 57 ArrayView<const float> render_reverb_contribution_spectrum, 58 int idx_current, 59 int num_lookahead) { 60 std::array<int, kWindowLength> indexes; 61 int num_lookahead_bounded = std::min(num_lookahead, kWindowLength - 1); 62 int idx = idx_current; 63 64 if (num_lookahead_bounded < kWindowLength - 1) { 65 int num_lookback = (kWindowLength - 1) - num_lookahead_bounded; 66 idx = spectrum_buffer.OffsetIndex(idx_current, num_lookback); 67 } 68 // For estimating the stationarity properties of the current frame, the 69 // power for each band is accumulated for several consecutive spectra in the 70 // method EstimateBandStationarity. 71 // In order to avoid getting the indexes of the spectra for every band with 72 // its associated overhead, those indexes are stored in an array and then use 73 // when the estimation is done. 74 indexes[0] = idx; 75 for (size_t k = 1; k < indexes.size(); ++k) { 76 indexes[k] = spectrum_buffer.DecIndex(indexes[k - 1]); 77 } 78 RTC_DCHECK_EQ( 79 spectrum_buffer.DecIndex(indexes[kWindowLength - 1]), 80 spectrum_buffer.OffsetIndex(idx_current, -(num_lookahead_bounded + 1))); 81 82 for (size_t k = 0; k < stationarity_flags_.size(); ++k) { 83 stationarity_flags_[k] = EstimateBandStationarity( 84 spectrum_buffer, render_reverb_contribution_spectrum, indexes, k); 85 } 86 UpdateHangover(); 87 SmoothStationaryPerFreq(); 88 } 89 90 bool StationarityEstimator::IsBlockStationary() const { 91 float acum_stationarity = 0.f; 92 RTC_DCHECK_EQ(stationarity_flags_.size(), kFftLengthBy2Plus1); 93 for (size_t band = 0; band < stationarity_flags_.size(); ++band) { 94 bool st = IsBandStationary(band); 95 acum_stationarity += static_cast<float>(st); 96 } 97 return ((acum_stationarity * (1.f / kFftLengthBy2Plus1)) > 0.75f); 98 } 99 100 bool StationarityEstimator::EstimateBandStationarity( 101 const SpectrumBuffer& spectrum_buffer, 102 ArrayView<const float> average_reverb, 103 const std::array<int, kWindowLength>& indexes, 104 size_t band) const { 105 constexpr float kThrStationarity = 10.f; 106 float acum_power = 0.f; 107 const int num_render_channels = 108 static_cast<int>(spectrum_buffer.buffer[0].size()); 109 const float one_by_num_channels = 1.f / num_render_channels; 110 for (auto idx : indexes) { 111 for (int ch = 0; ch < num_render_channels; ++ch) { 112 acum_power += spectrum_buffer.buffer[idx][ch][band] * one_by_num_channels; 113 } 114 } 115 acum_power += average_reverb[band]; 116 float noise = kWindowLength * GetStationarityPowerBand(band); 117 RTC_CHECK_LT(0.f, noise); 118 bool stationary = acum_power < kThrStationarity * noise; 119 data_dumper_->DumpRaw("aec3_stationarity_long_ratio", acum_power / noise); 120 return stationary; 121 } 122 123 bool StationarityEstimator::AreAllBandsStationary() { 124 for (auto b : stationarity_flags_) { 125 if (!b) 126 return false; 127 } 128 return true; 129 } 130 131 void StationarityEstimator::UpdateHangover() { 132 bool reduce_hangover = AreAllBandsStationary(); 133 for (size_t k = 0; k < stationarity_flags_.size(); ++k) { 134 if (!stationarity_flags_[k]) { 135 hangovers_[k] = kHangoverBlocks; 136 } else if (reduce_hangover) { 137 hangovers_[k] = std::max(hangovers_[k] - 1, 0); 138 } 139 } 140 } 141 142 void StationarityEstimator::SmoothStationaryPerFreq() { 143 std::array<bool, kFftLengthBy2Plus1> all_ahead_stationary_smooth; 144 for (size_t k = 1; k < kFftLengthBy2Plus1 - 1; ++k) { 145 all_ahead_stationary_smooth[k] = stationarity_flags_[k - 1] && 146 stationarity_flags_[k] && 147 stationarity_flags_[k + 1]; 148 } 149 150 all_ahead_stationary_smooth[0] = all_ahead_stationary_smooth[1]; 151 all_ahead_stationary_smooth[kFftLengthBy2Plus1 - 1] = 152 all_ahead_stationary_smooth[kFftLengthBy2Plus1 - 2]; 153 154 stationarity_flags_ = all_ahead_stationary_smooth; 155 } 156 157 std::atomic<int> StationarityEstimator::instance_count_(0); 158 159 StationarityEstimator::NoiseSpectrum::NoiseSpectrum() { 160 Reset(); 161 } 162 163 StationarityEstimator::NoiseSpectrum::~NoiseSpectrum() = default; 164 165 void StationarityEstimator::NoiseSpectrum::Reset() { 166 block_counter_ = 0; 167 noise_spectrum_.fill(kMinNoisePower); 168 } 169 170 void StationarityEstimator::NoiseSpectrum::Update( 171 ArrayView<const std::array<float, kFftLengthBy2Plus1>> spectrum) { 172 RTC_DCHECK_LE(1, spectrum[0].size()); 173 const int num_render_channels = static_cast<int>(spectrum.size()); 174 175 std::array<float, kFftLengthBy2Plus1> avg_spectrum_data; 176 ArrayView<const float> avg_spectrum; 177 if (num_render_channels == 1) { 178 avg_spectrum = spectrum[0]; 179 } else { 180 // For multiple channels, average the channel spectra before passing to the 181 // noise spectrum estimator. 182 avg_spectrum = avg_spectrum_data; 183 std::copy(spectrum[0].begin(), spectrum[0].end(), 184 avg_spectrum_data.begin()); 185 for (int ch = 1; ch < num_render_channels; ++ch) { 186 for (size_t k = 1; k < kFftLengthBy2Plus1; ++k) { 187 avg_spectrum_data[k] += spectrum[ch][k]; 188 } 189 } 190 191 const float one_by_num_channels = 1.f / num_render_channels; 192 for (size_t k = 1; k < kFftLengthBy2Plus1; ++k) { 193 avg_spectrum_data[k] *= one_by_num_channels; 194 } 195 } 196 197 ++block_counter_; 198 float alpha = GetAlpha(); 199 for (size_t k = 0; k < kFftLengthBy2Plus1; ++k) { 200 if (block_counter_ <= kNBlocksAverageInitPhase) { 201 noise_spectrum_[k] += (1.f / kNBlocksAverageInitPhase) * avg_spectrum[k]; 202 } else { 203 noise_spectrum_[k] = 204 UpdateBandBySmoothing(avg_spectrum[k], noise_spectrum_[k], alpha); 205 } 206 } 207 } 208 209 float StationarityEstimator::NoiseSpectrum::GetAlpha() const { 210 constexpr float kAlpha = 0.004f; 211 constexpr float kAlphaInit = 0.04f; 212 constexpr float kTiltAlpha = (kAlphaInit - kAlpha) / kNBlocksInitialPhase; 213 214 if (block_counter_ > (kNBlocksInitialPhase + kNBlocksAverageInitPhase)) { 215 return kAlpha; 216 } else { 217 return kAlphaInit - 218 kTiltAlpha * (block_counter_ - kNBlocksAverageInitPhase); 219 } 220 } 221 222 float StationarityEstimator::NoiseSpectrum::UpdateBandBySmoothing( 223 float power_band, 224 float power_band_noise, 225 float alpha) const { 226 float power_band_noise_updated = power_band_noise; 227 if (power_band_noise < power_band) { 228 RTC_DCHECK_GT(power_band, 0.f); 229 float alpha_inc = alpha * (power_band_noise / power_band); 230 if (block_counter_ > kNBlocksInitialPhase) { 231 if (10.f * power_band_noise < power_band) { 232 alpha_inc *= 0.1f; 233 } 234 } 235 power_band_noise_updated += alpha_inc * (power_band - power_band_noise); 236 } else { 237 power_band_noise_updated += alpha * (power_band - power_band_noise); 238 power_band_noise_updated = 239 std::max(power_band_noise_updated, kMinNoisePower); 240 } 241 return power_band_noise_updated; 242 } 243 244 } // namespace webrtc