tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

vad_wrapper_unittest.cc (6938B)


      1 /*
      2 *  Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #include "modules/audio_processing/agc2/vad_wrapper.h"
     12 
     13 #include <limits>
     14 #include <memory>
     15 #include <tuple>
     16 #include <utility>
     17 #include <vector>
     18 
     19 #include "api/array_view.h"
     20 #include "api/audio/audio_view.h"
     21 #include "modules/audio_processing/agc2/agc2_common.h"
     22 #include "rtc_base/checks.h"
     23 #include "rtc_base/numerics/safe_compare.h"
     24 #include "test/gmock.h"
     25 #include "test/gtest.h"
     26 
     27 namespace webrtc {
     28 namespace {
     29 
     30 using ::testing::AnyNumber;
     31 using ::testing::Return;
     32 using ::testing::ReturnRoundRobin;
     33 using ::testing::Truly;
     34 
     35 constexpr int kNumFramesPerSecond = 100;
     36 
     37 constexpr int kNoVadPeriodicReset =
     38    kFrameDurationMs * (std::numeric_limits<int>::max() / kFrameDurationMs);
     39 
     40 constexpr int kSampleRate8kHz = 8000;
     41 
     42 class MockVad : public VoiceActivityDetectorWrapper::MonoVad {
     43 public:
     44  MOCK_METHOD(int, SampleRateHz, (), (const, override));
     45  MOCK_METHOD(void, Reset, (), (override));
     46  MOCK_METHOD(float, Analyze, (ArrayView<const float> frame), (override));
     47 };
     48 
     49 // Checks that the ctor and `Initialize()` read the sample rate of the wrapped
     50 // VAD.
     51 TEST(GainController2VoiceActivityDetectorWrapper, CtorAndInitReadSampleRate) {
     52  auto vad = std::make_unique<MockVad>();
     53  EXPECT_CALL(*vad, SampleRateHz)
     54      .Times(1)
     55      .WillRepeatedly(Return(kSampleRate8kHz));
     56  EXPECT_CALL(*vad, Reset).Times(AnyNumber());
     57  auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
     58      kNoVadPeriodicReset, std::move(vad), kSampleRate8kHz);
     59 }
     60 
     61 // Creates a `VoiceActivityDetectorWrapper` injecting a mock VAD that
     62 // repeatedly returns the next value from `speech_probabilities` and that
     63 // restarts from the beginning when after the last element is returned.
     64 std::unique_ptr<VoiceActivityDetectorWrapper> CreateMockVadWrapper(
     65    int vad_reset_period_ms,
     66    int sample_rate_hz,
     67    const std::vector<float>& speech_probabilities,
     68    int expected_vad_reset_calls) {
     69  auto vad = std::make_unique<MockVad>();
     70  EXPECT_CALL(*vad, SampleRateHz)
     71      .Times(AnyNumber())
     72      .WillRepeatedly(Return(sample_rate_hz));
     73  if (expected_vad_reset_calls >= 0) {
     74    EXPECT_CALL(*vad, Reset).Times(expected_vad_reset_calls);
     75  }
     76  EXPECT_CALL(*vad, Analyze)
     77      .Times(AnyNumber())
     78      .WillRepeatedly(ReturnRoundRobin(speech_probabilities));
     79  return std::make_unique<VoiceActivityDetectorWrapper>(
     80      vad_reset_period_ms, std::move(vad), kSampleRate8kHz);
     81 }
     82 
     83 // 10 ms mono frame.
     84 struct FrameWithView {
     85  // Ctor. Initializes the frame samples with `value`.
     86  explicit FrameWithView(int sample_rate_hz)
     87      : samples(CheckedDivExact(sample_rate_hz, kNumFramesPerSecond), 0.0f),
     88        view(samples.data(), samples.size(), /*num_channels=*/1) {}
     89  std::vector<float> samples;
     90  const DeinterleavedView<const float> view;
     91 };
     92 
     93 // Checks that the expected speech probabilities are returned.
     94 TEST(GainController2VoiceActivityDetectorWrapper, CheckSpeechProbabilities) {
     95  const std::vector<float> speech_probabilities{0.709f, 0.484f, 0.882f, 0.167f,
     96                                                0.44f,  0.525f, 0.858f, 0.314f,
     97                                                0.653f, 0.965f, 0.413f, 0.0f};
     98  auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
     99                                          speech_probabilities,
    100                                          /*expected_vad_reset_calls=*/1);
    101  FrameWithView frame(kSampleRate8kHz);
    102  for (int i = 0; SafeLt(i, speech_probabilities.size()); ++i) {
    103    SCOPED_TRACE(i);
    104    EXPECT_EQ(speech_probabilities[i], vad_wrapper->Analyze(frame.view));
    105  }
    106 }
    107 
    108 // Checks that the VAD is not periodically reset.
    109 TEST(GainController2VoiceActivityDetectorWrapper, VadNoPeriodicReset) {
    110  constexpr int kNumFrames = 19;
    111  auto vad_wrapper = CreateMockVadWrapper(kNoVadPeriodicReset, kSampleRate8kHz,
    112                                          /*speech_probabilities=*/{1.0f},
    113                                          /*expected_vad_reset_calls=*/1);
    114  FrameWithView frame(kSampleRate8kHz);
    115  for (int i = 0; i < kNumFrames; ++i) {
    116    vad_wrapper->Analyze(frame.view);
    117  }
    118 }
    119 
    120 class VadPeriodResetParametrization
    121    : public ::testing::TestWithParam<std::tuple<int, int>> {
    122 protected:
    123  int num_frames() const { return std::get<0>(GetParam()); }
    124  int vad_reset_period_frames() const { return std::get<1>(GetParam()); }
    125 };
    126 
    127 // Checks that the VAD is periodically reset with the expected period.
    128 TEST_P(VadPeriodResetParametrization, VadPeriodicReset) {
    129  auto vad_wrapper = CreateMockVadWrapper(
    130      /*vad_reset_period_ms=*/vad_reset_period_frames() * kFrameDurationMs,
    131      kSampleRate8kHz,
    132      /*speech_probabilities=*/{1.0f},
    133      /*expected_vad_reset_calls=*/1 +
    134          num_frames() / vad_reset_period_frames());
    135  FrameWithView frame(kSampleRate8kHz);
    136  for (int i = 0; i < num_frames(); ++i) {
    137    vad_wrapper->Analyze(frame.view);
    138  }
    139 }
    140 
    141 INSTANTIATE_TEST_SUITE_P(GainController2VoiceActivityDetectorWrapper,
    142                         VadPeriodResetParametrization,
    143                         ::testing::Combine(::testing::Values(1, 19, 123),
    144                                            ::testing::Values(2, 5, 20, 53)));
    145 
    146 class VadResamplingParametrization
    147    : public ::testing::TestWithParam<std::tuple<int, int>> {
    148 protected:
    149  int input_sample_rate_hz() const { return std::get<0>(GetParam()); }
    150  int vad_sample_rate_hz() const { return std::get<1>(GetParam()); }
    151 };
    152 
    153 // Checks that regardless of the input audio sample rate, the wrapped VAD
    154 // analyzes frames having the expected size, that is according to its internal
    155 // sample rate.
    156 TEST_P(VadResamplingParametrization, CheckResampledFrameSize) {
    157  auto vad = std::make_unique<MockVad>();
    158  EXPECT_CALL(*vad, SampleRateHz)
    159      .Times(AnyNumber())
    160      .WillRepeatedly(Return(vad_sample_rate_hz()));
    161  EXPECT_CALL(*vad, Reset).Times(1);
    162  EXPECT_CALL(*vad, Analyze(Truly([this](ArrayView<const float> frame) {
    163    return SafeEq(frame.size(),
    164                  CheckedDivExact(vad_sample_rate_hz(), kNumFramesPerSecond));
    165  }))).Times(1);
    166  auto vad_wrapper = std::make_unique<VoiceActivityDetectorWrapper>(
    167      kNoVadPeriodicReset, std::move(vad), input_sample_rate_hz());
    168  FrameWithView frame(input_sample_rate_hz());
    169  vad_wrapper->Analyze(frame.view);
    170 }
    171 
    172 INSTANTIATE_TEST_SUITE_P(
    173    GainController2VoiceActivityDetectorWrapper,
    174    VadResamplingParametrization,
    175    ::testing::Combine(::testing::Values(8000, 16000, 44100, 48000),
    176                       ::testing::Values(6000, 8000, 12000, 16000, 24000)));
    177 
    178 }  // namespace
    179 }  // namespace webrtc