rnn_gru.cc (8559B)
1 /* 2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h" 12 13 #include <algorithm> 14 #include <array> 15 #include <cstddef> 16 #include <cstdint> 17 #include <vector> 18 19 #include "absl/strings/string_view.h" 20 #include "api/array_view.h" 21 #include "modules/audio_processing/agc2/cpu_features.h" 22 #include "modules/audio_processing/agc2/rnn_vad/vector_math.h" 23 #include "rtc_base/checks.h" 24 #include "rtc_base/numerics/safe_conversions.h" 25 #include "third_party/rnnoise/src/rnn_activations.h" 26 #include "third_party/rnnoise/src/rnn_vad_weights.h" 27 28 namespace webrtc { 29 namespace rnn_vad { 30 namespace { 31 32 constexpr int kNumGruGates = 3; // Update, reset, output. 33 34 std::vector<float> PreprocessGruTensor(ArrayView<const int8_t> tensor_src, 35 int output_size) { 36 // Transpose, cast and scale. 37 // `n` is the size of the first dimension of the 3-dim tensor `weights`. 38 const int n = CheckedDivExact(dchecked_cast<int>(tensor_src.size()), 39 output_size * kNumGruGates); 40 const int stride_src = kNumGruGates * output_size; 41 const int stride_dst = n * output_size; 42 std::vector<float> tensor_dst(tensor_src.size()); 43 for (int g = 0; g < kNumGruGates; ++g) { 44 for (int o = 0; o < output_size; ++o) { 45 for (int i = 0; i < n; ++i) { 46 tensor_dst[g * stride_dst + o * n + i] = 47 ::rnnoise::kWeightsScale * 48 static_cast<float>( 49 tensor_src[i * stride_src + g * output_size + o]); 50 } 51 } 52 } 53 return tensor_dst; 54 } 55 56 // Computes the output for the update or the reset gate. 57 // Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where 58 // - `g`: output gate vector 59 // - `W`: weights matrix 60 // - `i`: input vector 61 // - `R`: recurrent weights matrix 62 // - `s`: state gate vector 63 // - `b`: bias vector 64 void ComputeUpdateResetGate(int input_size, 65 int output_size, 66 const VectorMath& vector_math, 67 ArrayView<const float> input, 68 ArrayView<const float> state, 69 ArrayView<const float> bias, 70 ArrayView<const float> weights, 71 ArrayView<const float> recurrent_weights, 72 ArrayView<float> gate) { 73 RTC_DCHECK_EQ(input.size(), input_size); 74 RTC_DCHECK_EQ(state.size(), output_size); 75 RTC_DCHECK_EQ(bias.size(), output_size); 76 RTC_DCHECK_EQ(weights.size(), input_size * output_size); 77 RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size); 78 RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated. 79 for (int o = 0; o < output_size; ++o) { 80 float x = bias[o]; 81 x += vector_math.DotProduct(input, 82 weights.subview(o * input_size, input_size)); 83 x += vector_math.DotProduct( 84 state, recurrent_weights.subview(o * output_size, output_size)); 85 gate[o] = ::rnnoise::SigmoidApproximated(x); 86 } 87 } 88 89 // Computes the output for the state gate. 90 // Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where 91 // - `s'`: output state gate vector 92 // - `s`: previous state gate vector 93 // - `u`: update gate vector 94 // - `W`: weights matrix 95 // - `i`: input vector 96 // - `R`: recurrent weights matrix 97 // - `r`: reset gate vector 98 // - `b`: bias vector 99 // - `.*` element-wise product 100 void ComputeStateGate(int input_size, 101 int output_size, 102 const VectorMath& vector_math, 103 ArrayView<const float> input, 104 ArrayView<const float> update, 105 ArrayView<const float> reset, 106 ArrayView<const float> bias, 107 ArrayView<const float> weights, 108 ArrayView<const float> recurrent_weights, 109 ArrayView<float> state) { 110 RTC_DCHECK_EQ(input.size(), input_size); 111 RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated. 112 RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated. 113 RTC_DCHECK_EQ(bias.size(), output_size); 114 RTC_DCHECK_EQ(weights.size(), input_size * output_size); 115 RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size); 116 RTC_DCHECK_EQ(state.size(), output_size); 117 std::array<float, kGruLayerMaxUnits> reset_x_state; 118 for (int o = 0; o < output_size; ++o) { 119 reset_x_state[o] = state[o] * reset[o]; 120 } 121 for (int o = 0; o < output_size; ++o) { 122 float x = bias[o]; 123 x += vector_math.DotProduct(input, 124 weights.subview(o * input_size, input_size)); 125 x += vector_math.DotProduct( 126 {reset_x_state.data(), static_cast<size_t>(output_size)}, 127 recurrent_weights.subview(o * output_size, output_size)); 128 state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x); 129 } 130 } 131 132 } // namespace 133 134 GatedRecurrentLayer::GatedRecurrentLayer( 135 const int input_size, 136 const int output_size, 137 const ArrayView<const int8_t> bias, 138 const ArrayView<const int8_t> weights, 139 const ArrayView<const int8_t> recurrent_weights, 140 const AvailableCpuFeatures& cpu_features, 141 absl::string_view layer_name) 142 : input_size_(input_size), 143 output_size_(output_size), 144 bias_(PreprocessGruTensor(bias, output_size)), 145 weights_(PreprocessGruTensor(weights, output_size)), 146 recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)), 147 vector_math_(cpu_features) { 148 RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits) 149 << "Insufficient GRU layer over-allocation (" << layer_name << ")."; 150 RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size()) 151 << "Mismatching output size and bias terms array size (" << layer_name 152 << ")."; 153 RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size()) 154 << "Mismatching input-output size and weight coefficients array size (" 155 << layer_name << ")."; 156 RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_, 157 recurrent_weights_.size()) 158 << "Mismatching input-output size and recurrent weight coefficients array" 159 " size (" 160 << layer_name << ")."; 161 Reset(); 162 } 163 164 GatedRecurrentLayer::~GatedRecurrentLayer() = default; 165 166 void GatedRecurrentLayer::Reset() { 167 state_.fill(0.f); 168 } 169 170 void GatedRecurrentLayer::ComputeOutput(ArrayView<const float> input) { 171 RTC_DCHECK_EQ(input.size(), input_size_); 172 173 // The tensors below are organized as a sequence of flattened tensors for the 174 // `update`, `reset` and `state` gates. 175 ArrayView<const float> bias(bias_); 176 ArrayView<const float> weights(weights_); 177 ArrayView<const float> recurrent_weights(recurrent_weights_); 178 // Strides to access to the flattened tensors for a specific gate. 179 const int stride_weights = input_size_ * output_size_; 180 const int stride_recurrent_weights = output_size_ * output_size_; 181 182 ArrayView<float> state(state_.data(), output_size_); 183 184 // Update gate. 185 std::array<float, kGruLayerMaxUnits> update; 186 ComputeUpdateResetGate( 187 input_size_, output_size_, vector_math_, input, state, 188 bias.subview(0, output_size_), weights.subview(0, stride_weights), 189 recurrent_weights.subview(0, stride_recurrent_weights), update); 190 // Reset gate. 191 std::array<float, kGruLayerMaxUnits> reset; 192 ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state, 193 bias.subview(output_size_, output_size_), 194 weights.subview(stride_weights, stride_weights), 195 recurrent_weights.subview(stride_recurrent_weights, 196 stride_recurrent_weights), 197 reset); 198 // State gate. 199 ComputeStateGate(input_size_, output_size_, vector_math_, input, update, 200 reset, bias.subview(2 * output_size_, output_size_), 201 weights.subview(2 * stride_weights, stride_weights), 202 recurrent_weights.subview(2 * stride_recurrent_weights, 203 stride_recurrent_weights), 204 state); 205 } 206 207 } // namespace rnn_vad 208 } // namespace webrtc