tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

gemmology.h (52788B)


      1 #ifndef GEMMOLOGY_H
      2 #define GEMMOLOGY_H
      3 
      4 #include "gemmology_fwd.h"
      5 
      6 #include <cstdint>
      7 #include <cstring>
      8 #include <tuple>
      9 
     10 #include <xsimd/xsimd.hpp>
     11 
     12 namespace gemmology {
     13 
     14 namespace {
     15 
     16 //
     17 // Arch specific implementation of various elementary operations
     18 //
     19 
     20 namespace kernel {
     21 
     22 #ifdef __AVX512BW__
     23 template <class Arch>
     24 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
     25 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
     26           xsimd::kernel::requires_arch<xsimd::avx512bw>) {
     27  return {_mm512_unpacklo_epi8(first, second),
     28          _mm512_unpackhi_epi8(first, second)};
     29 }
     30 
     31 template <class Arch>
     32 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
     33 interleave(xsimd::batch<int16_t, Arch> first,
     34           xsimd::batch<int16_t, Arch> second,
     35           xsimd::kernel::requires_arch<xsimd::avx512bw>) {
     36  return {_mm512_unpacklo_epi16(first, second),
     37          _mm512_unpackhi_epi16(first, second)};
     38 }
     39 
     40 template <class Arch>
     41 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
     42 interleave(xsimd::batch<int32_t, Arch> first,
     43           xsimd::batch<int32_t, Arch> second,
     44           xsimd::kernel::requires_arch<xsimd::avx512bw>) {
     45  return {_mm512_unpacklo_epi32(first, second),
     46          _mm512_unpackhi_epi32(first, second)};
     47 }
     48 
     49 template <class Arch>
     50 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
     51 interleave(xsimd::batch<int64_t, Arch> first,
     52           xsimd::batch<int64_t, Arch> second,
     53           xsimd::kernel::requires_arch<xsimd::avx512bw>) {
     54  return {_mm512_unpacklo_epi64(first, second),
     55          _mm512_unpackhi_epi64(first, second)};
     56 }
     57 
     58 template <class Arch>
     59 xsimd::batch<int8_t, Arch>
     60 deinterleave(xsimd::batch<int16_t, Arch> first,
     61             xsimd::batch<int16_t, Arch> second,
     62             xsimd::kernel::requires_arch<xsimd::avx512bw>) {
     63  return _mm512_packs_epi16(first, second);
     64 }
     65 
     66 template <class Arch>
     67 xsimd::batch<int16_t, Arch>
     68 deinterleave(xsimd::batch<int32_t, Arch> first,
     69             xsimd::batch<int32_t, Arch> second,
     70             xsimd::kernel::requires_arch<xsimd::avx512bw>) {
     71  return _mm512_packs_epi32(first, second);
     72 }
     73 
     74 template <class Arch>
     75 inline xsimd::batch<int32_t, Arch>
     76 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
     77     xsimd::kernel::requires_arch<xsimd::avx512bw>) {
     78  return _mm512_madd_epi16(x, y);
     79 }
     80 
     81 template <class Arch>
     82 inline xsimd::batch<int16_t, Arch>
     83 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
     84     xsimd::kernel::requires_arch<xsimd::avx512bw>) {
     85  return _mm512_maddubs_epi16(x, y);
     86 }
     87 
     88 template <class Arch>
     89 inline xsimd::batch<int16_t, Arch>
     90 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
     91     xsimd::kernel::requires_arch<xsimd::avx512bw>) {
     92  return _mm512_madd_epi16(x, y);
     93 }
     94 
     95 template <class Arch>
     96 inline xsimd::batch<int32_t, xsimd::avx2>
     97 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
     98              xsimd::batch<int32_t, Arch> pack4567,
     99              xsimd::kernel::requires_arch<xsimd::avx512bw>) {
    100  // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567,
    101  // 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567]
    102  __m512i mix0 =
    103      _mm512_mask_permutex_epi64(pack0123, 0xcc, pack4567, (0 << 4) | (1 << 6));
    104  // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567,
    105  // 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567]
    106  __m512i mix1 =
    107      _mm512_mask_permutex_epi64(pack4567, 0x33, pack0123, 2 | (3 << 2));
    108  __m512i added = _mm512_add_epi32(mix0, mix1);
    109  // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7.
    110  // Fold register over itself.
    111  return _mm256_add_epi32(_mm512_castsi512_si256(added),
    112                          _mm512_extracti64x4_epi64(added, 1));
    113 }
    114 #endif
    115 
    116 #ifdef __AVX2__
    117 template <class Arch>
    118 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
    119 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
    120           xsimd::kernel::requires_arch<xsimd::avx2>) {
    121  return {_mm256_unpacklo_epi8(first, second),
    122          _mm256_unpackhi_epi8(first, second)};
    123 }
    124 
    125 template <class Arch>
    126 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
    127 interleave(xsimd::batch<int16_t, Arch> first,
    128           xsimd::batch<int16_t, Arch> second,
    129           xsimd::kernel::requires_arch<xsimd::avx2>) {
    130  return {_mm256_unpacklo_epi16(first, second),
    131          _mm256_unpackhi_epi16(first, second)};
    132 }
    133 
    134 template <class Arch>
    135 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
    136 interleave(xsimd::batch<int32_t, Arch> first,
    137           xsimd::batch<int32_t, Arch> second,
    138           xsimd::kernel::requires_arch<xsimd::avx2>) {
    139  return {_mm256_unpacklo_epi32(first, second),
    140          _mm256_unpackhi_epi32(first, second)};
    141 }
    142 
    143 template <class Arch>
    144 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
    145 interleave(xsimd::batch<int64_t, Arch> first,
    146           xsimd::batch<int64_t, Arch> second,
    147           xsimd::kernel::requires_arch<xsimd::avx2>) {
    148  return {_mm256_unpacklo_epi64(first, second),
    149          _mm256_unpackhi_epi64(first, second)};
    150 }
    151 
    152 template <class Arch>
    153 xsimd::batch<int8_t, Arch>
    154 deinterleave(xsimd::batch<int16_t, Arch> first,
    155             xsimd::batch<int16_t, Arch> second,
    156             xsimd::kernel::requires_arch<xsimd::avx2>) {
    157  return _mm256_packs_epi16(first, second);
    158 }
    159 
    160 template <class Arch>
    161 xsimd::batch<int16_t, Arch>
    162 deinterleave(xsimd::batch<int32_t, Arch> first,
    163             xsimd::batch<int32_t, Arch> second,
    164             xsimd::kernel::requires_arch<xsimd::avx2>) {
    165  return _mm256_packs_epi32(first, second);
    166 }
    167 
    168 template <class Arch>
    169 inline xsimd::batch<int32_t, Arch>
    170 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
    171     xsimd::kernel::requires_arch<xsimd::avx2>) {
    172  return _mm256_madd_epi16(x, y);
    173 }
    174 
    175 template <class Arch>
    176 inline xsimd::batch<int16_t, Arch>
    177 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    178     xsimd::kernel::requires_arch<xsimd::avx2>) {
    179  return _mm256_maddubs_epi16(x, y);
    180 }
    181 
    182 template <class Arch>
    183 inline xsimd::batch<int16_t, Arch>
    184 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    185     xsimd::kernel::requires_arch<xsimd::avx2>) {
    186  return _mm256_maddubs_epi16(xsimd::abs(x), _mm256_sign_epi8(y, x));
    187 }
    188 
    189 template <class Arch>
    190 inline xsimd::batch<int32_t, Arch>
    191 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
    192              xsimd::batch<int32_t, Arch> pack4567,
    193              xsimd::kernel::requires_arch<xsimd::avx2>) {
    194  // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f
    195  __m256i rev = _mm256_permute2f128_si256(pack0123, pack4567, 0x21);
    196  // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s
    197  __m256i blended = _mm256_blend_epi32(pack0123, pack4567, 0xf0);
    198  return _mm256_add_epi32(rev, blended);
    199 }
    200 
    201 template <class Arch>
    202 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
    203                                      xsimd::batch<int32_t, Arch> sum1,
    204                                      xsimd::batch<int32_t, Arch> sum2,
    205                                      xsimd::batch<int32_t, Arch> sum3,
    206                                      xsimd::kernel::requires_arch<xsimd::avx2>) {
    207  auto pack01 = _mm256_hadd_epi32(sum0, sum1);
    208  auto pack23 = _mm256_hadd_epi32(sum2, sum3);
    209  return _mm256_hadd_epi32(pack01, pack23);
    210 }
    211 
    212 #ifdef __AVXVNNI__
    213 
    214 template <class Arch>
    215 inline xsimd::batch<int32_t, Arch>
    216 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    217      xsimd::batch<int32_t, Arch> z,
    218      xsimd::kernel::requires_arch<xsimd::avxvnni>) {
    219  return _mm256_dpbusd_avx_epi32(z, x, y);
    220 }
    221 #endif
    222 
    223 #ifdef __AVX512VNNI__
    224 
    225 template <class Arch>
    226 inline xsimd::batch<int32_t, Arch>
    227 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    228      xsimd::batch<int32_t, Arch> z,
    229      xsimd::kernel::requires_arch<xsimd::avx512vnni<xsimd::avx512bw>>) {
    230  return _mm512_dpbusd_epi32(z, x, y);
    231 }
    232 
    233 template <class Arch>
    234 inline xsimd::batch<int32_t, Arch>
    235 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    236      xsimd::batch<int32_t, Arch> z,
    237      xsimd::kernel::requires_arch<xsimd::avx512vnni<xsimd::avx512vbmi>>) {
    238  return _mm512_dpbusd_epi32(z, x, y);
    239 }
    240 #endif
    241 
    242 #endif
    243 
    244 #ifdef __SSSE3__
    245 
    246 template <class Arch>
    247 inline xsimd::batch<int16_t, Arch>
    248 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    249     xsimd::kernel::requires_arch<xsimd::ssse3>) {
    250  return _mm_maddubs_epi16(x, y);
    251 }
    252 
    253 template <class Arch>
    254 inline xsimd::batch<int16_t, Arch>
    255 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    256     xsimd::kernel::requires_arch<xsimd::ssse3>) {
    257  return _mm_maddubs_epi16(xsimd::abs(x), _mm_sign_epi8(y, x));
    258 }
    259 
    260 template <class Arch>
    261 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
    262                                      xsimd::batch<int32_t, Arch> sum1,
    263                                      xsimd::batch<int32_t, Arch> sum2,
    264                                      xsimd::batch<int32_t, Arch> sum3,
    265                                      xsimd::kernel::requires_arch<xsimd::ssse3>) {
    266  auto pack01 = _mm_hadd_epi32(sum0, sum1);
    267  auto pack23 = _mm_hadd_epi32(sum2, sum3);
    268  return _mm_hadd_epi32(pack01, pack23);
    269 }
    270 #endif
    271 
    272 #ifdef __SSE2__
    273 template <class Arch>
    274 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
    275 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
    276           xsimd::kernel::requires_arch<xsimd::sse2>) {
    277  return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
    278 }
    279 
    280 template <class Arch>
    281 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
    282 interleave(xsimd::batch<int16_t, Arch> first,
    283           xsimd::batch<int16_t, Arch> second,
    284           xsimd::kernel::requires_arch<xsimd::sse2>) {
    285  return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
    286 }
    287 
    288 template <class Arch>
    289 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
    290 interleave(xsimd::batch<int32_t, Arch> first,
    291           xsimd::batch<int32_t, Arch> second,
    292           xsimd::kernel::requires_arch<xsimd::sse2>) {
    293  return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
    294 }
    295 
    296 template <class Arch>
    297 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
    298 interleave(xsimd::batch<int64_t, Arch> first,
    299           xsimd::batch<int64_t, Arch> second,
    300           xsimd::kernel::requires_arch<xsimd::sse2>) {
    301  return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
    302 }
    303 
    304 template <class Arch>
    305 xsimd::batch<int8_t, Arch>
    306 deinterleave(xsimd::batch<int16_t, Arch> first,
    307             xsimd::batch<int16_t, Arch> second,
    308             xsimd::kernel::requires_arch<xsimd::sse2>) {
    309  return _mm_packs_epi16(first, second);
    310 }
    311 
    312 template <class Arch>
    313 xsimd::batch<int16_t, Arch>
    314 deinterleave(xsimd::batch<int32_t, Arch> first,
    315             xsimd::batch<int32_t, Arch> second,
    316             xsimd::kernel::requires_arch<xsimd::sse2>) {
    317  return _mm_packs_epi32(first, second);
    318 }
    319 
    320 template <class Arch>
    321 inline xsimd::batch<int32_t, Arch>
    322 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
    323     xsimd::kernel::requires_arch<xsimd::sse2>) {
    324  return _mm_madd_epi16(x, y);
    325 }
    326 
    327 template <class Arch>
    328 inline xsimd::batch<int16_t, Arch>
    329 madd(xsimd::batch<uint8_t, Arch> a, xsimd::batch<int8_t, Arch> b,
    330     xsimd::kernel::requires_arch<xsimd::sse2>) {
    331  // Adapted from
    332  // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2
    333  // a = 0x00 0x01 0xFE 0x04 ...
    334  // b = 0x00 0x02 0x80 0x84 ...
    335 
    336  // To extend signed 8-bit value, MSB has to be set to 0xFF
    337  __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128());
    338 
    339  // sign_mask_b = 0x00 0x00 0xFF 0xFF ...
    340 
    341  // Unpack positives with 0x00, negatives with 0xFF
    342  __m128i a_epi16_l = _mm_unpacklo_epi8(a, _mm_setzero_si128());
    343  __m128i a_epi16_h = _mm_unpackhi_epi8(a, _mm_setzero_si128());
    344  __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b);
    345  __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b);
    346 
    347  // Here - valid 16-bit signed integers corresponding to the 8-bit input
    348  // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ...
    349 
    350  // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts
    351  __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l);
    352  __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h);
    353 
    354  // Now go back from 32-bit values to 16-bit values & signed saturate
    355  return _mm_packs_epi32(madd_epi32_l, madd_epi32_h);
    356 }
    357 
    358 template <class Arch>
    359 inline xsimd::batch<int16_t, Arch>
    360 madd(xsimd::batch<int8_t, Arch> a, xsimd::batch<int8_t, Arch> b,
    361     xsimd::kernel::requires_arch<xsimd::sse2>) {
    362  // adapted
    363  // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2
    364  // a = 0x00 0x01 0xFE 0x04 ...
    365  // b = 0x00 0x02 0x80 0x84 ...
    366 
    367  // To extend signed 8-bit value, MSB has to be set to 0xFF
    368  __m128i sign_mask_a = _mm_cmplt_epi8(a, _mm_setzero_si128());
    369  __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128());
    370 
    371  // sign_mask_a = 0x00 0x00 0xFF 0x00 ...
    372  // sign_mask_b = 0x00 0x00 0xFF 0xFF ...
    373 
    374  // Unpack positives with 0x00, negatives with 0xFF
    375  __m128i a_epi16_l = _mm_unpacklo_epi8(a, sign_mask_a);
    376  __m128i a_epi16_h = _mm_unpackhi_epi8(a, sign_mask_a);
    377  __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b);
    378  __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b);
    379 
    380  // Here - valid 16-bit signed integers corresponding to the 8-bit input
    381  // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ...
    382 
    383  // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts
    384  __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l);
    385  __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h);
    386 
    387  // Now go back from 32-bit values to 16-bit values & signed saturate
    388  return _mm_packs_epi32(madd_epi32_l, madd_epi32_h);
    389 }
    390 
    391 template <class Arch>
    392 inline std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
    393 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
    394              xsimd::batch<int32_t, Arch> pack4567,
    395              xsimd::kernel::requires_arch<xsimd::sse2>) {
    396  return {pack0123, pack4567};
    397 }
    398 
    399 #endif
    400 
    401 #if __ARM_ARCH >= 7
    402 template <class Arch>
    403 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
    404 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
    405           xsimd::kernel::requires_arch<xsimd::neon>) {
    406  return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
    407 }
    408 
    409 template <class Arch>
    410 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
    411 interleave(xsimd::batch<int16_t, Arch> first,
    412           xsimd::batch<int16_t, Arch> second,
    413           xsimd::kernel::requires_arch<xsimd::neon>) {
    414  return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
    415 }
    416 
    417 template <class Arch>
    418 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
    419 interleave(xsimd::batch<int32_t, Arch> first,
    420           xsimd::batch<int32_t, Arch> second,
    421           xsimd::kernel::requires_arch<xsimd::neon>) {
    422  return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
    423 }
    424 
    425 template <class Arch>
    426 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
    427 interleave(xsimd::batch<int64_t, Arch> first,
    428           xsimd::batch<int64_t, Arch> second,
    429           xsimd::kernel::requires_arch<xsimd::neon>) {
    430  return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)};
    431 }
    432 
    433 template <class Arch>
    434 xsimd::batch<int8_t, Arch>
    435 deinterleave(xsimd::batch<int16_t, Arch> first,
    436             xsimd::batch<int16_t, Arch> second,
    437             xsimd::kernel::requires_arch<xsimd::neon>) {
    438 
    439  return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second));
    440 }
    441 
    442 template <class Arch>
    443 xsimd::batch<int16_t, Arch>
    444 deinterleave(xsimd::batch<int32_t, Arch> first,
    445             xsimd::batch<int32_t, Arch> second,
    446             xsimd::kernel::requires_arch<xsimd::neon>) {
    447  return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second));
    448 }
    449 
    450 template <class Arch>
    451 inline xsimd::batch<int32_t, Arch>
    452 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
    453     xsimd::kernel::requires_arch<xsimd::neon>) {
    454 
    455  int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
    456  int32x4_t high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
    457 
    458  int32x2_t low_sum = vpadd_s32(vget_low_s32(low), vget_high_s32(low));
    459  int32x2_t high_sum = vpadd_s32(vget_low_s32(high), vget_high_s32(high));
    460 
    461  return vcombine_s32(low_sum, high_sum);
    462 }
    463 
    464 template <class Arch>
    465 inline xsimd::batch<int16_t, Arch>
    466 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    467     xsimd::kernel::requires_arch<xsimd::neon>) {
    468 
    469  // This would be much simpler if x86 would choose to zero extend OR sign
    470  // extend, not both. This could probably be optimized better.
    471 
    472  // Zero extend x
    473  int16x8_t x_odd =
    474      vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_u8(x), 8));
    475  int16x8_t x_even = vreinterpretq_s16_u16(
    476      vbicq_u16(vreinterpretq_u16_u8(x), vdupq_n_u16(0xff00)));
    477 
    478  // Sign extend by shifting left then shifting right.
    479  int16x8_t y_even = vshrq_n_s16(vshlq_n_s16(vreinterpretq_s16_s8(y), 8), 8);
    480  int16x8_t y_odd = vshrq_n_s16(vreinterpretq_s16_s8(y), 8);
    481 
    482  // multiply
    483  int16x8_t prod1 = vmulq_s16(x_even, y_even);
    484  int16x8_t prod2 = vmulq_s16(x_odd, y_odd);
    485 
    486  // saturated add
    487  return vqaddq_s16(prod1, prod2);
    488 }
    489 
    490 template <class Arch>
    491 inline xsimd::batch<int16_t, Arch>
    492 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    493     xsimd::kernel::requires_arch<xsimd::neon>) {
    494  int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y));
    495  int16x8_t high = vmull_s8(vget_high_s8(x), vget_high_s8(y));
    496 
    497  int16x4_t low_sum = vpadd_s16(vget_low_s16(low), vget_high_s16(low));
    498  int16x4_t high_sum = vpadd_s16(vget_low_s16(high), vget_high_s16(high));
    499 
    500  return vcombine_s16(low_sum, high_sum);
    501 }
    502 
    503 template <class Arch>
    504 inline std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
    505 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
    506              xsimd::batch<int32_t, Arch> pack4567,
    507              xsimd::kernel::requires_arch<xsimd::neon>) {
    508  return {pack0123, pack4567};
    509 }
    510 #endif
    511 
    512 #ifdef __aarch64__
    513 template <class Arch>
    514 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>>
    515 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second,
    516           xsimd::kernel::requires_arch<xsimd::neon64>) {
    517  return {vzip1q_s8(first, second), vzip2q_s8(first, second)};
    518 }
    519 
    520 template <class Arch>
    521 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>>
    522 interleave(xsimd::batch<int16_t, Arch> first,
    523           xsimd::batch<int16_t, Arch> second,
    524           xsimd::kernel::requires_arch<xsimd::neon64>) {
    525  return {vzip1q_s16(first, second), vzip2q_s16(first, second)};
    526 }
    527 
    528 template <class Arch>
    529 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
    530 interleave(xsimd::batch<int32_t, Arch> first,
    531           xsimd::batch<int32_t, Arch> second,
    532           xsimd::kernel::requires_arch<xsimd::neon64>) {
    533  return {vzip1q_s32(first, second), vzip2q_s32(first, second)};
    534 }
    535 
    536 template <class Arch>
    537 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>>
    538 interleave(xsimd::batch<int64_t, Arch> first,
    539           xsimd::batch<int64_t, Arch> second,
    540           xsimd::kernel::requires_arch<xsimd::neon64>) {
    541  return {vzip1q_s64(first, second), vzip2q_s64(first, second)};
    542 }
    543 
    544 template <class Arch>
    545 xsimd::batch<int8_t, Arch>
    546 deinterleave(xsimd::batch<int16_t, Arch> first,
    547             xsimd::batch<int16_t, Arch> second,
    548             xsimd::kernel::requires_arch<xsimd::neon64>) {
    549 
    550  return vqmovn_high_s16(vqmovn_s16(first), second);
    551 }
    552 
    553 template <class Arch>
    554 xsimd::batch<int16_t, Arch>
    555 deinterleave(xsimd::batch<int32_t, Arch> first,
    556             xsimd::batch<int32_t, Arch> second,
    557             xsimd::kernel::requires_arch<xsimd::neon64>) {
    558  return vqmovn_high_s32(vqmovn_s32(first), second);
    559 }
    560 
    561 #ifdef __ARM_FEATURE_MATMUL_INT8
    562 template <class Arch>
    563 inline xsimd::batch<int32_t, Arch>
    564 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    565      xsimd::batch<int32_t, Arch> z,
    566      xsimd::kernel::requires_arch<xsimd::i8mm<xsimd::neon64>>) {
    567  return vusdotq_s32(z, x, y);
    568 }
    569 #endif
    570 
    571 template <class Arch>
    572 inline xsimd::batch<int32_t, Arch>
    573 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    574      xsimd::batch<int32_t, Arch> z,
    575      xsimd::kernel::requires_arch<xsimd::neon64>) {
    576  int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))),
    577                           vmovl_s8(vget_low_s8(y)));
    578  int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))),
    579                           vmovl_s8(vget_high_s8(y)));
    580  return vpadalq_s16(vpadalq_s16(z, tl), th);
    581 }
    582 
    583 template <class Arch>
    584 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
    585                                      xsimd::batch<int32_t, Arch> sum1,
    586                                      xsimd::batch<int32_t, Arch> sum2,
    587                                      xsimd::batch<int32_t, Arch> sum3,
    588                                      xsimd::kernel::requires_arch<xsimd::neon64>) {
    589  auto pack01 = vpaddq_s32(sum0, sum1);
    590  auto pack23 = vpaddq_s32(sum2, sum3);
    591  return vpaddq_s32(pack01, pack23);
    592 }
    593 
    594 #endif
    595 
    596 template <class Arch>
    597 inline xsimd::batch<int32_t, Arch>
    598 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
    599      xsimd::batch<int32_t, Arch> z,
    600      xsimd::kernel::requires_arch<xsimd::generic>) {
    601  return z + madd(xsimd::batch<int16_t, Arch>(1), madd(x, y, Arch{}), Arch{});
    602 }
    603 
    604 } // namespace kernel
    605 
    606 //
    607 // Generic dispatcher for interleave, deinterleave madd and PermuteSummer
    608 //
    609 
    610 template <class T, class Arch>
    611 std::tuple<xsimd::batch<T, Arch>, xsimd::batch<T, Arch>>
    612 interleave(xsimd::batch<T, Arch> first, xsimd::batch<T, Arch> second) {
    613  return kernel::interleave(first, second, Arch{});
    614 }
    615 
    616 template <class Arch>
    617 xsimd::batch<int8_t, Arch> deinterleave(xsimd::batch<int16_t, Arch> first,
    618                                        xsimd::batch<int16_t, Arch> second) {
    619  return kernel::deinterleave(first, second, Arch{});
    620 }
    621 template <class Arch>
    622 xsimd::batch<int16_t, Arch> deinterleave(xsimd::batch<int32_t, Arch> first,
    623                                         xsimd::batch<int32_t, Arch> second) {
    624  return kernel::deinterleave(first, second, Arch{});
    625 }
    626 
    627 template <class Arch>
    628 inline xsimd::batch<int32_t, Arch> madd(xsimd::batch<int16_t, Arch> x,
    629                                        xsimd::batch<int16_t, Arch> y) {
    630  return kernel::madd(x, y, Arch{});
    631 }
    632 template <class Arch>
    633 inline xsimd::batch<int16_t, Arch> madd(xsimd::batch<int8_t, Arch> x,
    634                                        xsimd::batch<int8_t, Arch> y) {
    635  return kernel::madd(x, y, Arch{});
    636 }
    637 template <class Arch>
    638 inline xsimd::batch<int16_t, Arch> madd(xsimd::batch<uint8_t, Arch> x,
    639                                        xsimd::batch<int8_t, Arch> y) {
    640  return kernel::madd(x, y, Arch{});
    641 }
    642 template <class Arch>
    643 inline xsimd::batch<int32_t, Arch> maddw(xsimd::batch<uint8_t, Arch> x,
    644                                         xsimd::batch<int8_t, Arch> y,
    645                                         xsimd::batch<int32_t, Arch> z
    646                                         ) {
    647  return kernel::maddw(x, y, z, Arch{});
    648 }
    649 template <class Arch>
    650 inline xsimd::batch<int32_t, Arch> maddw(xsimd::batch<uint8_t, Arch> x,
    651                                         xsimd::batch<int8_t, Arch> y
    652                                         ) {
    653  return maddw(x, y, xsimd::batch<int32_t, Arch>((int32_t)0));
    654 }
    655 
    656 template <class Arch>
    657 inline auto PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
    658                          xsimd::batch<int32_t, Arch> pack4567)
    659    -> decltype(kernel::PermuteSummer(pack0123, pack4567, Arch{})) {
    660  return kernel::PermuteSummer(pack0123, pack4567, Arch{});
    661 }
    662 
    663 
    664 namespace kernel {
    665 
    666  template <class Arch>
    667  inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
    668                                        xsimd::batch<int32_t, Arch> sum1,
    669                                        xsimd::batch<int32_t, Arch> sum2,
    670                                        xsimd::batch<int32_t, Arch> sum3,
    671                                        xsimd::kernel::requires_arch<xsimd::generic>) {
    672 
    673    std::tie(sum0, sum1) = interleave(sum0, sum1, Arch{});
    674    auto pack01 = sum0 + sum1;
    675    std::tie(sum2, sum3) = interleave(sum2, sum3, Arch{});
    676    auto pack23 = sum2 + sum3;
    677 
    678    auto packed = interleave(xsimd::bitwise_cast<int64_t>(pack01),
    679                             xsimd::bitwise_cast<int64_t>(pack23),
    680                             Arch{});
    681    return xsimd::bitwise_cast<int32_t>(std::get<0>(packed)) +
    682           xsimd::bitwise_cast<int32_t>(std::get<1>(packed));
    683  }
    684 }
    685 
    686 template <class Arch>
    687 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0,
    688                                      xsimd::batch<int32_t, Arch> sum1,
    689                                      xsimd::batch<int32_t, Arch> sum2,
    690                                      xsimd::batch<int32_t, Arch> sum3) {
    691  return kernel::Pack0123(sum0, sum1, sum2, sum3, Arch{});
    692 }
    693 
    694 template <class Arch>
    695 static inline xsimd::batch<int32_t, Arch>
    696 quantize(xsimd::batch<float, Arch> input,
    697         xsimd::batch<float, Arch> quant_mult) {
    698  return xsimd::nearbyint_as_int(input * quant_mult);
    699 }
    700 
    701 template <class Arch>
    702 inline xsimd::batch<int32_t, Arch>
    703 QuantizerGrab(const float *input, xsimd::batch<float, Arch> quant_mult_reg) {
    704  return quantize(xsimd::batch<float, Arch>::load_unaligned(input),
    705                  quant_mult_reg);
    706 }
    707 
    708 #ifdef __AVX512BW__
    709 inline __m512 Concat(const __m256 first, const __m256 second) {
    710  // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway.
    711  return _mm512_insertf32x8(_mm512_castps256_ps512(first), second, 1);
    712 }
    713 
    714 // Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be
    715 // controlled independently.
    716 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set
    717 * INTGEMM_AVX512BW */
    718 inline __m512i QuantizerGrabHalves(const float *input0, const float *input1,
    719                                   const __m512 quant_mult_reg) {
    720  __m512 appended = Concat(_mm256_loadu_ps(input0), _mm256_loadu_ps(input1));
    721  appended = _mm512_mul_ps(appended, quant_mult_reg);
    722  return _mm512_cvtps_epi32(appended);
    723 }
    724 #else
    725 template <class Arch>
    726 inline xsimd::batch<int32_t, Arch>
    727 QuantizerGrabHalves(const float *input0, const float *input1,
    728                    xsimd::batch<float, Arch> quant_mult_reg);
    729 #endif
    730 
    731 /* Read 8 floats at a time from input0, input1, input2, and input3.  Quantize
    732 * them to 8-bit by multiplying with quant_mult_reg then rounding. Concatenate
    733 * the result into one register and return it.
    734 */
    735 class QuantizeTile8 {
    736  template <class Arch> struct Tiler {
    737    static constexpr uint32_t get(std::size_t i, std::size_t n) {
    738      size_t factor = xsimd::batch<float, Arch>::size / 4;
    739      return (i % factor) * 4 + i / factor;
    740    }
    741  };
    742 
    743 public:
    744  template <class Arch>
    745  static inline xsimd::batch<int8_t, Arch>
    746  Consecutive(xsimd::batch<float, Arch> quant_mult, const float *input) {
    747    return Tile(quant_mult, input + 0 * xsimd::batch<float, Arch>::size,
    748                input + 1 * xsimd::batch<float, Arch>::size,
    749                input + 2 * xsimd::batch<float, Arch>::size,
    750                input + 3 * xsimd::batch<float, Arch>::size);
    751  }
    752 
    753  template <class Arch>
    754  static inline xsimd::batch<uint8_t, Arch>
    755  ConsecutiveU(xsimd::batch<float, Arch> quant_mult, const float *input) {
    756    return TileU(quant_mult, input + 0 * xsimd::batch<float, Arch>::size,
    757                 input + 1 * xsimd::batch<float, Arch>::size,
    758                 input + 2 * xsimd::batch<float, Arch>::size,
    759                 input + 3 * xsimd::batch<float, Arch>::size);
    760  }
    761 
    762  template <class Arch>
    763  static inline xsimd::batch<int8_t, Arch>
    764  ConsecutiveWithWrapping(xsimd::batch<float, Arch> quant_mult,
    765                          const float *input, size_t cols_left, size_t cols,
    766                          size_t row_step) {
    767    using batchf32 = xsimd::batch<float, Arch>;
    768    const float *inputs[4];
    769    for (size_t i = 0; i < std::size(inputs); ++i) {
    770      while (cols_left < batchf32::size) {
    771        input += cols * (row_step - 1);
    772        cols_left += cols;
    773      }
    774      inputs[i] = input;
    775      input += batchf32::size;
    776      cols_left -= batchf32::size;
    777    }
    778    return Tile(quant_mult, inputs[0], inputs[1], inputs[2], inputs[3]);
    779  }
    780 
    781  template <class Arch>
    782  static inline xsimd::batch<int8_t, Arch>
    783  ForReshape(xsimd::batch<float, Arch> quant_mult, const float *input,
    784             size_t cols) {
    785    using batchf32 = xsimd::batch<float, Arch>;
    786    using batch8 = xsimd::batch<int8_t, Arch>;
    787    using batch16 = xsimd::batch<int16_t, Arch>;
    788    using batch32 = xsimd::batch<int32_t, Arch>;
    789 
    790    // Put higher rows in the second half of the register.  These will jumble
    791    // around in the same way then conveniently land in the right place.
    792    if constexpr (batchf32::size == 16) {
    793      const batch8 neg127(-127);
    794      // In reverse order: grabbing the first 32-bit values from each 128-bit
    795      // register, then the second 32-bit values, etc. Grab 4 registers at a
    796      // time in 32-bit format.
    797      batch32 g0 =
    798          QuantizerGrabHalves(input + 0 * cols, input + 2 * cols, quant_mult);
    799      batch32 g1 =
    800          QuantizerGrabHalves(input + 16 * cols, input + 18 * cols, quant_mult);
    801      batch32 g2 =
    802          QuantizerGrabHalves(input + 32 * cols, input + 34 * cols, quant_mult);
    803      batch32 g3 =
    804          QuantizerGrabHalves(input + 48 * cols, input + 50 * cols, quant_mult);
    805 
    806      // Pack 32-bit to 16-bit.
    807      batch16 packed0 = deinterleave(g0, g1);
    808      batch16 packed1 = deinterleave(g2, g3);
    809      // Pack 16-bit to 8-bit.
    810      batch8 packed = deinterleave(packed0, packed1);
    811      // Ban -128.
    812      packed = xsimd::max(packed, neg127);
    813 
    814      return xsimd::bitwise_cast<int8_t>(
    815          xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed),
    816                         xsimd::make_batch_constant<uint32_t, Arch, Tiler<Arch>>()));
    817    } else if constexpr (batchf32::size == 8)
    818      return Tile(quant_mult, input, input + 2 * cols, input + 16 * cols,
    819                  input + 18 * cols);
    820    else if constexpr (batchf32::size == 4)
    821      // Skip a row.
    822      return Tile(quant_mult, input, input + 4, input + 2 * cols,
    823                  input + 2 * cols + 4);
    824    else
    825      return {};
    826  }
    827 
    828  template <class Arch>
    829  static inline xsimd::batch<int8_t, Arch>
    830  Tile(xsimd::batch<float, Arch> quant_mult, const float *input0,
    831       const float *input1, const float *input2, const float *input3) {
    832    using batch8 = xsimd::batch<int8_t, Arch>;
    833    using batch16 = xsimd::batch<int16_t, Arch>;
    834    using batch32 = xsimd::batch<int32_t, Arch>;
    835 
    836    const batch8 neg127(-127);
    837    // Grab 4 registers at a time in 32-bit format.
    838    batch32 g0 = QuantizerGrab(input0, quant_mult);
    839    batch32 g1 = QuantizerGrab(input1, quant_mult);
    840    batch32 g2 = QuantizerGrab(input2, quant_mult);
    841    batch32 g3 = QuantizerGrab(input3, quant_mult);
    842    // Pack 32-bit to 16-bit.
    843    batch16 packed0 = deinterleave(g0, g1);
    844    batch16 packed1 = deinterleave(g2, g3);
    845    // Pack 16-bit to 8-bit.
    846    batch8 packed = deinterleave(packed0, packed1);
    847    // Ban -128.
    848    packed = xsimd::max(packed, neg127);
    849 
    850    if constexpr (batch32::size == 4)
    851      return packed;
    852    // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14
    853    // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7
    854    // Technically this could be removed so long as the rows are bigger than 16
    855    // and the values are only used for GEMM.
    856    return xsimd::bitwise_cast<int8_t>(
    857        xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed),
    858                       xsimd::make_batch_constant<uint32_t, Arch, Tiler<Arch>>()));
    859  }
    860 
    861 private:
    862  // A version that produces uint8_ts
    863  template <class Arch>
    864  static inline xsimd::batch<uint8_t, Arch>
    865  TileU(xsimd::batch<float, Arch> quant_mult, const float *input0,
    866        const float *input1, const float *input2, const float *input3) {
    867    using batch8 = xsimd::batch<int8_t, Arch>;
    868    using batch16 = xsimd::batch<int16_t, Arch>;
    869    using batch32 = xsimd::batch<int32_t, Arch>;
    870 
    871    const batch8 neg127 = -127;
    872    const batch8 pos127 = +127;
    873    // Grab 4 registers at a time in 32-bit format.
    874    batch32 g0 = QuantizerGrab(input0, quant_mult);
    875    batch32 g1 = QuantizerGrab(input1, quant_mult);
    876    batch32 g2 = QuantizerGrab(input2, quant_mult);
    877    batch32 g3 = QuantizerGrab(input3, quant_mult);
    878    // Pack 32-bit to 16-bit.
    879    batch16 packed0 = deinterleave(g0, g1);
    880    batch16 packed1 = deinterleave(g2, g3);
    881    // Pack 16-bit to 8-bit.
    882    batch8 packed = deinterleave(packed0, packed1);
    883    // Ban -128.
    884    packed = xsimd::max(packed, neg127); // Could be removed  if we use +128
    885    packed = packed + pos127;
    886    if (batch32::size == 4)
    887      return xsimd::bitwise_cast<uint8_t>(packed);
    888    // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14
    889    // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7
    890    // Technically this could be removed so long as the rows are bigger than 16
    891    // and the values are only used for GEMM.
    892    return xsimd::bitwise_cast<uint8_t>(
    893        xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed),
    894                       xsimd::make_batch_constant<uint32_t, Arch, Tiler<Arch>>()));
    895  }
    896 };
    897 
    898 template <class Arch>
    899 inline void Transpose16InLane(
    900    xsimd::batch<int8_t, Arch> &r0, xsimd::batch<int8_t, Arch> &r1,
    901    xsimd::batch<int8_t, Arch> &r2, xsimd::batch<int8_t, Arch> &r3,
    902    xsimd::batch<int8_t, Arch> &r4, xsimd::batch<int8_t, Arch> &r5,
    903    xsimd::batch<int8_t, Arch> &r6, xsimd::batch<int8_t, Arch> &r7) {
    904  /* r0: columns 0 1 2 3 4 5 6 7 from row 0
    905     r1: columns 0 1 2 3 4 5 6 7 from row 1*/
    906  auto r0_16 = xsimd::bitwise_cast<int16_t>(r0);
    907  auto r1_16 = xsimd::bitwise_cast<int16_t>(r1);
    908  auto r2_16 = xsimd::bitwise_cast<int16_t>(r2);
    909  auto r3_16 = xsimd::bitwise_cast<int16_t>(r3);
    910  auto r4_16 = xsimd::bitwise_cast<int16_t>(r4);
    911  auto r5_16 = xsimd::bitwise_cast<int16_t>(r5);
    912  auto r6_16 = xsimd::bitwise_cast<int16_t>(r6);
    913  auto r7_16 = xsimd::bitwise_cast<int16_t>(r7);
    914 
    915  std::tie(r0_16, r1_16) = interleave(r0_16, r1_16);
    916  std::tie(r2_16, r3_16) = interleave(r2_16, r3_16);
    917  std::tie(r4_16, r5_16) = interleave(r4_16, r5_16);
    918  std::tie(r6_16, r7_16) = interleave(r6_16, r7_16);
    919  /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1
    920     r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1
    921     r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3
    922     r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3
    923     r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5
    924     r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5
    925     r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7
    926     r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/
    927  auto r0_32 = xsimd::bitwise_cast<int32_t>(r0_16);
    928  auto r2_32 = xsimd::bitwise_cast<int32_t>(r2_16);
    929  auto r1_32 = xsimd::bitwise_cast<int32_t>(r1_16);
    930  auto r3_32 = xsimd::bitwise_cast<int32_t>(r3_16);
    931  auto r4_32 = xsimd::bitwise_cast<int32_t>(r4_16);
    932  auto r6_32 = xsimd::bitwise_cast<int32_t>(r6_16);
    933  auto r5_32 = xsimd::bitwise_cast<int32_t>(r5_16);
    934  auto r7_32 = xsimd::bitwise_cast<int32_t>(r7_16);
    935 
    936  std::tie(r0_32, r2_32) = interleave(r0_32, r2_32);
    937  std::tie(r1_32, r3_32) = interleave(r1_32, r3_32);
    938  std::tie(r4_32, r6_32) = interleave(r4_32, r6_32);
    939  std::tie(r5_32, r7_32) = interleave(r5_32, r7_32);
    940  /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3
    941     r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3
    942     r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3
    943     r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3
    944     r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7
    945     r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7
    946     r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7
    947     r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/
    948 
    949  auto r0_64 = xsimd::bitwise_cast<int64_t>(r0_32);
    950  auto r2_64 = xsimd::bitwise_cast<int64_t>(r2_32);
    951  auto r1_64 = xsimd::bitwise_cast<int64_t>(r1_32);
    952  auto r3_64 = xsimd::bitwise_cast<int64_t>(r3_32);
    953  auto r4_64 = xsimd::bitwise_cast<int64_t>(r4_32);
    954  auto r6_64 = xsimd::bitwise_cast<int64_t>(r6_32);
    955  auto r5_64 = xsimd::bitwise_cast<int64_t>(r5_32);
    956  auto r7_64 = xsimd::bitwise_cast<int64_t>(r7_32);
    957 
    958  std::tie(r0_64, r4_64) = interleave(r0_64, r4_64);
    959  std::tie(r1_64, r5_64) = interleave(r1_64, r5_64);
    960  std::tie(r2_64, r6_64) = interleave(r2_64, r6_64);
    961  std::tie(r3_64, r7_64) = interleave(r3_64, r7_64);
    962 
    963  r0 = xsimd::bitwise_cast<int8_t>(r0_64);
    964  r1 = xsimd::bitwise_cast<int8_t>(r1_64);
    965  r2 = xsimd::bitwise_cast<int8_t>(r2_64);
    966  r3 = xsimd::bitwise_cast<int8_t>(r3_64);
    967  r4 = xsimd::bitwise_cast<int8_t>(r4_64);
    968  r5 = xsimd::bitwise_cast<int8_t>(r5_64);
    969  r6 = xsimd::bitwise_cast<int8_t>(r6_64);
    970  r7 = xsimd::bitwise_cast<int8_t>(r7_64);
    971  /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7
    972     r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7
    973     r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7
    974     r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7
    975     r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7
    976     r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/
    977  /* Empirically gcc is able to remove these movs and just rename the outputs of
    978   * Interleave64. */
    979  std::swap(r1, r4);
    980  std::swap(r3, r6);
    981 }
    982 
    983 template <class Arch, typename IntegerTy>
    984 void SelectColumnsOfB(const xsimd::batch<int8_t, Arch> *input,
    985                      xsimd::batch<int8_t, Arch> *output,
    986                      size_t rows_bytes /* number of bytes in a row */,
    987                      const IntegerTy *cols_begin, const IntegerTy *cols_end) {
    988  using batch8 = xsimd::batch<int8_t, Arch>;
    989  /* Do columns for multiples of 8.*/
    990  size_t register_rows = rows_bytes / batch8::size;
    991  const batch8 *starts[8];
    992  for (; cols_begin != cols_end; cols_begin += 8) {
    993    for (size_t k = 0; k < 8; ++k) {
    994      starts[k] =
    995          input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows;
    996    }
    997    for (size_t r = 0; r < register_rows; ++r) {
    998      for (size_t k = 0; k < 8; ++k) {
    999        *(output++) = *starts[k];
   1000        starts[k] += 8;
   1001      }
   1002    }
   1003  }
   1004 }
   1005 
   1006 } // namespace
   1007 
   1008 namespace callbacks {
   1009 template <class Arch>
   1010 xsimd::batch<float, Arch> Unquantize::operator()(xsimd::batch<int32_t, Arch> total, size_t, size_t,
   1011                            size_t) {
   1012  return xsimd::batch_cast<float>(total) * unquant_mult;
   1013 }
   1014 
   1015 template <class Arch>
   1016 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> Unquantize::operator()(
   1017    std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> total,
   1018    size_t, size_t, size_t) {
   1019  return std::make_tuple(
   1020      xsimd::batch_cast<float>(std::get<0>(total)) * unquant_mult,
   1021      xsimd::batch_cast<float>(std::get<1>(total)) * unquant_mult);
   1022 }
   1023 
   1024 template <class Arch>
   1025 xsimd::batch<float, Arch> AddBias::operator()(xsimd::batch<float, Arch> total, size_t,
   1026                         size_t col_idx, size_t) {
   1027  return total + xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx);
   1028 }
   1029 
   1030 template <class Arch>
   1031 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>>
   1032 AddBias::operator()(
   1033    std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> total,
   1034    size_t, size_t col_idx, size_t) {
   1035  return std::make_tuple(
   1036      std::get<0>(total) + xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx + 0),
   1037      std::get<1>(total) +
   1038          xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx +
   1039                              xsimd::batch<float, Arch>::size));
   1040 }
   1041 
   1042 template <class Arch>
   1043 void Write::operator()(xsimd::batch<float, Arch> result, size_t row_idx,
   1044                       size_t col_idx, size_t col_size) {
   1045  result.store_aligned(output_addr + row_idx * col_size + col_idx);
   1046 }
   1047 
   1048 template <class Arch>
   1049 void Write::operator()(xsimd::batch<int32_t, Arch> result, size_t row_idx,
   1050                       size_t col_idx, size_t col_size) {
   1051  xsimd::bitwise_cast<float>(result).store_aligned(
   1052      output_addr + row_idx * col_size + col_idx);
   1053 }
   1054 
   1055 template <class Arch>
   1056 void Write::operator()(
   1057    std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> result,
   1058    size_t row_idx, size_t col_idx, size_t col_size) {
   1059  std::get<0>(result).store_aligned(output_addr + row_idx * col_size + col_idx +
   1060                                    0);
   1061  std::get<1>(result).store_aligned(output_addr + row_idx * col_size + col_idx +
   1062                                    xsimd::batch<float, Arch>::size);
   1063 }
   1064 
   1065 template <class Arch>
   1066 void Write::operator()(
   1067    std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> result,
   1068    size_t row_idx, size_t col_idx, size_t col_size) {
   1069  xsimd::bitwise_cast<float>(std::get<0>(result))
   1070      .store_aligned(output_addr + row_idx * col_size + col_idx + 0);
   1071  xsimd::bitwise_cast<float>(std::get<1>(result))
   1072      .store_aligned(output_addr + row_idx * col_size + col_idx +
   1073                     xsimd::batch<int32_t, Arch>::size);
   1074 }
   1075 
   1076 template <class T>
   1077 void UnquantizeAndWrite::operator()(T const &total, size_t row_idx,
   1078                                    size_t col_idx, size_t col_size) {
   1079  auto unquantized = unquantize(total, row_idx, col_idx, col_size);
   1080  write(unquantized, row_idx, col_idx, col_size);
   1081 }
   1082 
   1083 template <class T>
   1084 void UnquantizeAndAddBiasAndWrite::operator()(T const &total, size_t row_idx,
   1085                                              size_t col_idx, size_t col_size) {
   1086  auto unquantized = unquantize(total, row_idx, col_idx, col_size);
   1087  auto bias_added = add_bias(unquantized, row_idx, col_idx, col_size);
   1088  write(bias_added, row_idx, col_idx, col_size);
   1089 }
   1090 } // namespace callbacks
   1091 
   1092 template <class Arch>
   1093 void Engine<Arch>::QuantizeU(const float *input, uint8_t *output,
   1094                             float quant_mult, size_t size) {
   1095  using batch8 = xsimd::batch<int8_t, Arch>;
   1096 
   1097  xsimd::batch<float, Arch> q(quant_mult);
   1098  const float *end = input + size;
   1099  for (; input != end; input += batch8::size, output += batch8::size) {
   1100    auto tile = QuantizeTile8::ConsecutiveU(q, input);
   1101    tile.store_aligned(output);
   1102  }
   1103 }
   1104 
   1105 template <class Arch>
   1106 void Engine<Arch>::Quantize(const float *const input, int8_t *const output,
   1107                            float quant_mult, size_t size) {
   1108  using batch8 = xsimd::batch<int8_t, Arch>;
   1109 
   1110  const std::size_t kBatch = batch8::size;
   1111  const std::size_t fast_end = size & ~(kBatch - 1);
   1112 
   1113  xsimd::batch<float, Arch> q(quant_mult);
   1114  for (std::size_t i = 0; i < fast_end; i += kBatch) {
   1115    auto tile = QuantizeTile8::Consecutive(q, input + i);
   1116    tile.store_aligned(output + i);
   1117  }
   1118 
   1119  std::size_t overhang = size & (kBatch - 1);
   1120  if (!overhang)
   1121    return;
   1122  /* Each does size(xsimd::batch<int8_t, Arch>) / 32 == kBatch / 4 floats at a
   1123   * time. If we're allowed to read one of them, then we can read the whole
   1124   * register.
   1125   */
   1126  const float *inputs[4];
   1127  std::size_t i;
   1128  for (i = 0; i < (overhang + (kBatch / 4) - 1) / (kBatch / 4); ++i) {
   1129    inputs[i] = &input[fast_end + i * (kBatch / 4)];
   1130  }
   1131  /* These will be clipped off. */
   1132  for (; i < 4; ++i) {
   1133    inputs[i] = &input[fast_end];
   1134  }
   1135  auto result =
   1136      QuantizeTile8::Tile(q, inputs[0], inputs[1], inputs[2], inputs[3]);
   1137  std::memcpy(output + (size & ~(kBatch - 1)), &result, overhang);
   1138 }
   1139 
   1140 template <class Arch>
   1141 template <typename IntegerTy>
   1142 void Engine<Arch>::SelectColumnsB(const int8_t *input, int8_t *output,
   1143                                  size_t rows, const IntegerTy *cols_begin,
   1144                                  const IntegerTy *cols_end) {
   1145  using batch8 = xsimd::batch<int8_t, Arch>;
   1146  SelectColumnsOfB(reinterpret_cast<const batch8 *>(input),
   1147                   reinterpret_cast<batch8 *>(output), rows, cols_begin,
   1148                   cols_end);
   1149 }
   1150 
   1151 template <class Arch>
   1152 void Engine<Arch>::PrepareBTransposed(const float *input, int8_t *output,
   1153                                      float quant_mult, size_t cols,
   1154                                      size_t rows) {
   1155  using batch8 = xsimd::batch<int8_t, Arch>;
   1156  const size_t RegisterElemsInt = batch8::size;
   1157  const size_t kColStride = 8;
   1158 
   1159  xsimd::batch<float, Arch> q(quant_mult);
   1160  auto *output_it = reinterpret_cast<batch8 *>(output);
   1161  size_t r = 0;
   1162  size_t c = 0;
   1163  while (r < rows) {
   1164    for (size_t ri = 0; ri < 8; ++ri)
   1165      *output_it++ = QuantizeTile8::ConsecutiveWithWrapping(
   1166          q, input + (r + ri) * cols + c, cols - c, cols, 8);
   1167    c += RegisterElemsInt;
   1168    while (c >= cols) {
   1169      r += kColStride;
   1170      c -= cols;
   1171    }
   1172  }
   1173 }
   1174 
   1175 template <class Arch>
   1176 void Engine<Arch>::PrepareBQuantizedTransposed(const int8_t *input,
   1177                                               int8_t *output, size_t cols,
   1178                                               size_t rows) {
   1179  using batch8 = xsimd::batch<int8_t, Arch>;
   1180  const size_t RegisterElems = batch8::size;
   1181  const size_t kColStride = 8;
   1182 
   1183  auto *output_it = reinterpret_cast<batch8 *>(output);
   1184  for (size_t r = 0; r < rows; r += kColStride)
   1185    for (size_t c = 0; c < cols; c += RegisterElems)
   1186      for (size_t ri = 0; ri < 8; ++ri)
   1187        *output_it++ =
   1188            *reinterpret_cast<const batch8 *>(input + (r + ri) * cols + c);
   1189 }
   1190 
   1191 template <class Arch>
   1192 void Engine<Arch>::PrepareB(const float *input, int8_t *output_shadow,
   1193                            float quant_mult, size_t rows, size_t cols) {
   1194  using batch8 = xsimd::batch<int8_t, Arch>;
   1195 
   1196  xsimd::batch<float, Arch> q(quant_mult);
   1197  /* Currently all multipliers have a stride of 8 columns.*/
   1198  const size_t kColStride = 8;
   1199  auto *output = reinterpret_cast<batch8 *>(output_shadow);
   1200  for (size_t c = 0; c < cols; c += kColStride) {
   1201    for (size_t r = 0; r < rows; r += sizeof(*output), output += 8) {
   1202      output[0] =
   1203          QuantizeTile8::ForReshape(q, input + cols * (r + 0) + c, cols);
   1204      output[1] =
   1205          QuantizeTile8::ForReshape(q, input + cols * (r + 1) + c, cols);
   1206      output[2] =
   1207          QuantizeTile8::ForReshape(q, input + cols * (r + 4) + c, cols);
   1208      output[3] =
   1209          QuantizeTile8::ForReshape(q, input + cols * (r + 5) + c, cols);
   1210      output[4] =
   1211          QuantizeTile8::ForReshape(q, input + cols * (r + 8) + c, cols);
   1212      output[5] =
   1213          QuantizeTile8::ForReshape(q, input + cols * (r + 9) + c, cols);
   1214      output[6] =
   1215          QuantizeTile8::ForReshape(q, input + cols * (r + 12) + c, cols);
   1216      output[7] =
   1217          QuantizeTile8::ForReshape(q, input + cols * (r + 13) + c, cols);
   1218      std::tie(output[0], output[1]) =
   1219          interleave(xsimd::bitwise_cast<int8_t>(output[0]),
   1220                     xsimd::bitwise_cast<int8_t>(output[1]));
   1221      std::tie(output[2], output[3]) =
   1222          interleave(xsimd::bitwise_cast<int8_t>(output[2]),
   1223                     xsimd::bitwise_cast<int8_t>(output[3]));
   1224      std::tie(output[4], output[5]) =
   1225          interleave(xsimd::bitwise_cast<int8_t>(output[4]),
   1226                     xsimd::bitwise_cast<int8_t>(output[5]));
   1227      std::tie(output[6], output[7]) =
   1228          interleave(xsimd::bitwise_cast<int8_t>(output[6]),
   1229                     xsimd::bitwise_cast<int8_t>(output[7]));
   1230      Transpose16InLane(output[0], output[1], output[2], output[3], output[4],
   1231                        output[5], output[6], output[7]);
   1232    }
   1233  }
   1234 }
   1235 
   1236 template <class Arch>
   1237 void Engine<Arch>::PrepareA(const float *input, int8_t *output,
   1238                            float quant_mult, size_t rows, size_t cols) {
   1239  Quantize(input, output, quant_mult, rows * cols);
   1240 }
   1241 
   1242 template <class Arch>
   1243 void Engine<Arch>::Shift::PrepareA(const float *input, uint8_t *output,
   1244                                   float quant_mult, size_t rows, size_t cols) {
   1245  QuantizeU(input, output, quant_mult, rows * cols);
   1246 }
   1247 
   1248 template <class Arch>
   1249 template <class Callback>
   1250 void Engine<Arch>::Shift::Multiply(const uint8_t *A, const int8_t *B,
   1251                                   size_t A_rows, size_t width, size_t B_cols,
   1252                                   Callback callback) {
   1253 
   1254  using batch8 = xsimd::batch<int8_t, Arch>;
   1255  using ubatch8 = xsimd::batch<uint8_t, Arch>;
   1256  using batch32 = xsimd::batch<int32_t, Arch>;
   1257 
   1258  const size_t simd_width = width / batch8::size;
   1259  for (size_t B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
   1260    const auto *B0_col =
   1261        reinterpret_cast<const batch8 *>(B) + simd_width * B0_colidx;
   1262    /* Process one row of A at a time.  Doesn't seem to be faster to do multiple
   1263     * rows of A at once.*/
   1264    for (size_t A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
   1265      const auto *A_row =
   1266          reinterpret_cast<const ubatch8 *>(A + A_rowidx * width);
   1267      /* These will be packed 16-bit integers containing sums for each row of B
   1268         multiplied by the row of A. Iterate over shared (inner) dimension.*/
   1269      /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
   1270       * declared here.*/
   1271      size_t k = 0;
   1272      ubatch8 a = *(A_row + k);
   1273      batch32 isum0 = maddw(a, *(B0_col + k * 8));
   1274      batch32 isum1 = maddw(a, *(B0_col + k * 8 + 1));
   1275      batch32 isum2 = maddw(a, *(B0_col + k * 8 + 2));
   1276      batch32 isum3 = maddw(a, *(B0_col + k * 8 + 3));
   1277      batch32 isum4 = maddw(a, *(B0_col + k * 8 + 4));
   1278      batch32 isum5 = maddw(a, *(B0_col + k * 8 + 5));
   1279      batch32 isum6 = maddw(a, *(B0_col + k * 8 + 6));
   1280      batch32 isum7 = maddw(a, *(B0_col + k * 8 + 7));
   1281      for (k = 1; k < simd_width; ++k) {
   1282        a = *(A_row + k);
   1283        /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/
   1284        /* Upcast to 32-bit and horizontally add.*/
   1285        isum0 = maddw(a, *(B0_col + k * 8 + 0), isum0);
   1286        isum1 = maddw(a, *(B0_col + k * 8 + 1), isum1);
   1287        isum2 = maddw(a, *(B0_col + k * 8 + 2), isum2);
   1288        isum3 = maddw(a, *(B0_col + k * 8 + 3), isum3);
   1289        isum4 = maddw(a, *(B0_col + k * 8 + 4), isum4);
   1290        isum5 = maddw(a, *(B0_col + k * 8 + 5), isum5);
   1291        isum6 = maddw(a, *(B0_col + k * 8 + 6), isum6);
   1292        isum7 = maddw(a, *(B0_col + k * 8 + 7), isum7);
   1293      }
   1294      /* Reduce sums within 128-bit lanes.*/
   1295      auto pack0123 = Pack0123(isum0, isum1, isum2, isum3);
   1296      auto pack4567 = Pack0123(isum4, isum5, isum6, isum7);
   1297      /*The specific implementation may need to reduce further.*/
   1298      auto total = PermuteSummer(pack0123, pack4567);
   1299      callback(total, A_rowidx, B0_colidx, B_cols);
   1300    }
   1301  }
   1302 }
   1303 
   1304 template <class Arch>
   1305 template <class Callback>
   1306 void Engine<Arch>::Shift::PrepareBias(const int8_t *B, size_t width,
   1307                                      size_t B_cols, Callback C) {
   1308  using batch8 = xsimd::batch<int8_t, Arch>;
   1309  const size_t simd_width = width / batch8::size;
   1310  xsimd::batch<uint8_t, Arch> a(1);
   1311  for (size_t j = 0; j < B_cols; j += 8) {
   1312    /*Process one row of A at a time.  Doesn't seem to be faster to do multiple
   1313     * rows of A at once.*/
   1314    const int8_t *B_j = B + j * width;
   1315 
   1316    /* Rather than initializing as zeros and adding, just initialize the
   1317     * first.*/
   1318    /* These will be packed 16-bit integers containing sums for each column of
   1319     * B multiplied by the row of A.*/
   1320    /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
   1321     * declared here.*/
   1322    auto isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size]));
   1323    auto isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size]));
   1324    auto isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size]));
   1325    auto isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size]));
   1326    auto isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size]));
   1327    auto isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size]));
   1328    auto isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size]));
   1329    auto isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size]));
   1330 
   1331    B_j += 8 * batch8::size;
   1332 
   1333    for (size_t k = 1; k < simd_width; ++k, B_j += 8 * batch8::size) {
   1334      isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size]), isum0);
   1335      isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size]), isum1);
   1336      isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size]), isum2);
   1337      isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size]), isum3);
   1338      isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size]), isum4);
   1339      isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size]), isum5);
   1340      isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size]), isum6);
   1341      isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size]), isum7);
   1342    }
   1343 
   1344    auto pack0123 = Pack0123(isum0, isum1, isum2, isum3);
   1345    auto pack4567 = Pack0123(isum4, isum5, isum6, isum7);
   1346 
   1347    auto total = PermuteSummer(pack0123, pack4567);
   1348    C(total, 0, j, B_cols);
   1349  }
   1350 }
   1351 
   1352 } // namespace gemmology
   1353 
   1354 #endif