vad_wrapper.cc (3948B)
1 /* 2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/agc2/vad_wrapper.h" 12 13 #include <array> 14 #include <memory> 15 #include <utility> 16 17 #include "api/audio/audio_view.h" 18 #include "common_audio/resampler/include/push_resampler.h" 19 #include "modules/audio_processing/agc2/agc2_common.h" 20 #include "modules/audio_processing/agc2/cpu_features.h" 21 #include "modules/audio_processing/agc2/rnn_vad/common.h" 22 #include "modules/audio_processing/agc2/rnn_vad/features_extraction.h" 23 #include "modules/audio_processing/agc2/rnn_vad/rnn.h" 24 #include "rtc_base/checks.h" 25 26 namespace webrtc { 27 namespace { 28 29 constexpr int kNumFramesPerSecond = 100; 30 31 class MonoVadImpl : public VoiceActivityDetectorWrapper::MonoVad { 32 public: 33 explicit MonoVadImpl(const AvailableCpuFeatures& cpu_features) 34 : features_extractor_(cpu_features), rnn_vad_(cpu_features) {} 35 MonoVadImpl(const MonoVadImpl&) = delete; 36 MonoVadImpl& operator=(const MonoVadImpl&) = delete; 37 ~MonoVadImpl() override = default; 38 39 int SampleRateHz() const override { return rnn_vad::kSampleRate24kHz; } 40 void Reset() override { rnn_vad_.Reset(); } 41 float Analyze(MonoView<const float> frame) override { 42 RTC_DCHECK_EQ(frame.size(), rnn_vad::kFrameSize10ms24kHz); 43 std::array<float, rnn_vad::kFeatureVectorSize> feature_vector; 44 const bool is_silence = features_extractor_.CheckSilenceComputeFeatures( 45 /*samples=*/{frame.data(), rnn_vad::kFrameSize10ms24kHz}, 46 feature_vector); 47 return rnn_vad_.ComputeVadProbability(feature_vector, is_silence); 48 } 49 50 private: 51 rnn_vad::FeaturesExtractor features_extractor_; 52 rnn_vad::RnnVad rnn_vad_; 53 }; 54 55 } // namespace 56 57 VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( 58 const AvailableCpuFeatures& cpu_features, 59 int sample_rate_hz) 60 : VoiceActivityDetectorWrapper(kVadResetPeriodMs, 61 cpu_features, 62 sample_rate_hz) {} 63 64 VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( 65 int vad_reset_period_ms, 66 const AvailableCpuFeatures& cpu_features, 67 int sample_rate_hz) 68 : VoiceActivityDetectorWrapper(vad_reset_period_ms, 69 std::make_unique<MonoVadImpl>(cpu_features), 70 sample_rate_hz) {} 71 72 VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper( 73 int vad_reset_period_ms, 74 std::unique_ptr<MonoVad> vad, 75 int sample_rate_hz) 76 : vad_reset_period_frames_( 77 CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)), 78 frame_size_(CheckedDivExact(sample_rate_hz, kNumFramesPerSecond)), 79 time_to_vad_reset_(vad_reset_period_frames_), 80 vad_(std::move(vad)), 81 resampled_buffer_( 82 CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond)), 83 resampler_(frame_size_, 84 resampled_buffer_.size(), 85 /*num_channels=*/1) { 86 RTC_DCHECK_GT(vad_reset_period_frames_, 1); 87 vad_->Reset(); 88 } 89 90 VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default; 91 92 float VoiceActivityDetectorWrapper::Analyze( 93 DeinterleavedView<const float> frame) { 94 // Periodically reset the VAD. 95 time_to_vad_reset_--; 96 if (time_to_vad_reset_ <= 0) { 97 vad_->Reset(); 98 time_to_vad_reset_ = vad_reset_period_frames_; 99 } 100 101 // Resample the first channel of `frame`. 102 RTC_DCHECK_EQ(frame.samples_per_channel(), frame_size_); 103 MonoView<float> dst(resampled_buffer_.data(), resampled_buffer_.size()); 104 resampler_.Resample(frame[0], dst); 105 106 return vad_->Analyze(resampled_buffer_); 107 } 108 109 } // namespace webrtc