rnn_gru.h (2573B)
1 /* 2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_ 12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_ 13 14 #include <array> 15 #include <cstdint> 16 #include <vector> 17 18 #include "absl/strings/string_view.h" 19 #include "api/array_view.h" 20 #include "modules/audio_processing/agc2/cpu_features.h" 21 #include "modules/audio_processing/agc2/rnn_vad/vector_math.h" 22 23 namespace webrtc { 24 namespace rnn_vad { 25 26 // Maximum number of units for a GRU layer. 27 constexpr int kGruLayerMaxUnits = 24; 28 29 // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as 30 // activation functions for the update/reset and output gates respectively. 31 class GatedRecurrentLayer { 32 public: 33 // Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`. 34 GatedRecurrentLayer(int input_size, 35 int output_size, 36 ArrayView<const int8_t> bias, 37 ArrayView<const int8_t> weights, 38 ArrayView<const int8_t> recurrent_weights, 39 const AvailableCpuFeatures& cpu_features, 40 absl::string_view layer_name); 41 GatedRecurrentLayer(const GatedRecurrentLayer&) = delete; 42 GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete; 43 ~GatedRecurrentLayer(); 44 45 // Returns the size of the input vector. 46 int input_size() const { return input_size_; } 47 // Returns the pointer to the first element of the output buffer. 48 const float* data() const { return state_.data(); } 49 // Returns the size of the output buffer. 50 int size() const { return output_size_; } 51 52 // Resets the GRU state. 53 void Reset(); 54 // Computes the recurrent layer output and updates the status. 55 void ComputeOutput(ArrayView<const float> input); 56 57 private: 58 const int input_size_; 59 const int output_size_; 60 const std::vector<float> bias_; 61 const std::vector<float> weights_; 62 const std::vector<float> recurrent_weights_; 63 const VectorMath vector_math_; 64 // Over-allocated array with size equal to `output_size_`. 65 std::array<float, kGruLayerMaxUnits> state_; 66 }; 67 68 } // namespace rnn_vad 69 } // namespace webrtc 70 71 #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_