echo_remover_unittest.cc (9144B)
1 /* 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/aec3/echo_remover.h" 12 13 #include <algorithm> 14 #include <cstddef> 15 #include <memory> 16 #include <numeric> 17 #include <optional> 18 #include <string> 19 #include <tuple> 20 #include <vector> 21 22 #include "api/audio/echo_canceller3_config.h" 23 #include "api/environment/environment.h" 24 #include "api/environment/environment_factory.h" 25 #include "modules/audio_processing/aec3/aec3_common.h" 26 #include "modules/audio_processing/aec3/block.h" 27 #include "modules/audio_processing/aec3/delay_estimate.h" 28 #include "modules/audio_processing/aec3/echo_path_variability.h" 29 #include "modules/audio_processing/aec3/render_delay_buffer.h" 30 #include "modules/audio_processing/test/echo_canceller_test_tools.h" 31 #include "rtc_base/checks.h" 32 #include "rtc_base/random.h" 33 #include "rtc_base/strings/string_builder.h" 34 #include "test/gtest.h" 35 36 namespace webrtc { 37 namespace { 38 std::string ProduceDebugText(int sample_rate_hz) { 39 StringBuilder ss; 40 ss << "Sample rate: " << sample_rate_hz; 41 return ss.Release(); 42 } 43 44 std::string ProduceDebugText(int sample_rate_hz, int delay) { 45 StringBuilder ss(ProduceDebugText(sample_rate_hz)); 46 ss << ", Delay: " << delay; 47 return ss.Release(); 48 } 49 50 } // namespace 51 52 class EchoRemoverMultiChannel 53 : public ::testing::Test, 54 public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; 55 56 INSTANTIATE_TEST_SUITE_P(MultiChannel, 57 EchoRemoverMultiChannel, 58 ::testing::Combine(::testing::Values(1, 2, 8), 59 ::testing::Values(1, 2, 8))); 60 61 // Verifies the basic API call sequence 62 TEST_P(EchoRemoverMultiChannel, BasicApiCalls) { 63 const size_t num_render_channels = std::get<0>(GetParam()); 64 const size_t num_capture_channels = std::get<1>(GetParam()); 65 const Environment env = CreateEnvironment(); 66 std::optional<DelayEstimate> delay_estimate; 67 for (auto rate : {16000, 32000, 48000}) { 68 SCOPED_TRACE(ProduceDebugText(rate)); 69 std::unique_ptr<EchoRemover> remover = EchoRemover::Create( 70 env, EchoCanceller3Config(), rate, num_render_channels, 71 num_capture_channels, /*neural_residual_echo_estimator=*/nullptr); 72 std::unique_ptr<RenderDelayBuffer> render_buffer(RenderDelayBuffer::Create( 73 EchoCanceller3Config(), rate, num_render_channels)); 74 75 Block render(NumBandsForRate(rate), num_render_channels); 76 Block capture(NumBandsForRate(rate), num_capture_channels); 77 for (size_t k = 0; k < 100; ++k) { 78 EchoPathVariability echo_path_variability( 79 k % 3 == 0 ? true : false, 80 k % 5 == 0 ? EchoPathVariability::DelayAdjustment::kNewDetectedDelay 81 : EchoPathVariability::DelayAdjustment::kNone, 82 false); 83 render_buffer->Insert(render); 84 render_buffer->PrepareCaptureProcessing(); 85 86 remover->ProcessCapture(echo_path_variability, k % 2 == 0 ? true : false, 87 delay_estimate, render_buffer->GetRenderBuffer(), 88 nullptr, &capture); 89 } 90 } 91 } 92 93 #if RTC_DCHECK_IS_ON && GTEST_HAS_DEATH_TEST && !defined(WEBRTC_ANDROID) 94 95 // Verifies the check for the samplerate. 96 // TODO(peah): Re-enable the test once the issue with memory leaks during DEATH 97 // tests on test bots has been fixed. 98 TEST(EchoRemoverDeathTest, DISABLED_WrongSampleRate) { 99 EXPECT_DEATH( 100 EchoRemover::Create(CreateEnvironment(), EchoCanceller3Config(), 8001, 1, 101 1, /*neural_residual_echo_estimator=*/nullptr), 102 ""); 103 } 104 105 // Verifies the check for the number of capture bands. 106 // TODO(peah): Re-enable the test once the issue with memory leaks during DEATH 107 // tests on test bots has been fixed.c 108 TEST(EchoRemoverDeathTest, DISABLED_WrongCaptureNumBands) { 109 const Environment env = CreateEnvironment(); 110 std::optional<DelayEstimate> delay_estimate; 111 for (auto rate : {16000, 32000, 48000}) { 112 SCOPED_TRACE(ProduceDebugText(rate)); 113 std::unique_ptr<EchoRemover> remover = 114 EchoRemover::Create(env, EchoCanceller3Config(), rate, 1, 1, 115 /*neural_residual_echo_estimator=*/nullptr); 116 std::unique_ptr<RenderDelayBuffer> render_buffer( 117 RenderDelayBuffer::Create(EchoCanceller3Config(), rate, 1)); 118 Block capture(NumBandsForRate(rate == 48000 ? 16000 : rate + 16000), 1); 119 EchoPathVariability echo_path_variability( 120 false, EchoPathVariability::DelayAdjustment::kNone, false); 121 EXPECT_DEATH(remover->ProcessCapture( 122 echo_path_variability, false, delay_estimate, 123 render_buffer->GetRenderBuffer(), nullptr, &capture), 124 ""); 125 } 126 } 127 128 // Verifies the check for non-null capture block. 129 TEST(EchoRemoverDeathTest, NullCapture) { 130 std::optional<DelayEstimate> delay_estimate; 131 std::unique_ptr<EchoRemover> remover = 132 EchoRemover::Create(CreateEnvironment(), EchoCanceller3Config(), 16000, 1, 133 1, /*neural_residual_echo_estimator=*/nullptr); 134 std::unique_ptr<RenderDelayBuffer> render_buffer( 135 RenderDelayBuffer::Create(EchoCanceller3Config(), 16000, 1)); 136 EchoPathVariability echo_path_variability( 137 false, EchoPathVariability::DelayAdjustment::kNone, false); 138 EXPECT_DEATH(remover->ProcessCapture( 139 echo_path_variability, false, delay_estimate, 140 render_buffer->GetRenderBuffer(), nullptr, nullptr), 141 ""); 142 } 143 144 #endif 145 146 // Performs a sanity check that the echo_remover is able to properly 147 // remove echoes. 148 TEST(EchoRemover, BasicEchoRemoval) { 149 constexpr int kNumBlocksToProcess = 500; 150 const Environment env = CreateEnvironment(); 151 Random random_generator(42U); 152 std::optional<DelayEstimate> delay_estimate; 153 for (size_t num_channels : {1, 2, 4}) { 154 for (auto rate : {16000, 32000, 48000}) { 155 Block x(NumBandsForRate(rate), num_channels); 156 Block y(NumBandsForRate(rate), num_channels); 157 EchoPathVariability echo_path_variability( 158 false, EchoPathVariability::DelayAdjustment::kNone, false); 159 for (size_t delay_samples : {0, 64, 150, 200, 301}) { 160 SCOPED_TRACE(ProduceDebugText(rate, delay_samples)); 161 EchoCanceller3Config config; 162 std::unique_ptr<EchoRemover> remover = 163 EchoRemover::Create(env, config, rate, num_channels, num_channels, 164 /*neural_residual_echo_estimator=*/nullptr); 165 std::unique_ptr<RenderDelayBuffer> render_buffer( 166 RenderDelayBuffer::Create(config, rate, num_channels)); 167 render_buffer->AlignFromDelay(delay_samples / kBlockSize); 168 169 std::vector<std::vector<std::unique_ptr<DelayBuffer<float>>>> 170 delay_buffers(x.NumBands()); 171 for (size_t band = 0; band < delay_buffers.size(); ++band) { 172 delay_buffers[band].resize(x.NumChannels()); 173 } 174 175 for (int band = 0; band < x.NumBands(); ++band) { 176 for (int channel = 0; channel < x.NumChannels(); ++channel) { 177 delay_buffers[band][channel].reset( 178 new DelayBuffer<float>(delay_samples)); 179 } 180 } 181 182 float input_energy = 0.f; 183 float output_energy = 0.f; 184 for (int k = 0; k < kNumBlocksToProcess; ++k) { 185 const bool silence = k < 100 || (k % 100 >= 10); 186 187 for (int band = 0; band < x.NumBands(); ++band) { 188 for (int channel = 0; channel < x.NumChannels(); ++channel) { 189 if (silence) { 190 std::fill(x.begin(band, channel), x.end(band, channel), 0.f); 191 } else { 192 RandomizeSampleVector(&random_generator, x.View(band, channel)); 193 } 194 delay_buffers[band][channel]->Delay(x.View(band, channel), 195 y.View(band, channel)); 196 } 197 } 198 199 if (k > kNumBlocksToProcess / 2) { 200 input_energy = std::inner_product( 201 y.begin(/*band=*/0, /*channel=*/0), 202 y.end(/*band=*/0, /*channel=*/0), 203 y.begin(/*band=*/0, /*channel=*/0), input_energy); 204 } 205 206 render_buffer->Insert(x); 207 render_buffer->PrepareCaptureProcessing(); 208 209 remover->ProcessCapture(echo_path_variability, false, delay_estimate, 210 render_buffer->GetRenderBuffer(), nullptr, 211 &y); 212 213 if (k > kNumBlocksToProcess / 2) { 214 output_energy = std::inner_product( 215 y.begin(/*band=*/0, /*channel=*/0), 216 y.end(/*band=*/0, /*channel=*/0), 217 y.begin(/*band=*/0, /*channel=*/0), output_energy); 218 } 219 } 220 EXPECT_GT(input_energy, 10.f * output_energy); 221 } 222 } 223 } 224 } 225 226 } // namespace webrtc