rnn.h (1575B)
1 /* 2 * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_ 12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_ 13 14 #include <stddef.h> 15 16 17 #include "api/array_view.h" 18 #include "modules/audio_processing/agc2/cpu_features.h" 19 #include "modules/audio_processing/agc2/rnn_vad/common.h" 20 #include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h" 21 #include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h" 22 23 namespace webrtc { 24 namespace rnn_vad { 25 26 // Recurrent network with hard-coded architecture and weights for voice activity 27 // detection. 28 class RnnVad { 29 public: 30 explicit RnnVad(const AvailableCpuFeatures& cpu_features); 31 RnnVad(const RnnVad&) = delete; 32 RnnVad& operator=(const RnnVad&) = delete; 33 ~RnnVad(); 34 void Reset(); 35 // Observes `feature_vector` and `is_silence`, updates the RNN and returns the 36 // current voice probability. 37 float ComputeVadProbability( 38 ArrayView<const float, kFeatureVectorSize> feature_vector, 39 bool is_silence); 40 41 private: 42 FullyConnectedLayer input_; 43 GatedRecurrentLayer hidden_; 44 FullyConnectedLayer output_; 45 }; 46 47 } // namespace rnn_vad 48 } // namespace webrtc 49 50 #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_