subtractor_unittest.cc (14115B)
1 /* 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/aec3/subtractor.h" 12 13 #include <algorithm> 14 #include <array> 15 #include <cstddef> 16 #include <memory> 17 #include <numeric> 18 #include <optional> 19 #include <string> 20 #include <tuple> 21 #include <vector> 22 23 #include "api/array_view.h" 24 #include "api/audio/echo_canceller3_config.h" 25 #include "api/environment/environment.h" 26 #include "api/environment/environment_factory.h" 27 #include "modules/audio_processing/aec3/aec3_common.h" 28 #include "modules/audio_processing/aec3/aec3_fft.h" 29 #include "modules/audio_processing/aec3/aec_state.h" 30 #include "modules/audio_processing/aec3/block.h" 31 #include "modules/audio_processing/aec3/delay_estimate.h" 32 #include "modules/audio_processing/aec3/echo_path_variability.h" 33 #include "modules/audio_processing/aec3/render_delay_buffer.h" 34 #include "modules/audio_processing/aec3/render_signal_analyzer.h" 35 #include "modules/audio_processing/aec3/subtractor_output.h" 36 #include "modules/audio_processing/test/echo_canceller_test_tools.h" 37 #include "modules/audio_processing/utility/cascaded_biquad_filter.h" 38 #include "rtc_base/checks.h" 39 #include "rtc_base/random.h" 40 #include "rtc_base/strings/string_builder.h" 41 #include "test/gtest.h" 42 43 namespace webrtc { 44 namespace { 45 46 std::vector<float> RunSubtractorTest( 47 const Environment& env, 48 size_t num_render_channels, 49 size_t num_capture_channels, 50 int num_blocks_to_process, 51 int delay_samples, 52 int refined_filter_length_blocks, 53 int coarse_filter_length_blocks, 54 bool uncorrelated_inputs, 55 const std::vector<int>& blocks_with_echo_path_changes) { 56 ApmDataDumper data_dumper(42); 57 constexpr int kSampleRateHz = 48000; 58 constexpr size_t kNumBands = NumBandsForRate(kSampleRateHz); 59 EchoCanceller3Config config; 60 config.filter.refined.length_blocks = refined_filter_length_blocks; 61 config.filter.coarse.length_blocks = coarse_filter_length_blocks; 62 63 Subtractor subtractor(env, config, num_render_channels, num_capture_channels, 64 &data_dumper, DetectOptimization()); 65 std::optional<DelayEstimate> delay_estimate; 66 Block x(kNumBands, num_render_channels); 67 Block y(/*num_bands=*/1, num_capture_channels); 68 std::array<float, kBlockSize> x_old; 69 std::vector<SubtractorOutput> output(num_capture_channels); 70 config.delay.default_delay = 1; 71 std::unique_ptr<RenderDelayBuffer> render_delay_buffer( 72 RenderDelayBuffer::Create(config, kSampleRateHz, num_render_channels)); 73 RenderSignalAnalyzer render_signal_analyzer(config); 74 Random random_generator(42U); 75 Aec3Fft fft; 76 std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels); 77 std::vector<std::array<float, kFftLengthBy2Plus1>> E2_refined( 78 num_capture_channels); 79 std::array<float, kFftLengthBy2Plus1> E2_coarse; 80 AecState aec_state(env, config, num_capture_channels); 81 x_old.fill(0.f); 82 for (auto& Y2_ch : Y2) { 83 Y2_ch.fill(0.f); 84 } 85 for (auto& E2_refined_ch : E2_refined) { 86 E2_refined_ch.fill(0.f); 87 } 88 E2_coarse.fill(0.f); 89 90 std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> delay_buffer( 91 num_capture_channels); 92 for (size_t capture_ch = 0; capture_ch < num_capture_channels; ++capture_ch) { 93 delay_buffer[capture_ch].resize(num_render_channels); 94 for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) { 95 delay_buffer[capture_ch][render_ch] = 96 std::make_unique<DelayBuffer<float>>(delay_samples); 97 } 98 } 99 100 // [B,A] = butter(2,100/8000,'high') 101 constexpr std::array<CascadedBiQuadFilter::BiQuadCoefficients, 1> 102 kHighPassFilterCoefficients = {{ 103 {.b = {0.97261f, -1.94523f, 0.97261f}, .a = {-1.94448f, 0.94598f}}, 104 }}; 105 std::vector<std::unique_ptr<CascadedBiQuadFilter>> x_hp_filter( 106 num_render_channels); 107 for (size_t ch = 0; ch < num_render_channels; ++ch) { 108 x_hp_filter[ch] = std::make_unique<CascadedBiQuadFilter>( 109 ArrayView<const CascadedBiQuadFilter::BiQuadCoefficients>( 110 kHighPassFilterCoefficients)); 111 } 112 std::vector<std::unique_ptr<CascadedBiQuadFilter>> y_hp_filter( 113 num_capture_channels); 114 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 115 y_hp_filter[ch] = std::make_unique<CascadedBiQuadFilter>( 116 ArrayView<const CascadedBiQuadFilter::BiQuadCoefficients>( 117 kHighPassFilterCoefficients)); 118 } 119 120 for (int block_num = 0; block_num < num_blocks_to_process; ++block_num) { 121 for (size_t render_ch = 0; render_ch < num_render_channels; ++render_ch) { 122 RandomizeSampleVector(&random_generator, x.View(/*band=*/0, render_ch)); 123 } 124 if (uncorrelated_inputs) { 125 for (size_t capture_ch = 0; capture_ch < num_capture_channels; 126 ++capture_ch) { 127 RandomizeSampleVector(&random_generator, 128 y.View(/*band=*/0, capture_ch)); 129 } 130 } else { 131 for (size_t capture_ch = 0; capture_ch < num_capture_channels; 132 ++capture_ch) { 133 ArrayView<float> y_view = y.View(/*band=*/0, capture_ch); 134 for (size_t render_ch = 0; render_ch < num_render_channels; 135 ++render_ch) { 136 std::array<float, kBlockSize> y_channel; 137 delay_buffer[capture_ch][render_ch]->Delay( 138 x.View(/*band=*/0, render_ch), y_channel); 139 for (size_t k = 0; k < kBlockSize; ++k) { 140 y_view[k] += y_channel[k] / num_render_channels; 141 } 142 } 143 } 144 } 145 for (size_t ch = 0; ch < num_render_channels; ++ch) { 146 x_hp_filter[ch]->Process(x.View(/*band=*/0, ch)); 147 } 148 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 149 y_hp_filter[ch]->Process(y.View(/*band=*/0, ch)); 150 } 151 152 render_delay_buffer->Insert(x); 153 if (block_num == 0) { 154 render_delay_buffer->Reset(); 155 } 156 render_delay_buffer->PrepareCaptureProcessing(); 157 render_signal_analyzer.Update(*render_delay_buffer->GetRenderBuffer(), 158 aec_state.MinDirectPathFilterDelay()); 159 160 // Handle echo path changes. 161 if (std::find(blocks_with_echo_path_changes.begin(), 162 blocks_with_echo_path_changes.end(), 163 block_num) != blocks_with_echo_path_changes.end()) { 164 subtractor.HandleEchoPathChange(EchoPathVariability( 165 true, EchoPathVariability::DelayAdjustment::kNewDetectedDelay, 166 false)); 167 } 168 subtractor.Process(*render_delay_buffer->GetRenderBuffer(), y, 169 render_signal_analyzer, aec_state, output); 170 171 aec_state.HandleEchoPathChange(EchoPathVariability( 172 false, EchoPathVariability::DelayAdjustment::kNone, false)); 173 aec_state.Update(delay_estimate, subtractor.FilterFrequencyResponses(), 174 subtractor.FilterImpulseResponses(), 175 *render_delay_buffer->GetRenderBuffer(), E2_refined, Y2, 176 output); 177 } 178 179 std::vector<float> results(num_capture_channels); 180 for (size_t ch = 0; ch < num_capture_channels; ++ch) { 181 const float output_power = std::inner_product( 182 output[ch].e_refined.begin(), output[ch].e_refined.end(), 183 output[ch].e_refined.begin(), 0.f); 184 const float y_power = 185 std::inner_product(y.begin(/*band=*/0, ch), y.end(/*band=*/0, ch), 186 y.begin(/*band=*/0, ch), 0.f); 187 if (y_power == 0.f) { 188 ADD_FAILURE(); 189 results[ch] = -1.f; 190 } 191 results[ch] = output_power / y_power; 192 } 193 return results; 194 } 195 196 std::string ProduceDebugText(size_t num_render_channels, 197 size_t num_capture_channels, 198 size_t delay, 199 int filter_length_blocks) { 200 StringBuilder ss; 201 ss << "delay: " << delay << ", "; 202 ss << "filter_length_blocks:" << filter_length_blocks << ", "; 203 ss << "num_render_channels:" << num_render_channels << ", "; 204 ss << "num_capture_channels:" << num_capture_channels; 205 return ss.Release(); 206 } 207 208 } // namespace 209 210 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) 211 212 // Verifies that the check for non data dumper works. 213 TEST(SubtractorDeathTest, NullDataDumper) { 214 EXPECT_DEATH(Subtractor(CreateEnvironment(), EchoCanceller3Config(), 1, 1, 215 nullptr, DetectOptimization()), 216 ""); 217 } 218 219 #endif 220 221 // Verifies that the subtractor is able to converge on correlated data. 222 TEST(Subtractor, Convergence) { 223 const Environment env = CreateEnvironment(); 224 std::vector<int> blocks_with_echo_path_changes; 225 for (size_t filter_length_blocks : {12, 20, 30}) { 226 for (size_t delay_samples : {0, 64, 150, 200, 301}) { 227 SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks)); 228 std::vector<float> echo_to_nearend_powers = RunSubtractorTest( 229 env, 1, 1, 2500, delay_samples, filter_length_blocks, 230 filter_length_blocks, false, blocks_with_echo_path_changes); 231 232 for (float echo_to_nearend_power : echo_to_nearend_powers) { 233 EXPECT_GT(0.1f, echo_to_nearend_power); 234 } 235 } 236 } 237 } 238 239 // Verifies that the subtractor is able to handle the case when the refined 240 // filter is longer than the coarse filter. 241 TEST(Subtractor, RefinedFilterLongerThanCoarseFilter) { 242 const Environment env = CreateEnvironment(); 243 std::vector<int> blocks_with_echo_path_changes; 244 std::vector<float> echo_to_nearend_powers = RunSubtractorTest( 245 env, 1, 1, 400, 64, 20, 15, false, blocks_with_echo_path_changes); 246 for (float echo_to_nearend_power : echo_to_nearend_powers) { 247 EXPECT_GT(0.5f, echo_to_nearend_power); 248 } 249 } 250 251 // Verifies that the subtractor is able to handle the case when the coarse 252 // filter is longer than the refined filter. 253 TEST(Subtractor, CoarseFilterLongerThanRefinedFilter) { 254 const Environment env = CreateEnvironment(); 255 std::vector<int> blocks_with_echo_path_changes; 256 std::vector<float> echo_to_nearend_powers = RunSubtractorTest( 257 env, 1, 1, 400, 64, 15, 20, false, blocks_with_echo_path_changes); 258 for (float echo_to_nearend_power : echo_to_nearend_powers) { 259 EXPECT_GT(0.5f, echo_to_nearend_power); 260 } 261 } 262 263 // Verifies that the subtractor does not converge on uncorrelated signals. 264 TEST(Subtractor, NonConvergenceOnUncorrelatedSignals) { 265 const Environment env = CreateEnvironment(); 266 std::vector<int> blocks_with_echo_path_changes; 267 for (size_t filter_length_blocks : {12, 20, 30}) { 268 for (size_t delay_samples : {0, 64, 150, 200, 301}) { 269 SCOPED_TRACE(ProduceDebugText(1, 1, delay_samples, filter_length_blocks)); 270 271 std::vector<float> echo_to_nearend_powers = RunSubtractorTest( 272 env, 1, 1, 3000, delay_samples, filter_length_blocks, 273 filter_length_blocks, true, blocks_with_echo_path_changes); 274 for (float echo_to_nearend_power : echo_to_nearend_powers) { 275 EXPECT_NEAR(1.f, echo_to_nearend_power, 0.1); 276 } 277 } 278 } 279 } 280 281 class SubtractorMultiChannelUpToEightRender 282 : public ::testing::Test, 283 public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; 284 285 #if defined(NDEBUG) 286 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel, 287 SubtractorMultiChannelUpToEightRender, 288 ::testing::Combine(::testing::Values(1, 2, 8), 289 ::testing::Values(1, 2, 4))); 290 #else 291 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel, 292 SubtractorMultiChannelUpToEightRender, 293 ::testing::Combine(::testing::Values(1, 2), 294 ::testing::Values(1, 2))); 295 #endif 296 297 // Verifies that the subtractor is able to converge on correlated data. 298 TEST_P(SubtractorMultiChannelUpToEightRender, Convergence) { 299 const size_t num_render_channels = std::get<0>(GetParam()); 300 const size_t num_capture_channels = std::get<1>(GetParam()); 301 const Environment env = CreateEnvironment(); 302 303 std::vector<int> blocks_with_echo_path_changes; 304 size_t num_blocks_to_process = 2500 * num_render_channels; 305 std::vector<float> echo_to_nearend_powers = RunSubtractorTest( 306 env, num_render_channels, num_capture_channels, num_blocks_to_process, 64, 307 20, 20, false, blocks_with_echo_path_changes); 308 309 for (float echo_to_nearend_power : echo_to_nearend_powers) { 310 EXPECT_GT(0.1f, echo_to_nearend_power); 311 } 312 } 313 314 class SubtractorMultiChannelUpToFourRender 315 : public ::testing::Test, 316 public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; 317 318 #if defined(NDEBUG) 319 INSTANTIATE_TEST_SUITE_P(NonDebugMultiChannel, 320 SubtractorMultiChannelUpToFourRender, 321 ::testing::Combine(::testing::Values(1, 2, 4), 322 ::testing::Values(1, 2, 4))); 323 #else 324 INSTANTIATE_TEST_SUITE_P(DebugMultiChannel, 325 SubtractorMultiChannelUpToFourRender, 326 ::testing::Combine(::testing::Values(1, 2), 327 ::testing::Values(1, 2))); 328 #endif 329 330 // Verifies that the subtractor does not converge on uncorrelated signals. 331 TEST_P(SubtractorMultiChannelUpToFourRender, 332 NonConvergenceOnUncorrelatedSignals) { 333 const size_t num_render_channels = std::get<0>(GetParam()); 334 const size_t num_capture_channels = std::get<1>(GetParam()); 335 const Environment env = CreateEnvironment(); 336 337 std::vector<int> blocks_with_echo_path_changes; 338 size_t num_blocks_to_process = 5000 * num_render_channels; 339 std::vector<float> echo_to_nearend_powers = RunSubtractorTest( 340 env, num_render_channels, num_capture_channels, num_blocks_to_process, 64, 341 20, 20, true, blocks_with_echo_path_changes); 342 for (float echo_to_nearend_power : echo_to_nearend_powers) { 343 EXPECT_LT(.8f, echo_to_nearend_power); 344 EXPECT_NEAR(1.f, echo_to_nearend_power, 0.25f); 345 } 346 } 347 } // namespace webrtc