vector_math_avx2.cc (2221B)
1 /* 2 * Copyright (c) 2020 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11 #include <immintrin.h> 12 13 #include <cmath> 14 15 #include "api/array_view.h" 16 #include "modules/audio_processing/aec3/vector_math.h" 17 #include "rtc_base/checks.h" 18 19 namespace webrtc { 20 namespace aec3 { 21 22 // Elementwise square root. 23 void VectorMath::SqrtAVX2(ArrayView<float> x) { 24 const int x_size = static_cast<int>(x.size()); 25 const int vector_limit = x_size >> 3; 26 27 int j = 0; 28 for (; j < vector_limit * 8; j += 8) { 29 __m256 g = _mm256_loadu_ps(&x[j]); 30 g = _mm256_sqrt_ps(g); 31 _mm256_storeu_ps(&x[j], g); 32 } 33 34 for (; j < x_size; ++j) { 35 x[j] = sqrtf(x[j]); 36 } 37 } 38 39 // Elementwise vector multiplication z = x * y. 40 void VectorMath::MultiplyAVX2(ArrayView<const float> x, 41 ArrayView<const float> y, 42 ArrayView<float> z) { 43 RTC_DCHECK_EQ(z.size(), x.size()); 44 RTC_DCHECK_EQ(z.size(), y.size()); 45 const int x_size = static_cast<int>(x.size()); 46 const int vector_limit = x_size >> 3; 47 48 int j = 0; 49 for (; j < vector_limit * 8; j += 8) { 50 const __m256 x_j = _mm256_loadu_ps(&x[j]); 51 const __m256 y_j = _mm256_loadu_ps(&y[j]); 52 const __m256 z_j = _mm256_mul_ps(x_j, y_j); 53 _mm256_storeu_ps(&z[j], z_j); 54 } 55 56 for (; j < x_size; ++j) { 57 z[j] = x[j] * y[j]; 58 } 59 } 60 61 // Elementwise vector accumulation z += x. 62 void VectorMath::AccumulateAVX2(ArrayView<const float> x, ArrayView<float> z) { 63 RTC_DCHECK_EQ(z.size(), x.size()); 64 const int x_size = static_cast<int>(x.size()); 65 const int vector_limit = x_size >> 3; 66 67 int j = 0; 68 for (; j < vector_limit * 8; j += 8) { 69 const __m256 x_j = _mm256_loadu_ps(&x[j]); 70 __m256 z_j = _mm256_loadu_ps(&z[j]); 71 z_j = _mm256_add_ps(x_j, z_j); 72 _mm256_storeu_ps(&z[j], z_j); 73 } 74 75 for (; j < x_size; ++j) { 76 z[j] += x[j]; 77 } 78 } 79 80 } // namespace aec3 81 } // namespace webrtc