aec_state_unittest.cc (12109B)
1 /* 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/aec3/aec_state.h" 12 13 #include <algorithm> 14 #include <array> 15 #include <cstddef> 16 #include <memory> 17 #include <optional> 18 #include <tuple> 19 #include <vector> 20 21 #include "api/audio/echo_canceller3_config.h" 22 #include "api/environment/environment_factory.h" 23 #include "modules/audio_processing/aec3/aec3_common.h" 24 #include "modules/audio_processing/aec3/aec3_fft.h" 25 #include "modules/audio_processing/aec3/block.h" 26 #include "modules/audio_processing/aec3/delay_estimate.h" 27 #include "modules/audio_processing/aec3/echo_path_variability.h" 28 #include "modules/audio_processing/aec3/render_delay_buffer.h" 29 #include "modules/audio_processing/aec3/subtractor_output.h" 30 #include "modules/audio_processing/logging/apm_data_dumper.h" 31 #include "test/gtest.h" 32 33 namespace webrtc { 34 namespace { 35 36 void RunNormalUsageTest(size_t num_render_channels, 37 size_t num_capture_channels) { 38 // TODO(bugs.webrtc.org/10913): Test with different content in different 39 // channels. 40 constexpr int kSampleRateHz = 48000; 41 constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); 42 ApmDataDumper data_dumper(42); 43 EchoCanceller3Config config; 44 AecState state(CreateEnvironment(), config, num_capture_channels); 45 std::optional<DelayEstimate> delay_estimate = 46 DelayEstimate(DelayEstimate::Quality::kRefined, 10); 47 std::unique_ptr<RenderDelayBuffer> render_delay_buffer( 48 RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); 49 std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined( 50 num_capture_channels); 51 std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels); 52 Block x(kNumBands, num_render_channels); 53 EchoPathVariability echo_path_variability( 54 false, EchoPathVariability::DelayAdjustment::kNone, false); 55 std::vector<std::array<float, kBlockSize>> y(num_capture_channels); 56 std::vector<SubtractorOutput> subtractor_output(num_capture_channels); 57 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 58 subtractor_output[ch].Reset(); 59 subtractor_output[ch].s_refined.fill(100.f); 60 subtractor_output[ch].e_refined.fill(100.f); 61 y[ch].fill(1000.f); 62 E2_refined[ch].fill(0.f); 63 Y2[ch].fill(0.f); 64 } 65 Aec3Fft fft; 66 std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> 67 converged_filter_frequency_response( 68 num_capture_channels, 69 std::vector<std::array<float, kFftLengthBy2Plus1>>(10)); 70 for (auto& v_ch : converged_filter_frequency_response) { 71 for (auto& v : v_ch) { 72 v.fill(0.01f); 73 } 74 } 75 std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> 76 diverged_filter_frequency_response = converged_filter_frequency_response; 77 converged_filter_frequency_response[0][2].fill(100.f); 78 converged_filter_frequency_response[0][2][0] = 1.f; 79 std::vector<std::vector<float>> impulse_response( 80 num_capture_channels, 81 std::vector<float>( 82 GetTimeDomainLength(config.filter.refined.length_blocks), 0.f)); 83 84 // Verify that linear AEC usability is true when the filter is converged 85 for (size_t band = 0; band < kNumBands; ++band) { 86 for (size_t ch = 0; ch < num_render_channels; ++ch) { 87 std::fill(x.begin(band, ch), x.end(band, ch), 101.f); 88 } 89 } 90 for (int k = 0; k < 3000; ++k) { 91 render_delay_buffer->Insert(x); 92 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 93 subtractor_output[ch].ComputeMetrics(y[ch]); 94 } 95 state.Update(delay_estimate, converged_filter_frequency_response, 96 impulse_response, *render_delay_buffer->GetRenderBuffer(), 97 E2_refined, Y2, subtractor_output); 98 } 99 EXPECT_TRUE(state.UsableLinearEstimate()); 100 101 // Verify that linear AEC usability becomes false after an echo path 102 // change is reported 103 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 104 subtractor_output[ch].ComputeMetrics(y[ch]); 105 } 106 state.HandleEchoPathChange(EchoPathVariability( 107 false, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false)); 108 state.Update(delay_estimate, converged_filter_frequency_response, 109 impulse_response, *render_delay_buffer->GetRenderBuffer(), 110 E2_refined, Y2, subtractor_output); 111 EXPECT_FALSE(state.UsableLinearEstimate()); 112 113 // Verify that the active render detection works as intended. 114 for (size_t ch = 0; ch < num_render_channels; ++ch) { 115 std::fill(x.begin(0, ch), x.end(0, ch), 101.f); 116 } 117 render_delay_buffer->Insert(x); 118 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 119 subtractor_output[ch].ComputeMetrics(y[ch]); 120 } 121 state.HandleEchoPathChange(EchoPathVariability( 122 true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, false)); 123 state.Update(delay_estimate, converged_filter_frequency_response, 124 impulse_response, *render_delay_buffer->GetRenderBuffer(), 125 E2_refined, Y2, subtractor_output); 126 EXPECT_FALSE(state.ActiveRender()); 127 128 for (int k = 0; k < 1000; ++k) { 129 render_delay_buffer->Insert(x); 130 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 131 subtractor_output[ch].ComputeMetrics(y[ch]); 132 } 133 state.Update(delay_estimate, converged_filter_frequency_response, 134 impulse_response, *render_delay_buffer->GetRenderBuffer(), 135 E2_refined, Y2, subtractor_output); 136 } 137 EXPECT_TRUE(state.ActiveRender()); 138 139 // Verify that the ERL is properly estimated 140 for (int band = 0; band < x.NumBands(); ++band) { 141 for (int channel = 0; channel < x.NumChannels(); ++channel) { 142 std::fill(x.begin(band, channel), x.end(band, channel), 0.0f); 143 } 144 } 145 146 for (size_t ch = 0; ch < num_render_channels; ++ch) { 147 x.View(/*band=*/0, ch)[0] = 5000.f; 148 } 149 for (size_t k = 0; 150 k < render_delay_buffer->GetRenderBuffer()->GetFftBuffer().size(); ++k) { 151 render_delay_buffer->Insert(x); 152 if (k == 0) { 153 render_delay_buffer->Reset(); 154 } 155 render_delay_buffer->PrepareCaptureProcessing(); 156 } 157 158 for (auto& Y2_ch : Y2) { 159 Y2_ch.fill(10.f * 10000.f * 10000.f); 160 } 161 for (size_t k = 0; k < 1000; ++k) { 162 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 163 subtractor_output[ch].ComputeMetrics(y[ch]); 164 } 165 state.Update(delay_estimate, converged_filter_frequency_response, 166 impulse_response, *render_delay_buffer->GetRenderBuffer(), 167 E2_refined, Y2, subtractor_output); 168 } 169 170 ASSERT_TRUE(state.UsableLinearEstimate()); 171 const std::array<float, kFftLengthBy2Plus1>& erl = state.Erl(); 172 EXPECT_EQ(erl[0], erl[1]); 173 for (size_t k = 1; k < erl.size() - 1; ++k) { 174 EXPECT_NEAR(k % 2 == 0 ? 10.f : 1000.f, erl[k], 0.1); 175 } 176 EXPECT_EQ(erl[erl.size() - 2], erl[erl.size() - 1]); 177 178 // Verify that the ERLE is properly estimated 179 for (auto& E2_refined_ch : E2_refined) { 180 E2_refined_ch.fill(1.f * 10000.f * 10000.f); 181 } 182 for (auto& Y2_ch : Y2) { 183 Y2_ch.fill(10.f * E2_refined[0][0]); 184 } 185 for (size_t k = 0; k < 1000; ++k) { 186 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 187 subtractor_output[ch].ComputeMetrics(y[ch]); 188 } 189 state.Update(delay_estimate, converged_filter_frequency_response, 190 impulse_response, *render_delay_buffer->GetRenderBuffer(), 191 E2_refined, Y2, subtractor_output); 192 } 193 ASSERT_TRUE(state.UsableLinearEstimate()); 194 { 195 // Note that the render spectrum is built so it does not have energy in 196 // the odd bands but just in the even bands. 197 const auto& erle = state.Erle(/*onset_compensated=*/true)[0]; 198 EXPECT_EQ(erle[0], erle[1]); 199 constexpr size_t kLowFrequencyLimit = 32; 200 for (size_t k = 2; k < kLowFrequencyLimit; k = k + 2) { 201 EXPECT_NEAR(4.f, erle[k], 0.1); 202 } 203 for (size_t k = kLowFrequencyLimit; k < erle.size() - 1; k = k + 2) { 204 EXPECT_NEAR(1.5f, erle[k], 0.1); 205 } 206 EXPECT_EQ(erle[erle.size() - 2], erle[erle.size() - 1]); 207 } 208 for (auto& E2_refined_ch : E2_refined) { 209 E2_refined_ch.fill(1.f * 10000.f * 10000.f); 210 } 211 for (auto& Y2_ch : Y2) { 212 Y2_ch.fill(5.f * E2_refined[0][0]); 213 } 214 for (size_t k = 0; k < 1000; ++k) { 215 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 216 subtractor_output[ch].ComputeMetrics(y[ch]); 217 } 218 state.Update(delay_estimate, converged_filter_frequency_response, 219 impulse_response, *render_delay_buffer->GetRenderBuffer(), 220 E2_refined, Y2, subtractor_output); 221 } 222 223 ASSERT_TRUE(state.UsableLinearEstimate()); 224 { 225 const auto& erle = state.Erle(/*onset_compensated=*/true)[0]; 226 EXPECT_EQ(erle[0], erle[1]); 227 constexpr size_t kLowFrequencyLimit = 32; 228 for (size_t k = 1; k < kLowFrequencyLimit; ++k) { 229 EXPECT_NEAR(k % 2 == 0 ? 4.f : 1.f, erle[k], 0.1); 230 } 231 for (size_t k = kLowFrequencyLimit; k < erle.size() - 1; ++k) { 232 EXPECT_NEAR(k % 2 == 0 ? 1.5f : 1.f, erle[k], 0.1); 233 } 234 EXPECT_EQ(erle[erle.size() - 2], erle[erle.size() - 1]); 235 } 236 } 237 238 } // namespace 239 240 class AecStateMultiChannel 241 : public ::testing::Test, 242 public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; 243 244 INSTANTIATE_TEST_SUITE_P(MultiChannel, 245 AecStateMultiChannel, 246 ::testing::Combine(::testing::Values(1, 2, 8), 247 ::testing::Values(1, 2, 8))); 248 249 // Verify the general functionality of AecState 250 TEST_P(AecStateMultiChannel, NormalUsage) { 251 const size_t num_render_channels = std::get<0>(GetParam()); 252 const size_t num_capture_channels = std::get<1>(GetParam()); 253 RunNormalUsageTest(num_render_channels, num_capture_channels); 254 } 255 256 // Verifies the delay for a converged filter is correctly identified. 257 TEST(AecState, ConvergedFilterDelay) { 258 constexpr int kFilterLengthBlocks = 10; 259 constexpr size_t kNumCaptureChannels = 1; 260 EchoCanceller3Config config; 261 AecState state(CreateEnvironment(), config, kNumCaptureChannels); 262 std::unique_ptr<RenderDelayBuffer> render_delay_buffer( 263 RenderDelayBuffer::Create(config, 48000, 1)); 264 std::optional<DelayEstimate> delay_estimate; 265 std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined( 266 kNumCaptureChannels); 267 std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(kNumCaptureChannels); 268 std::array<float, kBlockSize> x; 269 EchoPathVariability echo_path_variability( 270 false, EchoPathVariability::DelayAdjustment::kNone, false); 271 std::vector<SubtractorOutput> subtractor_output(kNumCaptureChannels); 272 for (auto& output : subtractor_output) { 273 output.Reset(); 274 output.s_refined.fill(100.f); 275 } 276 std::array<float, kBlockSize> y; 277 x.fill(0.f); 278 y.fill(0.f); 279 280 std::vector<std::vector<std::array<float, kFftLengthBy2Plus1>>> 281 frequency_response(kNumCaptureChannels, 282 std::vector<std::array<float, kFftLengthBy2Plus1>>( 283 kFilterLengthBlocks)); 284 for (auto& v_ch : frequency_response) { 285 for (auto& v : v_ch) { 286 v.fill(0.01f); 287 } 288 } 289 290 std::vector<std::vector<float>> impulse_response( 291 kNumCaptureChannels, 292 std::vector<float>( 293 GetTimeDomainLength(config.filter.refined.length_blocks), 0.f)); 294 295 // Verify that the filter delay for a converged filter is properly 296 // identified. 297 for (int k = 0; k < kFilterLengthBlocks; ++k) { 298 for (auto& ir : impulse_response) { 299 std::fill(ir.begin(), ir.end(), 0.f); 300 ir[k * kBlockSize + 1] = 1.f; 301 } 302 303 state.HandleEchoPathChange(echo_path_variability); 304 subtractor_output[0].ComputeMetrics(y); 305 state.Update(delay_estimate, frequency_response, impulse_response, 306 *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2, 307 subtractor_output); 308 } 309 } 310 311 } // namespace webrtc