tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

rnn_gru.h (2573B)


      1 /*
      2 *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
     12 #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
     13 
     14 #include <array>
     15 #include <cstdint>
     16 #include <vector>
     17 
     18 #include "absl/strings/string_view.h"
     19 #include "api/array_view.h"
     20 #include "modules/audio_processing/agc2/cpu_features.h"
     21 #include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
     22 
     23 namespace webrtc {
     24 namespace rnn_vad {
     25 
     26 // Maximum number of units for a GRU layer.
     27 constexpr int kGruLayerMaxUnits = 24;
     28 
     29 // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
     30 // activation functions for the update/reset and output gates respectively.
     31 class GatedRecurrentLayer {
     32 public:
     33  // Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
     34  GatedRecurrentLayer(int input_size,
     35                      int output_size,
     36                      ArrayView<const int8_t> bias,
     37                      ArrayView<const int8_t> weights,
     38                      ArrayView<const int8_t> recurrent_weights,
     39                      const AvailableCpuFeatures& cpu_features,
     40                      absl::string_view layer_name);
     41  GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
     42  GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
     43  ~GatedRecurrentLayer();
     44 
     45  // Returns the size of the input vector.
     46  int input_size() const { return input_size_; }
     47  // Returns the pointer to the first element of the output buffer.
     48  const float* data() const { return state_.data(); }
     49  // Returns the size of the output buffer.
     50  int size() const { return output_size_; }
     51 
     52  // Resets the GRU state.
     53  void Reset();
     54  // Computes the recurrent layer output and updates the status.
     55  void ComputeOutput(ArrayView<const float> input);
     56 
     57 private:
     58  const int input_size_;
     59  const int output_size_;
     60  const std::vector<float> bias_;
     61  const std::vector<float> weights_;
     62  const std::vector<float> recurrent_weights_;
     63  const VectorMath vector_math_;
     64  // Over-allocated array with size equal to `output_size_`.
     65  std::array<float, kGruLayerMaxUnits> state_;
     66 };
     67 
     68 }  // namespace rnn_vad
     69 }  // namespace webrtc
     70 
     71 #endif  // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_