rnn_fc.cc (3878B)
1 /* 2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h" 12 13 #include <algorithm> 14 #include <cstdint> 15 #include <vector> 16 17 #include "absl/strings/string_view.h" 18 #include "api/array_view.h" 19 #include "api/function_view.h" 20 #include "modules/audio_processing/agc2/cpu_features.h" 21 #include "rtc_base/checks.h" 22 #include "rtc_base/numerics/safe_conversions.h" 23 #include "third_party/rnnoise/src/rnn_activations.h" 24 #include "third_party/rnnoise/src/rnn_vad_weights.h" 25 26 namespace webrtc { 27 namespace rnn_vad { 28 namespace { 29 30 std::vector<float> GetScaledParams(ArrayView<const int8_t> params) { 31 std::vector<float> scaled_params(params.size()); 32 std::transform(params.begin(), params.end(), scaled_params.begin(), 33 [](int8_t x) -> float { 34 return ::rnnoise::kWeightsScale * static_cast<float>(x); 35 }); 36 return scaled_params; 37 } 38 39 // TODO(bugs.chromium.org/10480): Hard-code optimized layout and remove this 40 // function to improve setup time. 41 // Casts and scales `weights` and re-arranges the layout. 42 std::vector<float> PreprocessWeights(ArrayView<const int8_t> weights, 43 int output_size) { 44 if (output_size == 1) { 45 return GetScaledParams(weights); 46 } 47 // Transpose, scale and cast. 48 const int input_size = 49 CheckedDivExact(dchecked_cast<int>(weights.size()), output_size); 50 std::vector<float> w(weights.size()); 51 for (int o = 0; o < output_size; ++o) { 52 for (int i = 0; i < input_size; ++i) { 53 w[o * input_size + i] = rnnoise::kWeightsScale * 54 static_cast<float>(weights[i * output_size + o]); 55 } 56 } 57 return w; 58 } 59 60 FunctionView<float(float)> GetActivationFunction( 61 ActivationFunction activation_function) { 62 switch (activation_function) { 63 case ActivationFunction::kTansigApproximated: 64 return ::rnnoise::TansigApproximated; 65 case ActivationFunction::kSigmoidApproximated: 66 return ::rnnoise::SigmoidApproximated; 67 } 68 } 69 70 } // namespace 71 72 FullyConnectedLayer::FullyConnectedLayer( 73 const int input_size, 74 const int output_size, 75 const ArrayView<const int8_t> bias, 76 const ArrayView<const int8_t> weights, 77 ActivationFunction activation_function, 78 const AvailableCpuFeatures& cpu_features, 79 absl::string_view layer_name) 80 : input_size_(input_size), 81 output_size_(output_size), 82 bias_(GetScaledParams(bias)), 83 weights_(PreprocessWeights(weights, output_size)), 84 vector_math_(cpu_features), 85 activation_function_(GetActivationFunction(activation_function)) { 86 RTC_DCHECK_LE(output_size_, kFullyConnectedLayerMaxUnits) 87 << "Insufficient FC layer over-allocation (" << layer_name << ")."; 88 RTC_DCHECK_EQ(output_size_, bias_.size()) 89 << "Mismatching output size and bias terms array size (" << layer_name 90 << ")."; 91 RTC_DCHECK_EQ(input_size_ * output_size_, weights_.size()) 92 << "Mismatching input-output size and weight coefficients array size (" 93 << layer_name << ")."; 94 } 95 96 FullyConnectedLayer::~FullyConnectedLayer() = default; 97 98 void FullyConnectedLayer::ComputeOutput(ArrayView<const float> input) { 99 RTC_DCHECK_EQ(input.size(), input_size_); 100 ArrayView<const float> weights(weights_); 101 for (int o = 0; o < output_size_; ++o) { 102 output_[o] = activation_function_( 103 bias_[o] + vector_math_.DotProduct( 104 input, weights.subview(o * input_size_, input_size_))); 105 } 106 } 107 108 } // namespace rnn_vad 109 } // namespace webrtc