vad_wrapper_unittest.cc (6938B)
1 /* 2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/agc2/vad_wrapper.h" 12 13 #include <limits> 14 #include <memory> 15 #include <tuple> 16 #include <utility> 17 #include <vector> 18 19 #include "api/array_view.h" 20 #include "api/audio/audio_view.h" 21 #include "modules/audio_processing/agc2/agc2_common.h" 22 #include "rtc_base/checks.h" 23 #include "rtc_base/numerics/safe_compare.h" 24 #include "test/gmock.h" 25 #include "test/gtest.h" 26 27 namespace webrtc { 28 namespace { 29 30 using ::testing::AnyNumber; 31 using ::testing::Return; 32 using ::testing::ReturnRoundRobin; 33 using ::testing::Truly; 34 35 constexpr int kNumFramesPerSecond = 100; 36 37 constexpr int kNoVadPeriodicReset = 38 kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs); 39 40 constexpr int kSampleRate8kHz = 8000; 41 42 class MockVad : public VoiceActivityDetectorWrapper::MonoVad { 43 public: 44 MOCK_METHOD(int, SampleRateHz, (), (const, override)); 45 MOCK_METHOD(void, Reset, (), (override)); 46 MOCK_METHOD(float, Analyze, (ArrayView<const float> frame), (override)); 47 }; 48 49 // Checks that the ctor and `Initialize()` read the sample rate of the wrapped 50 // VAD. 51 TEST(GainController2VoiceActivityDetectorWrapper, CtorAndInitReadSampleRate) { 52 auto vad = std::make_unique<MockVad>(); 53 EXPECT_CALL(*vad, SampleRateHz) 54 .Times(1) 55 .WillRepeatedly(Return(kSampleRate8kHz)); 56 EXPECT_CALL(*vad, Reset).Times(AnyNumber()); 57 auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>( 58 kNoVadPeriodicReset, std::move(vad), kSampleRate8kHz); 59 } 60 61 // Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that 62 // repeatedly returns the next value from `speech_probabilities` and that 63 // restarts from the beginning when after the last element is returned. 64 std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper( 65 int vad_reset_period_ms, 66 int sample_rate_hz, 67 const std::vector<float>& speech_probabilities, 68 int expected_vad_reset_calls) { 69 auto vad = std::make_unique<MockVad>(); 70 EXPECT_CALL(*vad, SampleRateHz) 71 .Times(AnyNumber()) 72 .WillRepeatedly(Return(sample_rate_hz)); 73 if (expected_vad_reset_calls >= 0) { 74 EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls); 75 } 76 EXPECT_CALL(*vad, Analyze) 77 .Times(AnyNumber()) 78 .WillRepeatedly(ReturnRoundRobin(speech_probabilities)); 79 return std::make_unique<VoiceActivityDetectorWrapper>( 80 vad_reset_period_ms, std::move(vad), kSampleRate8kHz); 81 } 82 83 // 10 ms mono frame. 84 struct FrameWithView { 85 // Ctor. Initializes the frame samples with `value`. 86 explicit FrameWithView(int sample_rate_hz) 87 : samples(CheckedDivExact(sample_rate_hz, kNumFramesPerSecond), 0.0f), 88 view(samples.data(), samples.size(), /*num_channels=*/1) {} 89 std::vector<float> samples; 90 const DeinterleavedView<const float> view; 91 }; 92 93 // Checks that the expected speech probabilities are returned. 94 TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) { 95 const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f, 96 0.44f, 0.525f, 0.858f, 0.314f, 97 0.653f, 0.965f, 0.413f, 0.0f}; 98 auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz, 99 speech_probabilities, 100 /*expected_vad_reset_calls=*/1); 101 FrameWithView frame(kSampleRate8kHz); 102 for (int i = 0; SafeLt(i, speech_probabilities.size()); ++i) { 103 SCOPED_TRACE(i); 104 EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view)); 105 } 106 } 107 108 // Checks that the VAD is not periodically reset. 109 TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) { 110 constexpr int kNumFrames = 19; 111 auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz, 112 /*speech_probabilities=*/{1.0f}, 113 /*expected_vad_reset_calls=*/1); 114 FrameWithView frame(kSampleRate8kHz); 115 for (int i = 0; i < kNumFrames; ++i) { 116 vad_wrapper->Analyze(frame.view); 117 } 118 } 119 120 class VadPeriodResetParametrization 121 : public ::testing::TestWithParam<std::tuple<int, int>> { 122 protected: 123 int num_frames() const { return std::get<0>(GetParam()); } 124 int vad_reset_period_frames() const { return std::get<1>(GetParam()); } 125 }; 126 127 // Checks that the VAD is periodically reset with the expected period. 128 TEST_P(VadPeriodResetParametrization, VadPeriodicReset) { 129 auto vad_wrapper = CreateMockVadWrapper( 130 /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs, 131 kSampleRate8kHz, 132 /*speech_probabilities=*/{1.0f}, 133 /*expected_vad_reset_calls=*/1 + 134 num_frames() / vad_reset_period_frames()); 135 FrameWithView frame(kSampleRate8kHz); 136 for (int i = 0; i < num_frames(); ++i) { 137 vad_wrapper->Analyze(frame.view); 138 } 139 } 140 141 INSTANTIATE_TEST_SUITE_P(GainController2VoiceActivityDetectorWrapper, 142 VadPeriodResetParametrization, 143 ::testing::Combine(::testing::Values(1, 19, 123), 144 ::testing::Values(2, 5, 20, 53))); 145 146 class VadResamplingParametrization 147 : public ::testing::TestWithParam<std::tuple<int, int>> { 148 protected: 149 int input_sample_rate_hz() const { return std::get<0>(GetParam()); } 150 int vad_sample_rate_hz() const { return std::get<1>(GetParam()); } 151 }; 152 153 // Checks that regardless of the input audio sample rate, the wrapped VAD 154 // analyzes frames having the expected size, that is according to its internal 155 // sample rate. 156 TEST_P(VadResamplingParametrization, CheckResampledFrameSize) { 157 auto vad = std::make_unique<MockVad>(); 158 EXPECT_CALL(*vad, SampleRateHz) 159 .Times(AnyNumber()) 160 .WillRepeatedly(Return(vad_sample_rate_hz())); 161 EXPECT_CALL(*vad, Reset).Times(1); 162 EXPECT_CALL(*vad, Analyze(Truly([this](ArrayView<const float> frame) { 163 return SafeEq(frame.size(), 164 CheckedDivExact(vad_sample_rate_hz(), kNumFramesPerSecond)); 165 }))).Times(1); 166 auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>( 167 kNoVadPeriodicReset, std::move(vad), input_sample_rate_hz()); 168 FrameWithView frame(input_sample_rate_hz()); 169 vad_wrapper->Analyze(frame.view); 170 } 171 172 INSTANTIATE_TEST_SUITE_P( 173 GainController2VoiceActivityDetectorWrapper, 174 VadResamplingParametrization, 175 ::testing::Combine(::testing::Values(8000, 16000, 44100, 48000), 176 ::testing::Values(6000, 8000, 12000, 16000, 24000))); 177 178 } // namespace 179 } // namespace webrtc