tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

vector_math.h (6829B)


      1 /*
      2 *  Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
      3 *
      4 *  Use of this source code is governed by a BSD-style license
      5 *  that can be found in the LICENSE file in the root of the source
      6 *  tree. An additional intellectual property rights grant can be found
      7 *  in the file PATENTS.  All contributing project authors may
      8 *  be found in the AUTHORS file in the root of the source tree.
      9 */
     10 
     11 #ifndef MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_
     12 #define MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_
     13 
     14 #include <algorithm>
     15 #include <cmath>
     16 #include <functional>
     17 
     18 #include "api/array_view.h"
     19 #include "modules/audio_processing/aec3/aec3_common.h"
     20 #include "rtc_base/checks.h"
     21 
     22 // Defines WEBRTC_ARCH_X86_FAMILY, used below.
     23 #include "rtc_base/system/arch.h"
     24 #if defined(WEBRTC_HAS_NEON)
     25 #include <arm_neon.h>
     26 #endif
     27 #if defined(WEBRTC_ARCH_X86_FAMILY)
     28 #include <emmintrin.h>
     29 #endif
     30 
     31 namespace webrtc {
     32 namespace aec3 {
     33 
     34 // Provides optimizations for mathematical operations based on vectors.
     35 class VectorMath {
     36 public:
     37  explicit VectorMath(Aec3Optimization optimization)
     38      : optimization_(optimization) {}
     39 
     40  // Elementwise square root.
     41  void SqrtAVX2(ArrayView<float> x);
     42  void Sqrt(ArrayView<float> x) {
     43    switch (optimization_) {
     44 #if defined(WEBRTC_ARCH_X86_FAMILY)
     45      case Aec3Optimization::kSse2: {
     46        const int x_size = static_cast<int>(x.size());
     47        const int vector_limit = x_size >> 2;
     48 
     49        int j = 0;
     50        for (; j < vector_limit * 4; j += 4) {
     51          __m128 g = _mm_loadu_ps(&x[j]);
     52          g = _mm_sqrt_ps(g);
     53          _mm_storeu_ps(&x[j], g);
     54        }
     55 
     56        for (; j < x_size; ++j) {
     57          x[j] = sqrtf(x[j]);
     58        }
     59      } break;
     60      case Aec3Optimization::kAvx2:
     61        SqrtAVX2(x);
     62        break;
     63 #endif
     64 #if defined(WEBRTC_HAS_NEON)
     65      case Aec3Optimization::kNeon: {
     66        const int x_size = static_cast<int>(x.size());
     67        const int vector_limit = x_size >> 2;
     68 
     69        int j = 0;
     70        for (; j < vector_limit * 4; j += 4) {
     71          float32x4_t g = vld1q_f32(&x[j]);
     72 #if !defined(WEBRTC_ARCH_ARM64)
     73          float32x4_t y = vrsqrteq_f32(g);
     74 
     75          // Code to handle sqrt(0).
     76          // If the input to sqrtf() is zero, a zero will be returned.
     77          // If the input to vrsqrteq_f32() is zero, positive infinity is
     78          // returned.
     79          const uint32x4_t vec_p_inf = vdupq_n_u32(0x7F800000);
     80          // check for divide by zero
     81          const uint32x4_t div_by_zero =
     82              vceqq_u32(vec_p_inf, vreinterpretq_u32_f32(y));
     83          // zero out the positive infinity results
     84          y = vreinterpretq_f32_u32(
     85              vandq_u32(vmvnq_u32(div_by_zero), vreinterpretq_u32_f32(y)));
     86          // from arm documentation
     87          // The Newton-Raphson iteration:
     88          //     y[n+1] = y[n] * (3 - d * (y[n] * y[n])) / 2)
     89          // converges to (1/√d) if y0 is the result of VRSQRTE applied to d.
     90          //
     91          // Note: The precision did not improve after 2 iterations.
     92          for (int i = 0; i < 2; i++) {
     93            y = vmulq_f32(vrsqrtsq_f32(vmulq_f32(y, y), g), y);
     94          }
     95          // sqrt(g) = g * 1/sqrt(g)
     96          g = vmulq_f32(g, y);
     97 #else
     98          g = vsqrtq_f32(g);
     99 #endif
    100          vst1q_f32(&x[j], g);
    101        }
    102 
    103        for (; j < x_size; ++j) {
    104          x[j] = sqrtf(x[j]);
    105        }
    106      }
    107 #endif
    108      break;
    109      default:
    110        std::for_each(x.begin(), x.end(), [](float& a) { a = sqrtf(a); });
    111    }
    112  }
    113 
    114  // Elementwise vector multiplication z = x * y.
    115  void MultiplyAVX2(ArrayView<const float> x,
    116                    ArrayView<const float> y,
    117                    ArrayView<float> z);
    118  void Multiply(ArrayView<const float> x,
    119                ArrayView<const float> y,
    120                ArrayView<float> z) {
    121    RTC_DCHECK_EQ(z.size(), x.size());
    122    RTC_DCHECK_EQ(z.size(), y.size());
    123    switch (optimization_) {
    124 #if defined(WEBRTC_ARCH_X86_FAMILY)
    125      case Aec3Optimization::kSse2: {
    126        const int x_size = static_cast<int>(x.size());
    127        const int vector_limit = x_size >> 2;
    128 
    129        int j = 0;
    130        for (; j < vector_limit * 4; j += 4) {
    131          const __m128 x_j = _mm_loadu_ps(&x[j]);
    132          const __m128 y_j = _mm_loadu_ps(&y[j]);
    133          const __m128 z_j = _mm_mul_ps(x_j, y_j);
    134          _mm_storeu_ps(&z[j], z_j);
    135        }
    136 
    137        for (; j < x_size; ++j) {
    138          z[j] = x[j] * y[j];
    139        }
    140      } break;
    141      case Aec3Optimization::kAvx2:
    142        MultiplyAVX2(x, y, z);
    143        break;
    144 #endif
    145 #if defined(WEBRTC_HAS_NEON)
    146      case Aec3Optimization::kNeon: {
    147        const int x_size = static_cast<int>(x.size());
    148        const int vector_limit = x_size >> 2;
    149 
    150        int j = 0;
    151        for (; j < vector_limit * 4; j += 4) {
    152          const float32x4_t x_j = vld1q_f32(&x[j]);
    153          const float32x4_t y_j = vld1q_f32(&y[j]);
    154          const float32x4_t z_j = vmulq_f32(x_j, y_j);
    155          vst1q_f32(&z[j], z_j);
    156        }
    157 
    158        for (; j < x_size; ++j) {
    159          z[j] = x[j] * y[j];
    160        }
    161      } break;
    162 #endif
    163      default:
    164        std::transform(x.begin(), x.end(), y.begin(), z.begin(),
    165                       std::multiplies<float>());
    166    }
    167  }
    168 
    169  // Elementwise vector accumulation z += x.
    170  void AccumulateAVX2(ArrayView<const float> x, ArrayView<float> z);
    171  void Accumulate(ArrayView<const float> x, ArrayView<float> z) {
    172    RTC_DCHECK_EQ(z.size(), x.size());
    173    switch (optimization_) {
    174 #if defined(WEBRTC_ARCH_X86_FAMILY)
    175      case Aec3Optimization::kSse2: {
    176        const int x_size = static_cast<int>(x.size());
    177        const int vector_limit = x_size >> 2;
    178 
    179        int j = 0;
    180        for (; j < vector_limit * 4; j += 4) {
    181          const __m128 x_j = _mm_loadu_ps(&x[j]);
    182          __m128 z_j = _mm_loadu_ps(&z[j]);
    183          z_j = _mm_add_ps(x_j, z_j);
    184          _mm_storeu_ps(&z[j], z_j);
    185        }
    186 
    187        for (; j < x_size; ++j) {
    188          z[j] += x[j];
    189        }
    190      } break;
    191      case Aec3Optimization::kAvx2:
    192        AccumulateAVX2(x, z);
    193        break;
    194 #endif
    195 #if defined(WEBRTC_HAS_NEON)
    196      case Aec3Optimization::kNeon: {
    197        const int x_size = static_cast<int>(x.size());
    198        const int vector_limit = x_size >> 2;
    199 
    200        int j = 0;
    201        for (; j < vector_limit * 4; j += 4) {
    202          const float32x4_t x_j = vld1q_f32(&x[j]);
    203          float32x4_t z_j = vld1q_f32(&z[j]);
    204          z_j = vaddq_f32(z_j, x_j);
    205          vst1q_f32(&z[j], z_j);
    206        }
    207 
    208        for (; j < x_size; ++j) {
    209          z[j] += x[j];
    210        }
    211      } break;
    212 #endif
    213      default:
    214        std::transform(x.begin(), x.end(), z.begin(), z.begin(),
    215                       std::plus<float>());
    216    }
    217  }
    218 
    219 private:
    220  Aec3Optimization optimization_;
    221 };
    222 
    223 }  // namespace aec3
    224 
    225 }  // namespace webrtc
    226 
    227 #endif  // MODULES_AUDIO_PROCESSING_AEC3_VECTOR_MATH_H_