rnn_fc.h (2643B)
1 /* 2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_ 12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_ 13 14 #include <array> 15 #include <cstdint> 16 #include <vector> 17 18 #include "absl/strings/string_view.h" 19 #include "api/array_view.h" 20 #include "api/function_view.h" 21 #include "modules/audio_processing/agc2/cpu_features.h" 22 #include "modules/audio_processing/agc2/rnn_vad/vector_math.h" 23 24 namespace webrtc { 25 namespace rnn_vad { 26 27 // Activation function for a neural network cell. 28 enum class ActivationFunction { kTansigApproximated, kSigmoidApproximated }; 29 30 // Maximum number of units for an FC layer. 31 constexpr int kFullyConnectedLayerMaxUnits = 24; 32 33 // Fully-connected layer with a custom activation function which owns the output 34 // buffer. 35 class FullyConnectedLayer { 36 public: 37 // Ctor. `output_size` cannot be greater than `kFullyConnectedLayerMaxUnits`. 38 FullyConnectedLayer(int input_size, 39 int output_size, 40 ArrayView<const int8_t> bias, 41 ArrayView<const int8_t> weights, 42 ActivationFunction activation_function, 43 const AvailableCpuFeatures& cpu_features, 44 absl::string_view layer_name); 45 FullyConnectedLayer(const FullyConnectedLayer&) = delete; 46 FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete; 47 ~FullyConnectedLayer(); 48 49 // Returns the size of the input vector. 50 int input_size() const { return input_size_; } 51 // Returns the pointer to the first element of the output buffer. 52 const float* data() const { return output_.data(); } 53 // Returns the size of the output buffer. 54 int size() const { return output_size_; } 55 56 // Computes the fully-connected layer output. 57 void ComputeOutput(ArrayView<const float> input); 58 59 private: 60 const int input_size_; 61 const int output_size_; 62 const std::vector<float> bias_; 63 const std::vector<float> weights_; 64 const VectorMath vector_math_; 65 FunctionView<float(float)> activation_function_; 66 // Over-allocated array with size equal to `output_size_`. 67 std::array<float, kFullyConnectedLayerMaxUnits> output_; 68 }; 69 70 } // namespace rnn_vad 71 } // namespace webrtc 72 73 #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_