tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

rnn_fc.h (2643B)


      1 /*
      2 *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
     12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
     13 
     14 #include <array>
     15 #include <cstdint>
     16 #include <vector>
     17 
     18 #include "absl/strings/string_view.h"
     19 #include "api/array_view.h"
     20 #include "api/function_view.h"
     21 #include "modules/audio_processing/agc2/cpu_features.h"
     22 #include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
     23 
     24 namespace webrtc {
     25 namespace rnn_vad {
     26 
     27 // Activation function for a neural network cell.
     28 enum class ActivationFunction { kTansigApproximated, kSigmoidApproximated };
     29 
     30 // Maximum number of units for an FC layer.
     31 constexpr int kFullyConnectedLayerMaxUnits = 24;
     32 
     33 // Fully-connected layer with a custom activation function which owns the output
     34 // buffer.
     35 class FullyConnectedLayer {
     36 public:
     37  // Ctor. `output_size` cannot be greater than `kFullyConnectedLayerMaxUnits`.
     38  FullyConnectedLayer(int input_size,
     39                      int output_size,
     40                      ArrayView<const int8_t> bias,
     41                      ArrayView<const int8_t> weights,
     42                      ActivationFunction activation_function,
     43                      const AvailableCpuFeatures& cpu_features,
     44                      absl::string_view layer_name);
     45  FullyConnectedLayer(const FullyConnectedLayer&) = delete;
     46  FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
     47  ~FullyConnectedLayer();
     48 
     49  // Returns the size of the input vector.
     50  int input_size() const { return input_size_; }
     51  // Returns the pointer to the first element of the output buffer.
     52  const float* data() const { return output_.data(); }
     53  // Returns the size of the output buffer.
     54  int size() const { return output_size_; }
     55 
     56  // Computes the fully-connected layer output.
     57  void ComputeOutput(ArrayView<const float> input);
     58 
     59 private:
     60  const int input_size_;
     61  const int output_size_;
     62  const std::vector<float> bias_;
     63  const std::vector<float> weights_;
     64  const VectorMath vector_math_;
     65  FunctionView<float(float)> activation_function_;
     66  // Over-allocated array with size equal to `output_size_`.
     67  std::array<float, kFullyConnectedLayerMaxUnits> output_;
     68 };
     69 
     70 }  // namespace rnn_vad
     71 }  // namespace webrtc
     72 
     73 #endif  // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_