tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

rnn_gru.cc (8559B)


      1 /*
      2 *  Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
     12 
     13 #include <algorithm>
     14 #include <array>
     15 #include <cstddef>
     16 #include <cstdint>
     17 #include <vector>
     18 
     19 #include "absl/strings/string_view.h"
     20 #include "api/array_view.h"
     21 #include "modules/audio_processing/agc2/cpu_features.h"
     22 #include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
     23 #include "rtc_base/checks.h"
     24 #include "rtc_base/numerics/safe_conversions.h"
     25 #include "third_party/rnnoise/src/rnn_activations.h"
     26 #include "third_party/rnnoise/src/rnn_vad_weights.h"
     27 
     28 namespace webrtc {
     29 namespace rnn_vad {
     30 namespace {
     31 
     32 constexpr int kNumGruGates = 3;  // Update, reset, output.
     33 
     34 std::vector<float> PreprocessGruTensor(ArrayView<const int8_t> tensor_src,
     35                                       int output_size) {
     36  // Transpose, cast and scale.
     37  // `n` is the size of the first dimension of the 3-dim tensor `weights`.
     38  const int n = CheckedDivExact(dchecked_cast<int>(tensor_src.size()),
     39                                output_size * kNumGruGates);
     40  const int stride_src = kNumGruGates * output_size;
     41  const int stride_dst = n * output_size;
     42  std::vector<float> tensor_dst(tensor_src.size());
     43  for (int g = 0; g < kNumGruGates; ++g) {
     44    for (int o = 0; o < output_size; ++o) {
     45      for (int i = 0; i < n; ++i) {
     46        tensor_dst[g * stride_dst + o * n + i] =
     47            ::rnnoise::kWeightsScale *
     48            static_cast<float>(
     49                tensor_src[i * stride_src + g * output_size + o]);
     50      }
     51    }
     52  }
     53  return tensor_dst;
     54 }
     55 
     56 // Computes the output for the update or the reset gate.
     57 // Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
     58 // - `g`: output gate vector
     59 // - `W`: weights matrix
     60 // - `i`: input vector
     61 // - `R`: recurrent weights matrix
     62 // - `s`: state gate vector
     63 // - `b`: bias vector
     64 void ComputeUpdateResetGate(int input_size,
     65                            int output_size,
     66                            const VectorMath& vector_math,
     67                            ArrayView<const float> input,
     68                            ArrayView<const float> state,
     69                            ArrayView<const float> bias,
     70                            ArrayView<const float> weights,
     71                            ArrayView<const float> recurrent_weights,
     72                            ArrayView<float> gate) {
     73  RTC_DCHECK_EQ(input.size(), input_size);
     74  RTC_DCHECK_EQ(state.size(), output_size);
     75  RTC_DCHECK_EQ(bias.size(), output_size);
     76  RTC_DCHECK_EQ(weights.size(), input_size * output_size);
     77  RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
     78  RTC_DCHECK_GE(gate.size(), output_size);  // `gate` is over-allocated.
     79  for (int o = 0; o < output_size; ++o) {
     80    float x = bias[o];
     81    x += vector_math.DotProduct(input,
     82                                weights.subview(o * input_size, input_size));
     83    x += vector_math.DotProduct(
     84        state, recurrent_weights.subview(o * output_size, output_size));
     85    gate[o] = ::rnnoise::SigmoidApproximated(x);
     86  }
     87 }
     88 
     89 // Computes the output for the state gate.
     90 // Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
     91 // - `s'`: output state gate vector
     92 // - `s`: previous state gate vector
     93 // - `u`: update gate vector
     94 // - `W`: weights matrix
     95 // - `i`: input vector
     96 // - `R`: recurrent weights matrix
     97 // - `r`: reset gate vector
     98 // - `b`: bias vector
     99 // - `.*` element-wise product
    100 void ComputeStateGate(int input_size,
    101                      int output_size,
    102                      const VectorMath& vector_math,
    103                      ArrayView<const float> input,
    104                      ArrayView<const float> update,
    105                      ArrayView<const float> reset,
    106                      ArrayView<const float> bias,
    107                      ArrayView<const float> weights,
    108                      ArrayView<const float> recurrent_weights,
    109                      ArrayView<float> state) {
    110  RTC_DCHECK_EQ(input.size(), input_size);
    111  RTC_DCHECK_GE(update.size(), output_size);  // `update` is over-allocated.
    112  RTC_DCHECK_GE(reset.size(), output_size);   // `reset` is over-allocated.
    113  RTC_DCHECK_EQ(bias.size(), output_size);
    114  RTC_DCHECK_EQ(weights.size(), input_size * output_size);
    115  RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
    116  RTC_DCHECK_EQ(state.size(), output_size);
    117  std::array<float, kGruLayerMaxUnits> reset_x_state;
    118  for (int o = 0; o < output_size; ++o) {
    119    reset_x_state[o] = state[o] * reset[o];
    120  }
    121  for (int o = 0; o < output_size; ++o) {
    122    float x = bias[o];
    123    x += vector_math.DotProduct(input,
    124                                weights.subview(o * input_size, input_size));
    125    x += vector_math.DotProduct(
    126        {reset_x_state.data(), static_cast<size_t>(output_size)},
    127        recurrent_weights.subview(o * output_size, output_size));
    128    state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
    129  }
    130 }
    131 
    132 }  // namespace
    133 
    134 GatedRecurrentLayer::GatedRecurrentLayer(
    135    const int input_size,
    136    const int output_size,
    137    const ArrayView<const int8_t> bias,
    138    const ArrayView<const int8_t> weights,
    139    const ArrayView<const int8_t> recurrent_weights,
    140    const AvailableCpuFeatures& cpu_features,
    141    absl::string_view layer_name)
    142    : input_size_(input_size),
    143      output_size_(output_size),
    144      bias_(PreprocessGruTensor(bias, output_size)),
    145      weights_(PreprocessGruTensor(weights, output_size)),
    146      recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
    147      vector_math_(cpu_features) {
    148  RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
    149      << "Insufficient GRU layer over-allocation (" << layer_name << ").";
    150  RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
    151      << "Mismatching output size and bias terms array size (" << layer_name
    152      << ").";
    153  RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
    154      << "Mismatching input-output size and weight coefficients array size ("
    155      << layer_name << ").";
    156  RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
    157                recurrent_weights_.size())
    158      << "Mismatching input-output size and recurrent weight coefficients array"
    159         " size ("
    160      << layer_name << ").";
    161  Reset();
    162 }
    163 
    164 GatedRecurrentLayer::~GatedRecurrentLayer() = default;
    165 
    166 void GatedRecurrentLayer::Reset() {
    167  state_.fill(0.f);
    168 }
    169 
    170 void GatedRecurrentLayer::ComputeOutput(ArrayView<const float> input) {
    171  RTC_DCHECK_EQ(input.size(), input_size_);
    172 
    173  // The tensors below are organized as a sequence of flattened tensors for the
    174  // `update`, `reset` and `state` gates.
    175  ArrayView<const float> bias(bias_);
    176  ArrayView<const float> weights(weights_);
    177  ArrayView<const float> recurrent_weights(recurrent_weights_);
    178  // Strides to access to the flattened tensors for a specific gate.
    179  const int stride_weights = input_size_ * output_size_;
    180  const int stride_recurrent_weights = output_size_ * output_size_;
    181 
    182  ArrayView<float> state(state_.data(), output_size_);
    183 
    184  // Update gate.
    185  std::array<float, kGruLayerMaxUnits> update;
    186  ComputeUpdateResetGate(
    187      input_size_, output_size_, vector_math_, input, state,
    188      bias.subview(0, output_size_), weights.subview(0, stride_weights),
    189      recurrent_weights.subview(0, stride_recurrent_weights), update);
    190  // Reset gate.
    191  std::array<float, kGruLayerMaxUnits> reset;
    192  ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
    193                         bias.subview(output_size_, output_size_),
    194                         weights.subview(stride_weights, stride_weights),
    195                         recurrent_weights.subview(stride_recurrent_weights,
    196                                                   stride_recurrent_weights),
    197                         reset);
    198  // State gate.
    199  ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
    200                   reset, bias.subview(2 * output_size_, output_size_),
    201                   weights.subview(2 * stride_weights, stride_weights),
    202                   recurrent_weights.subview(2 * stride_recurrent_weights,
    203                                             stride_recurrent_weights),
    204                   state);
    205 }
    206 
    207 }  // namespace rnn_vad
    208 }  // namespace webrtc