vector_math.h (6829B)
1 /* 2 * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #ifndef MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ 12 #define MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_ 13 14 #include <algorithm> 15 #include <cmath> 16 #include <functional> 17 18 #include "api/array_view.h" 19 #include "modules/audio_processing/aec3/aec3_common.h" 20 #include "rtc_base/checks.h" 21 22 // Defines WEBRTC_ARCH_X86_FAMILY, used below. 23 #include "rtc_base/system/arch.h" 24 #if defined(WEBRTC_HAS_NEON) 25 #include <arm_neon.h> 26 #endif 27 #if defined(WEBRTC_ARCH_X86_FAMILY) 28 #include <emmintrin.h> 29 #endif 30 31 namespace webrtc { 32 namespace aec3 { 33 34 // Provides optimizations for mathematical operations based on vectors. 35 class VectorMath { 36 public: 37 explicit VectorMath(Aec3Optimization optimization) 38 : optimization_(optimization) {} 39 40 // Elementwise square root. 41 void SqrtAVX2(ArrayView<float> x); 42 void Sqrt(ArrayView<float> x) { 43 switch (optimization_) { 44 #if defined(WEBRTC_ARCH_X86_FAMILY) 45 case Aec3Optimization::kSse2: { 46 const int x_size = static_cast<int>(x.size()); 47 const int vector_limit = x_size >> 2; 48 49 int j = 0; 50 for (; j < vector_limit * 4; j += 4) { 51 __m128 g = _mm_loadu_ps(&x[j]); 52 g = _mm_sqrt_ps(g); 53 _mm_storeu_ps(&x[j], g); 54 } 55 56 for (; j < x_size; ++j) { 57 x[j] = sqrtf(x[j]); 58 } 59 } break; 60 case Aec3Optimization::kAvx2: 61 SqrtAVX2(x); 62 break; 63 #endif 64 #if defined(WEBRTC_HAS_NEON) 65 case Aec3Optimization::kNeon: { 66 const int x_size = static_cast<int>(x.size()); 67 const int vector_limit = x_size >> 2; 68 69 int j = 0; 70 for (; j < vector_limit * 4; j += 4) { 71 float32x4_t g = vld1q_f32(&x[j]); 72 #if !defined(WEBRTC_ARCH_ARM64) 73 float32x4_t y = vrsqrteq_f32(g); 74 75 // Code to handle sqrt(0). 76 // If the input to sqrtf() is zero, a zero will be returned. 77 // If the input to vrsqrteq_f32() is zero, positive infinity is 78 // returned. 79 const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000); 80 // check for divide by zero 81 const uint32x4_t div_by_zero = 82 vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(y)); 83 // zero out the positive infinity results 84 y = vreinterpretq_f32_u32( 85 vandq_u32(vmvnq_u32(div_by_zero), vreinterpretq_u32_f32(y))); 86 // from arm documentation 87 // The Newton-Raphson iteration: 88 // y[n+1] = y[n] * (3 - d * (y[n] * y[n])) / 2) 89 // converges to (1/√d) if y0 is the result of VRSQRTE applied to d. 90 // 91 // Note: The precision did not improve after 2 iterations. 92 for (int i = 0; i < 2; i++) { 93 y = vmulq_f32(vrsqrtsq_f32(vmulq_f32(y, y), g), y); 94 } 95 // sqrt(g) = g * 1/sqrt(g) 96 g = vmulq_f32(g, y); 97 #else 98 g = vsqrtq_f32(g); 99 #endif 100 vst1q_f32(&x[j], g); 101 } 102 103 for (; j < x_size; ++j) { 104 x[j] = sqrtf(x[j]); 105 } 106 } 107 #endif 108 break; 109 default: 110 std::for_each(x.begin(), x.end(), [](float& a) { a = sqrtf(a); }); 111 } 112 } 113 114 // Elementwise vector multiplication z = x * y. 115 void MultiplyAVX2(ArrayView<const float> x, 116 ArrayView<const float> y, 117 ArrayView<float> z); 118 void Multiply(ArrayView<const float> x, 119 ArrayView<const float> y, 120 ArrayView<float> z) { 121 RTC_DCHECK_EQ(z.size(), x.size()); 122 RTC_DCHECK_EQ(z.size(), y.size()); 123 switch (optimization_) { 124 #if defined(WEBRTC_ARCH_X86_FAMILY) 125 case Aec3Optimization::kSse2: { 126 const int x_size = static_cast<int>(x.size()); 127 const int vector_limit = x_size >> 2; 128 129 int j = 0; 130 for (; j < vector_limit * 4; j += 4) { 131 const __m128 x_j = _mm_loadu_ps(&x[j]); 132 const __m128 y_j = _mm_loadu_ps(&y[j]); 133 const __m128 z_j = _mm_mul_ps(x_j, y_j); 134 _mm_storeu_ps(&z[j], z_j); 135 } 136 137 for (; j < x_size; ++j) { 138 z[j] = x[j] * y[j]; 139 } 140 } break; 141 case Aec3Optimization::kAvx2: 142 MultiplyAVX2(x, y, z); 143 break; 144 #endif 145 #if defined(WEBRTC_HAS_NEON) 146 case Aec3Optimization::kNeon: { 147 const int x_size = static_cast<int>(x.size()); 148 const int vector_limit = x_size >> 2; 149 150 int j = 0; 151 for (; j < vector_limit * 4; j += 4) { 152 const float32x4_t x_j = vld1q_f32(&x[j]); 153 const float32x4_t y_j = vld1q_f32(&y[j]); 154 const float32x4_t z_j = vmulq_f32(x_j, y_j); 155 vst1q_f32(&z[j], z_j); 156 } 157 158 for (; j < x_size; ++j) { 159 z[j] = x[j] * y[j]; 160 } 161 } break; 162 #endif 163 default: 164 std::transform(x.begin(), x.end(), y.begin(), z.begin(), 165 std::multiplies<float>()); 166 } 167 } 168 169 // Elementwise vector accumulation z += x. 170 void AccumulateAVX2(ArrayView<const float> x, ArrayView<float> z); 171 void Accumulate(ArrayView<const float> x, ArrayView<float> z) { 172 RTC_DCHECK_EQ(z.size(), x.size()); 173 switch (optimization_) { 174 #if defined(WEBRTC_ARCH_X86_FAMILY) 175 case Aec3Optimization::kSse2: { 176 const int x_size = static_cast<int>(x.size()); 177 const int vector_limit = x_size >> 2; 178 179 int j = 0; 180 for (; j < vector_limit * 4; j += 4) { 181 const __m128 x_j = _mm_loadu_ps(&x[j]); 182 __m128 z_j = _mm_loadu_ps(&z[j]); 183 z_j = _mm_add_ps(x_j, z_j); 184 _mm_storeu_ps(&z[j], z_j); 185 } 186 187 for (; j < x_size; ++j) { 188 z[j] += x[j]; 189 } 190 } break; 191 case Aec3Optimization::kAvx2: 192 AccumulateAVX2(x, z); 193 break; 194 #endif 195 #if defined(WEBRTC_HAS_NEON) 196 case Aec3Optimization::kNeon: { 197 const int x_size = static_cast<int>(x.size()); 198 const int vector_limit = x_size >> 2; 199 200 int j = 0; 201 for (; j < vector_limit * 4; j += 4) { 202 const float32x4_t x_j = vld1q_f32(&x[j]); 203 float32x4_t z_j = vld1q_f32(&z[j]); 204 z_j = vaddq_f32(z_j, x_j); 205 vst1q_f32(&z[j], z_j); 206 } 207 208 for (; j < x_size; ++j) { 209 z[j] += x[j]; 210 } 211 } break; 212 #endif 213 default: 214 std::transform(x.begin(), x.end(), z.begin(), z.begin(), 215 std::plus<float>()); 216 } 217 } 218 219 private: 220 Aec3Optimization optimization_; 221 }; 222 223 } // namespace aec3 224 225 } // namespace webrtc 226 227 #endif // MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_