erl_estimator_unittest.cc (3874B)
1 /* 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/aec3/erl_estimator.h" 12 13 #include <algorithm> 14 #include <array> 15 #include <cstddef> 16 #include <string> 17 #include <tuple> 18 #include <vector> 19 20 #include "modules/audio_processing/aec3/aec3_common.h" 21 #include "rtc_base/strings/string_builder.h" 22 #include "test/gtest.h" 23 24 namespace webrtc { 25 26 namespace { 27 std::string ProduceDebugText(size_t num_render_channels, 28 size_t num_capture_channels) { 29 StringBuilder ss; 30 ss << "Render channels: " << num_render_channels; 31 ss << ", Capture channels: " << num_capture_channels; 32 return ss.Release(); 33 } 34 35 void VerifyErl(const std::array<float, kFftLengthBy2Plus1>& erl, 36 float erl_time_domain, 37 float reference) { 38 std::for_each(erl.begin(), erl.end(), 39 [reference](float a) { EXPECT_NEAR(reference, a, 0.001); }); 40 EXPECT_NEAR(reference, erl_time_domain, 0.001); 41 } 42 43 } // namespace 44 45 class ErlEstimatorMultiChannel 46 : public ::testing::Test, 47 public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {}; 48 49 INSTANTIATE_TEST_SUITE_P(MultiChannel, 50 ErlEstimatorMultiChannel, 51 ::testing::Combine(::testing::Values(1, 2, 8), 52 ::testing::Values(1, 2, 8))); 53 54 // Verifies that the correct ERL estimates are achieved. 55 TEST_P(ErlEstimatorMultiChannel, Estimates) { 56 const size_t num_render_channels = std::get<0>(GetParam()); 57 const size_t num_capture_channels = std::get<1>(GetParam()); 58 SCOPED_TRACE(ProduceDebugText(num_render_channels, num_capture_channels)); 59 std::vector<std::array<float, kFftLengthBy2Plus1>> X2(num_render_channels); 60 for (auto& X2_ch : X2) { 61 X2_ch.fill(0.f); 62 } 63 std::vector<std::array<float, kFftLengthBy2Plus1>> Y2(num_capture_channels); 64 for (auto& Y2_ch : Y2) { 65 Y2_ch.fill(0.f); 66 } 67 std::vector<bool> converged_filters(num_capture_channels, false); 68 const size_t converged_idx = num_capture_channels - 1; 69 converged_filters[converged_idx] = true; 70 71 ErlEstimator estimator(0); 72 73 // Verifies that the ERL estimate is properly reduced to lower values. 74 for (auto& X2_ch : X2) { 75 X2_ch.fill(500 * 1000.f * 1000.f); 76 } 77 Y2[converged_idx].fill(10 * X2[0][0]); 78 for (size_t k = 0; k < 200; ++k) { 79 estimator.Update(converged_filters, X2, Y2); 80 } 81 VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f); 82 83 // Verifies that the ERL is not immediately increased when the ERL in the 84 // data increases. 85 Y2[converged_idx].fill(10000 * X2[0][0]); 86 for (size_t k = 0; k < 998; ++k) { 87 estimator.Update(converged_filters, X2, Y2); 88 } 89 VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 10.f); 90 91 // Verifies that the rate of increase is 3 dB. 92 estimator.Update(converged_filters, X2, Y2); 93 VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 20.f); 94 95 // Verifies that the maximum ERL is achieved when there are no low RLE 96 // estimates. 97 for (size_t k = 0; k < 1000; ++k) { 98 estimator.Update(converged_filters, X2, Y2); 99 } 100 VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f); 101 102 // Verifies that the ERL estimate is is not updated for low-level signals 103 for (auto& X2_ch : X2) { 104 X2_ch.fill(1000.f * 1000.f); 105 } 106 Y2[converged_idx].fill(10 * X2[0][0]); 107 for (size_t k = 0; k < 200; ++k) { 108 estimator.Update(converged_filters, X2, Y2); 109 } 110 VerifyErl(estimator.Erl(), estimator.ErlTimeDomain(), 1000.f); 111 } 112 } // namespace webrtc