gemmology.h (52788B)
1 #ifndef GEMMOLOGY_H 2 #define GEMMOLOGY_H 3 4 #include "gemmology_fwd.h" 5 6 #include <cstdint> 7 #include <cstring> 8 #include <tuple> 9 10 #include <xsimd/xsimd.hpp> 11 12 namespace gemmology { 13 14 namespace { 15 16 // 17 // Arch specific implementation of various elementary operations 18 // 19 20 namespace kernel { 21 22 #ifdef __AVX512BW__ 23 template <class Arch> 24 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>> 25 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second, 26 xsimd::kernel::requires_arch<xsimd::avx512bw>) { 27 return {_mm512_unpacklo_epi8(first, second), 28 _mm512_unpackhi_epi8(first, second)}; 29 } 30 31 template <class Arch> 32 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>> 33 interleave(xsimd::batch<int16_t, Arch> first, 34 xsimd::batch<int16_t, Arch> second, 35 xsimd::kernel::requires_arch<xsimd::avx512bw>) { 36 return {_mm512_unpacklo_epi16(first, second), 37 _mm512_unpackhi_epi16(first, second)}; 38 } 39 40 template <class Arch> 41 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> 42 interleave(xsimd::batch<int32_t, Arch> first, 43 xsimd::batch<int32_t, Arch> second, 44 xsimd::kernel::requires_arch<xsimd::avx512bw>) { 45 return {_mm512_unpacklo_epi32(first, second), 46 _mm512_unpackhi_epi32(first, second)}; 47 } 48 49 template <class Arch> 50 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>> 51 interleave(xsimd::batch<int64_t, Arch> first, 52 xsimd::batch<int64_t, Arch> second, 53 xsimd::kernel::requires_arch<xsimd::avx512bw>) { 54 return {_mm512_unpacklo_epi64(first, second), 55 _mm512_unpackhi_epi64(first, second)}; 56 } 57 58 template <class Arch> 59 xsimd::batch<int8_t, Arch> 60 deinterleave(xsimd::batch<int16_t, Arch> first, 61 xsimd::batch<int16_t, Arch> second, 62 xsimd::kernel::requires_arch<xsimd::avx512bw>) { 63 return _mm512_packs_epi16(first, second); 64 } 65 66 template <class Arch> 67 xsimd::batch<int16_t, Arch> 68 deinterleave(xsimd::batch<int32_t, Arch> first, 69 xsimd::batch<int32_t, Arch> second, 70 xsimd::kernel::requires_arch<xsimd::avx512bw>) { 71 return _mm512_packs_epi32(first, second); 72 } 73 74 template <class Arch> 75 inline xsimd::batch<int32_t, Arch> 76 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y, 77 xsimd::kernel::requires_arch<xsimd::avx512bw>) { 78 return _mm512_madd_epi16(x, y); 79 } 80 81 template <class Arch> 82 inline xsimd::batch<int16_t, Arch> 83 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 84 xsimd::kernel::requires_arch<xsimd::avx512bw>) { 85 return _mm512_maddubs_epi16(x, y); 86 } 87 88 template <class Arch> 89 inline xsimd::batch<int16_t, Arch> 90 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 91 xsimd::kernel::requires_arch<xsimd::avx512bw>) { 92 return _mm512_madd_epi16(x, y); 93 } 94 95 template <class Arch> 96 inline xsimd::batch<int32_t, xsimd::avx2> 97 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123, 98 xsimd::batch<int32_t, Arch> pack4567, 99 xsimd::kernel::requires_arch<xsimd::avx512bw>) { 100 // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567, 101 // 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567] 102 __m512i mix0 = 103 _mm512_mask_permutex_epi64(pack0123, 0xcc, pack4567, (0 << 4) | (1 << 6)); 104 // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567, 105 // 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567] 106 __m512i mix1 = 107 _mm512_mask_permutex_epi64(pack4567, 0x33, pack0123, 2 | (3 << 2)); 108 __m512i added = _mm512_add_epi32(mix0, mix1); 109 // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7. 110 // Fold register over itself. 111 return _mm256_add_epi32(_mm512_castsi512_si256(added), 112 _mm512_extracti64x4_epi64(added, 1)); 113 } 114 #endif 115 116 #ifdef __AVX2__ 117 template <class Arch> 118 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>> 119 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second, 120 xsimd::kernel::requires_arch<xsimd::avx2>) { 121 return {_mm256_unpacklo_epi8(first, second), 122 _mm256_unpackhi_epi8(first, second)}; 123 } 124 125 template <class Arch> 126 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>> 127 interleave(xsimd::batch<int16_t, Arch> first, 128 xsimd::batch<int16_t, Arch> second, 129 xsimd::kernel::requires_arch<xsimd::avx2>) { 130 return {_mm256_unpacklo_epi16(first, second), 131 _mm256_unpackhi_epi16(first, second)}; 132 } 133 134 template <class Arch> 135 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> 136 interleave(xsimd::batch<int32_t, Arch> first, 137 xsimd::batch<int32_t, Arch> second, 138 xsimd::kernel::requires_arch<xsimd::avx2>) { 139 return {_mm256_unpacklo_epi32(first, second), 140 _mm256_unpackhi_epi32(first, second)}; 141 } 142 143 template <class Arch> 144 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>> 145 interleave(xsimd::batch<int64_t, Arch> first, 146 xsimd::batch<int64_t, Arch> second, 147 xsimd::kernel::requires_arch<xsimd::avx2>) { 148 return {_mm256_unpacklo_epi64(first, second), 149 _mm256_unpackhi_epi64(first, second)}; 150 } 151 152 template <class Arch> 153 xsimd::batch<int8_t, Arch> 154 deinterleave(xsimd::batch<int16_t, Arch> first, 155 xsimd::batch<int16_t, Arch> second, 156 xsimd::kernel::requires_arch<xsimd::avx2>) { 157 return _mm256_packs_epi16(first, second); 158 } 159 160 template <class Arch> 161 xsimd::batch<int16_t, Arch> 162 deinterleave(xsimd::batch<int32_t, Arch> first, 163 xsimd::batch<int32_t, Arch> second, 164 xsimd::kernel::requires_arch<xsimd::avx2>) { 165 return _mm256_packs_epi32(first, second); 166 } 167 168 template <class Arch> 169 inline xsimd::batch<int32_t, Arch> 170 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y, 171 xsimd::kernel::requires_arch<xsimd::avx2>) { 172 return _mm256_madd_epi16(x, y); 173 } 174 175 template <class Arch> 176 inline xsimd::batch<int16_t, Arch> 177 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 178 xsimd::kernel::requires_arch<xsimd::avx2>) { 179 return _mm256_maddubs_epi16(x, y); 180 } 181 182 template <class Arch> 183 inline xsimd::batch<int16_t, Arch> 184 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 185 xsimd::kernel::requires_arch<xsimd::avx2>) { 186 return _mm256_maddubs_epi16(xsimd::abs(x), _mm256_sign_epi8(y, x)); 187 } 188 189 template <class Arch> 190 inline xsimd::batch<int32_t, Arch> 191 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123, 192 xsimd::batch<int32_t, Arch> pack4567, 193 xsimd::kernel::requires_arch<xsimd::avx2>) { 194 // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f 195 __m256i rev = _mm256_permute2f128_si256(pack0123, pack4567, 0x21); 196 // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s 197 __m256i blended = _mm256_blend_epi32(pack0123, pack4567, 0xf0); 198 return _mm256_add_epi32(rev, blended); 199 } 200 201 template <class Arch> 202 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, 203 xsimd::batch<int32_t, Arch> sum1, 204 xsimd::batch<int32_t, Arch> sum2, 205 xsimd::batch<int32_t, Arch> sum3, 206 xsimd::kernel::requires_arch<xsimd::avx2>) { 207 auto pack01 = _mm256_hadd_epi32(sum0, sum1); 208 auto pack23 = _mm256_hadd_epi32(sum2, sum3); 209 return _mm256_hadd_epi32(pack01, pack23); 210 } 211 212 #ifdef __AVXVNNI__ 213 214 template <class Arch> 215 inline xsimd::batch<int32_t, Arch> 216 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 217 xsimd::batch<int32_t, Arch> z, 218 xsimd::kernel::requires_arch<xsimd::avxvnni>) { 219 return _mm256_dpbusd_avx_epi32(z, x, y); 220 } 221 #endif 222 223 #ifdef __AVX512VNNI__ 224 225 template <class Arch> 226 inline xsimd::batch<int32_t, Arch> 227 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 228 xsimd::batch<int32_t, Arch> z, 229 xsimd::kernel::requires_arch<xsimd::avx512vnni<xsimd::avx512bw>>) { 230 return _mm512_dpbusd_epi32(z, x, y); 231 } 232 233 template <class Arch> 234 inline xsimd::batch<int32_t, Arch> 235 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 236 xsimd::batch<int32_t, Arch> z, 237 xsimd::kernel::requires_arch<xsimd::avx512vnni<xsimd::avx512vbmi>>) { 238 return _mm512_dpbusd_epi32(z, x, y); 239 } 240 #endif 241 242 #endif 243 244 #ifdef __SSSE3__ 245 246 template <class Arch> 247 inline xsimd::batch<int16_t, Arch> 248 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 249 xsimd::kernel::requires_arch<xsimd::ssse3>) { 250 return _mm_maddubs_epi16(x, y); 251 } 252 253 template <class Arch> 254 inline xsimd::batch<int16_t, Arch> 255 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 256 xsimd::kernel::requires_arch<xsimd::ssse3>) { 257 return _mm_maddubs_epi16(xsimd::abs(x), _mm_sign_epi8(y, x)); 258 } 259 260 template <class Arch> 261 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, 262 xsimd::batch<int32_t, Arch> sum1, 263 xsimd::batch<int32_t, Arch> sum2, 264 xsimd::batch<int32_t, Arch> sum3, 265 xsimd::kernel::requires_arch<xsimd::ssse3>) { 266 auto pack01 = _mm_hadd_epi32(sum0, sum1); 267 auto pack23 = _mm_hadd_epi32(sum2, sum3); 268 return _mm_hadd_epi32(pack01, pack23); 269 } 270 #endif 271 272 #ifdef __SSE2__ 273 template <class Arch> 274 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>> 275 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second, 276 xsimd::kernel::requires_arch<xsimd::sse2>) { 277 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; 278 } 279 280 template <class Arch> 281 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>> 282 interleave(xsimd::batch<int16_t, Arch> first, 283 xsimd::batch<int16_t, Arch> second, 284 xsimd::kernel::requires_arch<xsimd::sse2>) { 285 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; 286 } 287 288 template <class Arch> 289 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> 290 interleave(xsimd::batch<int32_t, Arch> first, 291 xsimd::batch<int32_t, Arch> second, 292 xsimd::kernel::requires_arch<xsimd::sse2>) { 293 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; 294 } 295 296 template <class Arch> 297 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>> 298 interleave(xsimd::batch<int64_t, Arch> first, 299 xsimd::batch<int64_t, Arch> second, 300 xsimd::kernel::requires_arch<xsimd::sse2>) { 301 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; 302 } 303 304 template <class Arch> 305 xsimd::batch<int8_t, Arch> 306 deinterleave(xsimd::batch<int16_t, Arch> first, 307 xsimd::batch<int16_t, Arch> second, 308 xsimd::kernel::requires_arch<xsimd::sse2>) { 309 return _mm_packs_epi16(first, second); 310 } 311 312 template <class Arch> 313 xsimd::batch<int16_t, Arch> 314 deinterleave(xsimd::batch<int32_t, Arch> first, 315 xsimd::batch<int32_t, Arch> second, 316 xsimd::kernel::requires_arch<xsimd::sse2>) { 317 return _mm_packs_epi32(first, second); 318 } 319 320 template <class Arch> 321 inline xsimd::batch<int32_t, Arch> 322 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y, 323 xsimd::kernel::requires_arch<xsimd::sse2>) { 324 return _mm_madd_epi16(x, y); 325 } 326 327 template <class Arch> 328 inline xsimd::batch<int16_t, Arch> 329 madd(xsimd::batch<uint8_t, Arch> a, xsimd::batch<int8_t, Arch> b, 330 xsimd::kernel::requires_arch<xsimd::sse2>) { 331 // Adapted from 332 // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2 333 // a = 0x00 0x01 0xFE 0x04 ... 334 // b = 0x00 0x02 0x80 0x84 ... 335 336 // To extend signed 8-bit value, MSB has to be set to 0xFF 337 __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128()); 338 339 // sign_mask_b = 0x00 0x00 0xFF 0xFF ... 340 341 // Unpack positives with 0x00, negatives with 0xFF 342 __m128i a_epi16_l = _mm_unpacklo_epi8(a, _mm_setzero_si128()); 343 __m128i a_epi16_h = _mm_unpackhi_epi8(a, _mm_setzero_si128()); 344 __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b); 345 __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b); 346 347 // Here - valid 16-bit signed integers corresponding to the 8-bit input 348 // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ... 349 350 // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts 351 __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l); 352 __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h); 353 354 // Now go back from 32-bit values to 16-bit values & signed saturate 355 return _mm_packs_epi32(madd_epi32_l, madd_epi32_h); 356 } 357 358 template <class Arch> 359 inline xsimd::batch<int16_t, Arch> 360 madd(xsimd::batch<int8_t, Arch> a, xsimd::batch<int8_t, Arch> b, 361 xsimd::kernel::requires_arch<xsimd::sse2>) { 362 // adapted 363 // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2 364 // a = 0x00 0x01 0xFE 0x04 ... 365 // b = 0x00 0x02 0x80 0x84 ... 366 367 // To extend signed 8-bit value, MSB has to be set to 0xFF 368 __m128i sign_mask_a = _mm_cmplt_epi8(a, _mm_setzero_si128()); 369 __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128()); 370 371 // sign_mask_a = 0x00 0x00 0xFF 0x00 ... 372 // sign_mask_b = 0x00 0x00 0xFF 0xFF ... 373 374 // Unpack positives with 0x00, negatives with 0xFF 375 __m128i a_epi16_l = _mm_unpacklo_epi8(a, sign_mask_a); 376 __m128i a_epi16_h = _mm_unpackhi_epi8(a, sign_mask_a); 377 __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b); 378 __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b); 379 380 // Here - valid 16-bit signed integers corresponding to the 8-bit input 381 // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ... 382 383 // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts 384 __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l); 385 __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h); 386 387 // Now go back from 32-bit values to 16-bit values & signed saturate 388 return _mm_packs_epi32(madd_epi32_l, madd_epi32_h); 389 } 390 391 template <class Arch> 392 inline std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> 393 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123, 394 xsimd::batch<int32_t, Arch> pack4567, 395 xsimd::kernel::requires_arch<xsimd::sse2>) { 396 return {pack0123, pack4567}; 397 } 398 399 #endif 400 401 #if __ARM_ARCH >= 7 402 template <class Arch> 403 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>> 404 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second, 405 xsimd::kernel::requires_arch<xsimd::neon>) { 406 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; 407 } 408 409 template <class Arch> 410 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>> 411 interleave(xsimd::batch<int16_t, Arch> first, 412 xsimd::batch<int16_t, Arch> second, 413 xsimd::kernel::requires_arch<xsimd::neon>) { 414 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; 415 } 416 417 template <class Arch> 418 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> 419 interleave(xsimd::batch<int32_t, Arch> first, 420 xsimd::batch<int32_t, Arch> second, 421 xsimd::kernel::requires_arch<xsimd::neon>) { 422 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; 423 } 424 425 template <class Arch> 426 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>> 427 interleave(xsimd::batch<int64_t, Arch> first, 428 xsimd::batch<int64_t, Arch> second, 429 xsimd::kernel::requires_arch<xsimd::neon>) { 430 return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; 431 } 432 433 template <class Arch> 434 xsimd::batch<int8_t, Arch> 435 deinterleave(xsimd::batch<int16_t, Arch> first, 436 xsimd::batch<int16_t, Arch> second, 437 xsimd::kernel::requires_arch<xsimd::neon>) { 438 439 return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second)); 440 } 441 442 template <class Arch> 443 xsimd::batch<int16_t, Arch> 444 deinterleave(xsimd::batch<int32_t, Arch> first, 445 xsimd::batch<int32_t, Arch> second, 446 xsimd::kernel::requires_arch<xsimd::neon>) { 447 return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second)); 448 } 449 450 template <class Arch> 451 inline xsimd::batch<int32_t, Arch> 452 madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y, 453 xsimd::kernel::requires_arch<xsimd::neon>) { 454 455 int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y)); 456 int32x4_t high = vmull_s16(vget_high_s16(x), vget_high_s16(y)); 457 458 int32x2_t low_sum = vpadd_s32(vget_low_s32(low), vget_high_s32(low)); 459 int32x2_t high_sum = vpadd_s32(vget_low_s32(high), vget_high_s32(high)); 460 461 return vcombine_s32(low_sum, high_sum); 462 } 463 464 template <class Arch> 465 inline xsimd::batch<int16_t, Arch> 466 madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 467 xsimd::kernel::requires_arch<xsimd::neon>) { 468 469 // This would be much simpler if x86 would choose to zero extend OR sign 470 // extend, not both. This could probably be optimized better. 471 472 // Zero extend x 473 int16x8_t x_odd = 474 vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_u8(x), 8)); 475 int16x8_t x_even = vreinterpretq_s16_u16( 476 vbicq_u16(vreinterpretq_u16_u8(x), vdupq_n_u16(0xff00))); 477 478 // Sign extend by shifting left then shifting right. 479 int16x8_t y_even = vshrq_n_s16(vshlq_n_s16(vreinterpretq_s16_s8(y), 8), 8); 480 int16x8_t y_odd = vshrq_n_s16(vreinterpretq_s16_s8(y), 8); 481 482 // multiply 483 int16x8_t prod1 = vmulq_s16(x_even, y_even); 484 int16x8_t prod2 = vmulq_s16(x_odd, y_odd); 485 486 // saturated add 487 return vqaddq_s16(prod1, prod2); 488 } 489 490 template <class Arch> 491 inline xsimd::batch<int16_t, Arch> 492 madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 493 xsimd::kernel::requires_arch<xsimd::neon>) { 494 int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y)); 495 int16x8_t high = vmull_s8(vget_high_s8(x), vget_high_s8(y)); 496 497 int16x4_t low_sum = vpadd_s16(vget_low_s16(low), vget_high_s16(low)); 498 int16x4_t high_sum = vpadd_s16(vget_low_s16(high), vget_high_s16(high)); 499 500 return vcombine_s16(low_sum, high_sum); 501 } 502 503 template <class Arch> 504 inline std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> 505 PermuteSummer(xsimd::batch<int32_t, Arch> pack0123, 506 xsimd::batch<int32_t, Arch> pack4567, 507 xsimd::kernel::requires_arch<xsimd::neon>) { 508 return {pack0123, pack4567}; 509 } 510 #endif 511 512 #ifdef __aarch64__ 513 template <class Arch> 514 std::tuple<xsimd::batch<int8_t, Arch>, xsimd::batch<int8_t, Arch>> 515 interleave(xsimd::batch<int8_t, Arch> first, xsimd::batch<int8_t, Arch> second, 516 xsimd::kernel::requires_arch<xsimd::neon64>) { 517 return {vzip1q_s8(first, second), vzip2q_s8(first, second)}; 518 } 519 520 template <class Arch> 521 std::tuple<xsimd::batch<int16_t, Arch>, xsimd::batch<int16_t, Arch>> 522 interleave(xsimd::batch<int16_t, Arch> first, 523 xsimd::batch<int16_t, Arch> second, 524 xsimd::kernel::requires_arch<xsimd::neon64>) { 525 return {vzip1q_s16(first, second), vzip2q_s16(first, second)}; 526 } 527 528 template <class Arch> 529 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> 530 interleave(xsimd::batch<int32_t, Arch> first, 531 xsimd::batch<int32_t, Arch> second, 532 xsimd::kernel::requires_arch<xsimd::neon64>) { 533 return {vzip1q_s32(first, second), vzip2q_s32(first, second)}; 534 } 535 536 template <class Arch> 537 std::tuple<xsimd::batch<int64_t, Arch>, xsimd::batch<int64_t, Arch>> 538 interleave(xsimd::batch<int64_t, Arch> first, 539 xsimd::batch<int64_t, Arch> second, 540 xsimd::kernel::requires_arch<xsimd::neon64>) { 541 return {vzip1q_s64(first, second), vzip2q_s64(first, second)}; 542 } 543 544 template <class Arch> 545 xsimd::batch<int8_t, Arch> 546 deinterleave(xsimd::batch<int16_t, Arch> first, 547 xsimd::batch<int16_t, Arch> second, 548 xsimd::kernel::requires_arch<xsimd::neon64>) { 549 550 return vqmovn_high_s16(vqmovn_s16(first), second); 551 } 552 553 template <class Arch> 554 xsimd::batch<int16_t, Arch> 555 deinterleave(xsimd::batch<int32_t, Arch> first, 556 xsimd::batch<int32_t, Arch> second, 557 xsimd::kernel::requires_arch<xsimd::neon64>) { 558 return vqmovn_high_s32(vqmovn_s32(first), second); 559 } 560 561 #ifdef __ARM_FEATURE_MATMUL_INT8 562 template <class Arch> 563 inline xsimd::batch<int32_t, Arch> 564 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 565 xsimd::batch<int32_t, Arch> z, 566 xsimd::kernel::requires_arch<xsimd::i8mm<xsimd::neon64>>) { 567 return vusdotq_s32(z, x, y); 568 } 569 #endif 570 571 template <class Arch> 572 inline xsimd::batch<int32_t, Arch> 573 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 574 xsimd::batch<int32_t, Arch> z, 575 xsimd::kernel::requires_arch<xsimd::neon64>) { 576 int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))), 577 vmovl_s8(vget_low_s8(y))); 578 int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))), 579 vmovl_s8(vget_high_s8(y))); 580 return vpadalq_s16(vpadalq_s16(z, tl), th); 581 } 582 583 template <class Arch> 584 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, 585 xsimd::batch<int32_t, Arch> sum1, 586 xsimd::batch<int32_t, Arch> sum2, 587 xsimd::batch<int32_t, Arch> sum3, 588 xsimd::kernel::requires_arch<xsimd::neon64>) { 589 auto pack01 = vpaddq_s32(sum0, sum1); 590 auto pack23 = vpaddq_s32(sum2, sum3); 591 return vpaddq_s32(pack01, pack23); 592 } 593 594 #endif 595 596 template <class Arch> 597 inline xsimd::batch<int32_t, Arch> 598 maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y, 599 xsimd::batch<int32_t, Arch> z, 600 xsimd::kernel::requires_arch<xsimd::generic>) { 601 return z + madd(xsimd::batch<int16_t, Arch>(1), madd(x, y, Arch{}), Arch{}); 602 } 603 604 } // namespace kernel 605 606 // 607 // Generic dispatcher for interleave, deinterleave madd and PermuteSummer 608 // 609 610 template <class T, class Arch> 611 std::tuple<xsimd::batch<T, Arch>, xsimd::batch<T, Arch>> 612 interleave(xsimd::batch<T, Arch> first, xsimd::batch<T, Arch> second) { 613 return kernel::interleave(first, second, Arch{}); 614 } 615 616 template <class Arch> 617 xsimd::batch<int8_t, Arch> deinterleave(xsimd::batch<int16_t, Arch> first, 618 xsimd::batch<int16_t, Arch> second) { 619 return kernel::deinterleave(first, second, Arch{}); 620 } 621 template <class Arch> 622 xsimd::batch<int16_t, Arch> deinterleave(xsimd::batch<int32_t, Arch> first, 623 xsimd::batch<int32_t, Arch> second) { 624 return kernel::deinterleave(first, second, Arch{}); 625 } 626 627 template <class Arch> 628 inline xsimd::batch<int32_t, Arch> madd(xsimd::batch<int16_t, Arch> x, 629 xsimd::batch<int16_t, Arch> y) { 630 return kernel::madd(x, y, Arch{}); 631 } 632 template <class Arch> 633 inline xsimd::batch<int16_t, Arch> madd(xsimd::batch<int8_t, Arch> x, 634 xsimd::batch<int8_t, Arch> y) { 635 return kernel::madd(x, y, Arch{}); 636 } 637 template <class Arch> 638 inline xsimd::batch<int16_t, Arch> madd(xsimd::batch<uint8_t, Arch> x, 639 xsimd::batch<int8_t, Arch> y) { 640 return kernel::madd(x, y, Arch{}); 641 } 642 template <class Arch> 643 inline xsimd::batch<int32_t, Arch> maddw(xsimd::batch<uint8_t, Arch> x, 644 xsimd::batch<int8_t, Arch> y, 645 xsimd::batch<int32_t, Arch> z 646 ) { 647 return kernel::maddw(x, y, z, Arch{}); 648 } 649 template <class Arch> 650 inline xsimd::batch<int32_t, Arch> maddw(xsimd::batch<uint8_t, Arch> x, 651 xsimd::batch<int8_t, Arch> y 652 ) { 653 return maddw(x, y, xsimd::batch<int32_t, Arch>((int32_t)0)); 654 } 655 656 template <class Arch> 657 inline auto PermuteSummer(xsimd::batch<int32_t, Arch> pack0123, 658 xsimd::batch<int32_t, Arch> pack4567) 659 -> decltype(kernel::PermuteSummer(pack0123, pack4567, Arch{})) { 660 return kernel::PermuteSummer(pack0123, pack4567, Arch{}); 661 } 662 663 664 namespace kernel { 665 666 template <class Arch> 667 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, 668 xsimd::batch<int32_t, Arch> sum1, 669 xsimd::batch<int32_t, Arch> sum2, 670 xsimd::batch<int32_t, Arch> sum3, 671 xsimd::kernel::requires_arch<xsimd::generic>) { 672 673 std::tie(sum0, sum1) = interleave(sum0, sum1, Arch{}); 674 auto pack01 = sum0 + sum1; 675 std::tie(sum2, sum3) = interleave(sum2, sum3, Arch{}); 676 auto pack23 = sum2 + sum3; 677 678 auto packed = interleave(xsimd::bitwise_cast<int64_t>(pack01), 679 xsimd::bitwise_cast<int64_t>(pack23), 680 Arch{}); 681 return xsimd::bitwise_cast<int32_t>(std::get<0>(packed)) + 682 xsimd::bitwise_cast<int32_t>(std::get<1>(packed)); 683 } 684 } 685 686 template <class Arch> 687 inline xsimd::batch<int32_t, Arch> Pack0123(xsimd::batch<int32_t, Arch> sum0, 688 xsimd::batch<int32_t, Arch> sum1, 689 xsimd::batch<int32_t, Arch> sum2, 690 xsimd::batch<int32_t, Arch> sum3) { 691 return kernel::Pack0123(sum0, sum1, sum2, sum3, Arch{}); 692 } 693 694 template <class Arch> 695 static inline xsimd::batch<int32_t, Arch> 696 quantize(xsimd::batch<float, Arch> input, 697 xsimd::batch<float, Arch> quant_mult) { 698 return xsimd::nearbyint_as_int(input * quant_mult); 699 } 700 701 template <class Arch> 702 inline xsimd::batch<int32_t, Arch> 703 QuantizerGrab(const float *input, xsimd::batch<float, Arch> quant_mult_reg) { 704 return quantize(xsimd::batch<float, Arch>::load_unaligned(input), 705 quant_mult_reg); 706 } 707 708 #ifdef __AVX512BW__ 709 inline __m512 Concat(const __m256 first, const __m256 second) { 710 // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway. 711 return _mm512_insertf32x8(_mm512_castps256_ps512(first), second, 1); 712 } 713 714 // Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be 715 // controlled independently. 716 /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set 717 * INTGEMM_AVX512BW */ 718 inline __m512i QuantizerGrabHalves(const float *input0, const float *input1, 719 const __m512 quant_mult_reg) { 720 __m512 appended = Concat(_mm256_loadu_ps(input0), _mm256_loadu_ps(input1)); 721 appended = _mm512_mul_ps(appended, quant_mult_reg); 722 return _mm512_cvtps_epi32(appended); 723 } 724 #else 725 template <class Arch> 726 inline xsimd::batch<int32_t, Arch> 727 QuantizerGrabHalves(const float *input0, const float *input1, 728 xsimd::batch<float, Arch> quant_mult_reg); 729 #endif 730 731 /* Read 8 floats at a time from input0, input1, input2, and input3. Quantize 732 * them to 8-bit by multiplying with quant_mult_reg then rounding. Concatenate 733 * the result into one register and return it. 734 */ 735 class QuantizeTile8 { 736 template <class Arch> struct Tiler { 737 static constexpr uint32_t get(std::size_t i, std::size_t n) { 738 size_t factor = xsimd::batch<float, Arch>::size / 4; 739 return (i % factor) * 4 + i / factor; 740 } 741 }; 742 743 public: 744 template <class Arch> 745 static inline xsimd::batch<int8_t, Arch> 746 Consecutive(xsimd::batch<float, Arch> quant_mult, const float *input) { 747 return Tile(quant_mult, input + 0 * xsimd::batch<float, Arch>::size, 748 input + 1 * xsimd::batch<float, Arch>::size, 749 input + 2 * xsimd::batch<float, Arch>::size, 750 input + 3 * xsimd::batch<float, Arch>::size); 751 } 752 753 template <class Arch> 754 static inline xsimd::batch<uint8_t, Arch> 755 ConsecutiveU(xsimd::batch<float, Arch> quant_mult, const float *input) { 756 return TileU(quant_mult, input + 0 * xsimd::batch<float, Arch>::size, 757 input + 1 * xsimd::batch<float, Arch>::size, 758 input + 2 * xsimd::batch<float, Arch>::size, 759 input + 3 * xsimd::batch<float, Arch>::size); 760 } 761 762 template <class Arch> 763 static inline xsimd::batch<int8_t, Arch> 764 ConsecutiveWithWrapping(xsimd::batch<float, Arch> quant_mult, 765 const float *input, size_t cols_left, size_t cols, 766 size_t row_step) { 767 using batchf32 = xsimd::batch<float, Arch>; 768 const float *inputs[4]; 769 for (size_t i = 0; i < std::size(inputs); ++i) { 770 while (cols_left < batchf32::size) { 771 input += cols * (row_step - 1); 772 cols_left += cols; 773 } 774 inputs[i] = input; 775 input += batchf32::size; 776 cols_left -= batchf32::size; 777 } 778 return Tile(quant_mult, inputs[0], inputs[1], inputs[2], inputs[3]); 779 } 780 781 template <class Arch> 782 static inline xsimd::batch<int8_t, Arch> 783 ForReshape(xsimd::batch<float, Arch> quant_mult, const float *input, 784 size_t cols) { 785 using batchf32 = xsimd::batch<float, Arch>; 786 using batch8 = xsimd::batch<int8_t, Arch>; 787 using batch16 = xsimd::batch<int16_t, Arch>; 788 using batch32 = xsimd::batch<int32_t, Arch>; 789 790 // Put higher rows in the second half of the register. These will jumble 791 // around in the same way then conveniently land in the right place. 792 if constexpr (batchf32::size == 16) { 793 const batch8 neg127(-127); 794 // In reverse order: grabbing the first 32-bit values from each 128-bit 795 // register, then the second 32-bit values, etc. Grab 4 registers at a 796 // time in 32-bit format. 797 batch32 g0 = 798 QuantizerGrabHalves(input + 0 * cols, input + 2 * cols, quant_mult); 799 batch32 g1 = 800 QuantizerGrabHalves(input + 16 * cols, input + 18 * cols, quant_mult); 801 batch32 g2 = 802 QuantizerGrabHalves(input + 32 * cols, input + 34 * cols, quant_mult); 803 batch32 g3 = 804 QuantizerGrabHalves(input + 48 * cols, input + 50 * cols, quant_mult); 805 806 // Pack 32-bit to 16-bit. 807 batch16 packed0 = deinterleave(g0, g1); 808 batch16 packed1 = deinterleave(g2, g3); 809 // Pack 16-bit to 8-bit. 810 batch8 packed = deinterleave(packed0, packed1); 811 // Ban -128. 812 packed = xsimd::max(packed, neg127); 813 814 return xsimd::bitwise_cast<int8_t>( 815 xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed), 816 xsimd::make_batch_constant<uint32_t, Arch, Tiler<Arch>>())); 817 } else if constexpr (batchf32::size == 8) 818 return Tile(quant_mult, input, input + 2 * cols, input + 16 * cols, 819 input + 18 * cols); 820 else if constexpr (batchf32::size == 4) 821 // Skip a row. 822 return Tile(quant_mult, input, input + 4, input + 2 * cols, 823 input + 2 * cols + 4); 824 else 825 return {}; 826 } 827 828 template <class Arch> 829 static inline xsimd::batch<int8_t, Arch> 830 Tile(xsimd::batch<float, Arch> quant_mult, const float *input0, 831 const float *input1, const float *input2, const float *input3) { 832 using batch8 = xsimd::batch<int8_t, Arch>; 833 using batch16 = xsimd::batch<int16_t, Arch>; 834 using batch32 = xsimd::batch<int32_t, Arch>; 835 836 const batch8 neg127(-127); 837 // Grab 4 registers at a time in 32-bit format. 838 batch32 g0 = QuantizerGrab(input0, quant_mult); 839 batch32 g1 = QuantizerGrab(input1, quant_mult); 840 batch32 g2 = QuantizerGrab(input2, quant_mult); 841 batch32 g3 = QuantizerGrab(input3, quant_mult); 842 // Pack 32-bit to 16-bit. 843 batch16 packed0 = deinterleave(g0, g1); 844 batch16 packed1 = deinterleave(g2, g3); 845 // Pack 16-bit to 8-bit. 846 batch8 packed = deinterleave(packed0, packed1); 847 // Ban -128. 848 packed = xsimd::max(packed, neg127); 849 850 if constexpr (batch32::size == 4) 851 return packed; 852 // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 853 // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7 854 // Technically this could be removed so long as the rows are bigger than 16 855 // and the values are only used for GEMM. 856 return xsimd::bitwise_cast<int8_t>( 857 xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed), 858 xsimd::make_batch_constant<uint32_t, Arch, Tiler<Arch>>())); 859 } 860 861 private: 862 // A version that produces uint8_ts 863 template <class Arch> 864 static inline xsimd::batch<uint8_t, Arch> 865 TileU(xsimd::batch<float, Arch> quant_mult, const float *input0, 866 const float *input1, const float *input2, const float *input3) { 867 using batch8 = xsimd::batch<int8_t, Arch>; 868 using batch16 = xsimd::batch<int16_t, Arch>; 869 using batch32 = xsimd::batch<int32_t, Arch>; 870 871 const batch8 neg127 = -127; 872 const batch8 pos127 = +127; 873 // Grab 4 registers at a time in 32-bit format. 874 batch32 g0 = QuantizerGrab(input0, quant_mult); 875 batch32 g1 = QuantizerGrab(input1, quant_mult); 876 batch32 g2 = QuantizerGrab(input2, quant_mult); 877 batch32 g3 = QuantizerGrab(input3, quant_mult); 878 // Pack 32-bit to 16-bit. 879 batch16 packed0 = deinterleave(g0, g1); 880 batch16 packed1 = deinterleave(g2, g3); 881 // Pack 16-bit to 8-bit. 882 batch8 packed = deinterleave(packed0, packed1); 883 // Ban -128. 884 packed = xsimd::max(packed, neg127); // Could be removed if we use +128 885 packed = packed + pos127; 886 if (batch32::size == 4) 887 return xsimd::bitwise_cast<uint8_t>(packed); 888 // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 889 // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7 890 // Technically this could be removed so long as the rows are bigger than 16 891 // and the values are only used for GEMM. 892 return xsimd::bitwise_cast<uint8_t>( 893 xsimd::swizzle(xsimd::bitwise_cast<int32_t>(packed), 894 xsimd::make_batch_constant<uint32_t, Arch, Tiler<Arch>>())); 895 } 896 }; 897 898 template <class Arch> 899 inline void Transpose16InLane( 900 xsimd::batch<int8_t, Arch> &r0, xsimd::batch<int8_t, Arch> &r1, 901 xsimd::batch<int8_t, Arch> &r2, xsimd::batch<int8_t, Arch> &r3, 902 xsimd::batch<int8_t, Arch> &r4, xsimd::batch<int8_t, Arch> &r5, 903 xsimd::batch<int8_t, Arch> &r6, xsimd::batch<int8_t, Arch> &r7) { 904 /* r0: columns 0 1 2 3 4 5 6 7 from row 0 905 r1: columns 0 1 2 3 4 5 6 7 from row 1*/ 906 auto r0_16 = xsimd::bitwise_cast<int16_t>(r0); 907 auto r1_16 = xsimd::bitwise_cast<int16_t>(r1); 908 auto r2_16 = xsimd::bitwise_cast<int16_t>(r2); 909 auto r3_16 = xsimd::bitwise_cast<int16_t>(r3); 910 auto r4_16 = xsimd::bitwise_cast<int16_t>(r4); 911 auto r5_16 = xsimd::bitwise_cast<int16_t>(r5); 912 auto r6_16 = xsimd::bitwise_cast<int16_t>(r6); 913 auto r7_16 = xsimd::bitwise_cast<int16_t>(r7); 914 915 std::tie(r0_16, r1_16) = interleave(r0_16, r1_16); 916 std::tie(r2_16, r3_16) = interleave(r2_16, r3_16); 917 std::tie(r4_16, r5_16) = interleave(r4_16, r5_16); 918 std::tie(r6_16, r7_16) = interleave(r6_16, r7_16); 919 /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1 920 r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1 921 r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3 922 r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3 923 r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5 924 r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5 925 r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7 926 r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/ 927 auto r0_32 = xsimd::bitwise_cast<int32_t>(r0_16); 928 auto r2_32 = xsimd::bitwise_cast<int32_t>(r2_16); 929 auto r1_32 = xsimd::bitwise_cast<int32_t>(r1_16); 930 auto r3_32 = xsimd::bitwise_cast<int32_t>(r3_16); 931 auto r4_32 = xsimd::bitwise_cast<int32_t>(r4_16); 932 auto r6_32 = xsimd::bitwise_cast<int32_t>(r6_16); 933 auto r5_32 = xsimd::bitwise_cast<int32_t>(r5_16); 934 auto r7_32 = xsimd::bitwise_cast<int32_t>(r7_16); 935 936 std::tie(r0_32, r2_32) = interleave(r0_32, r2_32); 937 std::tie(r1_32, r3_32) = interleave(r1_32, r3_32); 938 std::tie(r4_32, r6_32) = interleave(r4_32, r6_32); 939 std::tie(r5_32, r7_32) = interleave(r5_32, r7_32); 940 /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3 941 r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3 942 r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3 943 r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3 944 r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7 945 r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7 946 r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7 947 r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/ 948 949 auto r0_64 = xsimd::bitwise_cast<int64_t>(r0_32); 950 auto r2_64 = xsimd::bitwise_cast<int64_t>(r2_32); 951 auto r1_64 = xsimd::bitwise_cast<int64_t>(r1_32); 952 auto r3_64 = xsimd::bitwise_cast<int64_t>(r3_32); 953 auto r4_64 = xsimd::bitwise_cast<int64_t>(r4_32); 954 auto r6_64 = xsimd::bitwise_cast<int64_t>(r6_32); 955 auto r5_64 = xsimd::bitwise_cast<int64_t>(r5_32); 956 auto r7_64 = xsimd::bitwise_cast<int64_t>(r7_32); 957 958 std::tie(r0_64, r4_64) = interleave(r0_64, r4_64); 959 std::tie(r1_64, r5_64) = interleave(r1_64, r5_64); 960 std::tie(r2_64, r6_64) = interleave(r2_64, r6_64); 961 std::tie(r3_64, r7_64) = interleave(r3_64, r7_64); 962 963 r0 = xsimd::bitwise_cast<int8_t>(r0_64); 964 r1 = xsimd::bitwise_cast<int8_t>(r1_64); 965 r2 = xsimd::bitwise_cast<int8_t>(r2_64); 966 r3 = xsimd::bitwise_cast<int8_t>(r3_64); 967 r4 = xsimd::bitwise_cast<int8_t>(r4_64); 968 r5 = xsimd::bitwise_cast<int8_t>(r5_64); 969 r6 = xsimd::bitwise_cast<int8_t>(r6_64); 970 r7 = xsimd::bitwise_cast<int8_t>(r7_64); 971 /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7 972 r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7 973 r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7 974 r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7 975 r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7 976 r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/ 977 /* Empirically gcc is able to remove these movs and just rename the outputs of 978 * Interleave64. */ 979 std::swap(r1, r4); 980 std::swap(r3, r6); 981 } 982 983 template <class Arch, typename IntegerTy> 984 void SelectColumnsOfB(const xsimd::batch<int8_t, Arch> *input, 985 xsimd::batch<int8_t, Arch> *output, 986 size_t rows_bytes /* number of bytes in a row */, 987 const IntegerTy *cols_begin, const IntegerTy *cols_end) { 988 using batch8 = xsimd::batch<int8_t, Arch>; 989 /* Do columns for multiples of 8.*/ 990 size_t register_rows = rows_bytes / batch8::size; 991 const batch8 *starts[8]; 992 for (; cols_begin != cols_end; cols_begin += 8) { 993 for (size_t k = 0; k < 8; ++k) { 994 starts[k] = 995 input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows; 996 } 997 for (size_t r = 0; r < register_rows; ++r) { 998 for (size_t k = 0; k < 8; ++k) { 999 *(output++) = *starts[k]; 1000 starts[k] += 8; 1001 } 1002 } 1003 } 1004 } 1005 1006 } // namespace 1007 1008 namespace callbacks { 1009 template <class Arch> 1010 xsimd::batch<float, Arch> Unquantize::operator()(xsimd::batch<int32_t, Arch> total, size_t, size_t, 1011 size_t) { 1012 return xsimd::batch_cast<float>(total) * unquant_mult; 1013 } 1014 1015 template <class Arch> 1016 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> Unquantize::operator()( 1017 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> total, 1018 size_t, size_t, size_t) { 1019 return std::make_tuple( 1020 xsimd::batch_cast<float>(std::get<0>(total)) * unquant_mult, 1021 xsimd::batch_cast<float>(std::get<1>(total)) * unquant_mult); 1022 } 1023 1024 template <class Arch> 1025 xsimd::batch<float, Arch> AddBias::operator()(xsimd::batch<float, Arch> total, size_t, 1026 size_t col_idx, size_t) { 1027 return total + xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx); 1028 } 1029 1030 template <class Arch> 1031 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> 1032 AddBias::operator()( 1033 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> total, 1034 size_t, size_t col_idx, size_t) { 1035 return std::make_tuple( 1036 std::get<0>(total) + xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx + 0), 1037 std::get<1>(total) + 1038 xsimd::batch<float, Arch>::load_aligned(bias_addr + col_idx + 1039 xsimd::batch<float, Arch>::size)); 1040 } 1041 1042 template <class Arch> 1043 void Write::operator()(xsimd::batch<float, Arch> result, size_t row_idx, 1044 size_t col_idx, size_t col_size) { 1045 result.store_aligned(output_addr + row_idx * col_size + col_idx); 1046 } 1047 1048 template <class Arch> 1049 void Write::operator()(xsimd::batch<int32_t, Arch> result, size_t row_idx, 1050 size_t col_idx, size_t col_size) { 1051 xsimd::bitwise_cast<float>(result).store_aligned( 1052 output_addr + row_idx * col_size + col_idx); 1053 } 1054 1055 template <class Arch> 1056 void Write::operator()( 1057 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> result, 1058 size_t row_idx, size_t col_idx, size_t col_size) { 1059 std::get<0>(result).store_aligned(output_addr + row_idx * col_size + col_idx + 1060 0); 1061 std::get<1>(result).store_aligned(output_addr + row_idx * col_size + col_idx + 1062 xsimd::batch<float, Arch>::size); 1063 } 1064 1065 template <class Arch> 1066 void Write::operator()( 1067 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> result, 1068 size_t row_idx, size_t col_idx, size_t col_size) { 1069 xsimd::bitwise_cast<float>(std::get<0>(result)) 1070 .store_aligned(output_addr + row_idx * col_size + col_idx + 0); 1071 xsimd::bitwise_cast<float>(std::get<1>(result)) 1072 .store_aligned(output_addr + row_idx * col_size + col_idx + 1073 xsimd::batch<int32_t, Arch>::size); 1074 } 1075 1076 template <class T> 1077 void UnquantizeAndWrite::operator()(T const &total, size_t row_idx, 1078 size_t col_idx, size_t col_size) { 1079 auto unquantized = unquantize(total, row_idx, col_idx, col_size); 1080 write(unquantized, row_idx, col_idx, col_size); 1081 } 1082 1083 template <class T> 1084 void UnquantizeAndAddBiasAndWrite::operator()(T const &total, size_t row_idx, 1085 size_t col_idx, size_t col_size) { 1086 auto unquantized = unquantize(total, row_idx, col_idx, col_size); 1087 auto bias_added = add_bias(unquantized, row_idx, col_idx, col_size); 1088 write(bias_added, row_idx, col_idx, col_size); 1089 } 1090 } // namespace callbacks 1091 1092 template <class Arch> 1093 void Engine<Arch>::QuantizeU(const float *input, uint8_t *output, 1094 float quant_mult, size_t size) { 1095 using batch8 = xsimd::batch<int8_t, Arch>; 1096 1097 xsimd::batch<float, Arch> q(quant_mult); 1098 const float *end = input + size; 1099 for (; input != end; input += batch8::size, output += batch8::size) { 1100 auto tile = QuantizeTile8::ConsecutiveU(q, input); 1101 tile.store_aligned(output); 1102 } 1103 } 1104 1105 template <class Arch> 1106 void Engine<Arch>::Quantize(const float *const input, int8_t *const output, 1107 float quant_mult, size_t size) { 1108 using batch8 = xsimd::batch<int8_t, Arch>; 1109 1110 const std::size_t kBatch = batch8::size; 1111 const std::size_t fast_end = size & ~(kBatch - 1); 1112 1113 xsimd::batch<float, Arch> q(quant_mult); 1114 for (std::size_t i = 0; i < fast_end; i += kBatch) { 1115 auto tile = QuantizeTile8::Consecutive(q, input + i); 1116 tile.store_aligned(output + i); 1117 } 1118 1119 std::size_t overhang = size & (kBatch - 1); 1120 if (!overhang) 1121 return; 1122 /* Each does size(xsimd::batch<int8_t, Arch>) / 32 == kBatch / 4 floats at a 1123 * time. If we're allowed to read one of them, then we can read the whole 1124 * register. 1125 */ 1126 const float *inputs[4]; 1127 std::size_t i; 1128 for (i = 0; i < (overhang + (kBatch / 4) - 1) / (kBatch / 4); ++i) { 1129 inputs[i] = &input[fast_end + i * (kBatch / 4)]; 1130 } 1131 /* These will be clipped off. */ 1132 for (; i < 4; ++i) { 1133 inputs[i] = &input[fast_end]; 1134 } 1135 auto result = 1136 QuantizeTile8::Tile(q, inputs[0], inputs[1], inputs[2], inputs[3]); 1137 std::memcpy(output + (size & ~(kBatch - 1)), &result, overhang); 1138 } 1139 1140 template <class Arch> 1141 template <typename IntegerTy> 1142 void Engine<Arch>::SelectColumnsB(const int8_t *input, int8_t *output, 1143 size_t rows, const IntegerTy *cols_begin, 1144 const IntegerTy *cols_end) { 1145 using batch8 = xsimd::batch<int8_t, Arch>; 1146 SelectColumnsOfB(reinterpret_cast<const batch8 *>(input), 1147 reinterpret_cast<batch8 *>(output), rows, cols_begin, 1148 cols_end); 1149 } 1150 1151 template <class Arch> 1152 void Engine<Arch>::PrepareBTransposed(const float *input, int8_t *output, 1153 float quant_mult, size_t cols, 1154 size_t rows) { 1155 using batch8 = xsimd::batch<int8_t, Arch>; 1156 const size_t RegisterElemsInt = batch8::size; 1157 const size_t kColStride = 8; 1158 1159 xsimd::batch<float, Arch> q(quant_mult); 1160 auto *output_it = reinterpret_cast<batch8 *>(output); 1161 size_t r = 0; 1162 size_t c = 0; 1163 while (r < rows) { 1164 for (size_t ri = 0; ri < 8; ++ri) 1165 *output_it++ = QuantizeTile8::ConsecutiveWithWrapping( 1166 q, input + (r + ri) * cols + c, cols - c, cols, 8); 1167 c += RegisterElemsInt; 1168 while (c >= cols) { 1169 r += kColStride; 1170 c -= cols; 1171 } 1172 } 1173 } 1174 1175 template <class Arch> 1176 void Engine<Arch>::PrepareBQuantizedTransposed(const int8_t *input, 1177 int8_t *output, size_t cols, 1178 size_t rows) { 1179 using batch8 = xsimd::batch<int8_t, Arch>; 1180 const size_t RegisterElems = batch8::size; 1181 const size_t kColStride = 8; 1182 1183 auto *output_it = reinterpret_cast<batch8 *>(output); 1184 for (size_t r = 0; r < rows; r += kColStride) 1185 for (size_t c = 0; c < cols; c += RegisterElems) 1186 for (size_t ri = 0; ri < 8; ++ri) 1187 *output_it++ = 1188 *reinterpret_cast<const batch8 *>(input + (r + ri) * cols + c); 1189 } 1190 1191 template <class Arch> 1192 void Engine<Arch>::PrepareB(const float *input, int8_t *output_shadow, 1193 float quant_mult, size_t rows, size_t cols) { 1194 using batch8 = xsimd::batch<int8_t, Arch>; 1195 1196 xsimd::batch<float, Arch> q(quant_mult); 1197 /* Currently all multipliers have a stride of 8 columns.*/ 1198 const size_t kColStride = 8; 1199 auto *output = reinterpret_cast<batch8 *>(output_shadow); 1200 for (size_t c = 0; c < cols; c += kColStride) { 1201 for (size_t r = 0; r < rows; r += sizeof(*output), output += 8) { 1202 output[0] = 1203 QuantizeTile8::ForReshape(q, input + cols * (r + 0) + c, cols); 1204 output[1] = 1205 QuantizeTile8::ForReshape(q, input + cols * (r + 1) + c, cols); 1206 output[2] = 1207 QuantizeTile8::ForReshape(q, input + cols * (r + 4) + c, cols); 1208 output[3] = 1209 QuantizeTile8::ForReshape(q, input + cols * (r + 5) + c, cols); 1210 output[4] = 1211 QuantizeTile8::ForReshape(q, input + cols * (r + 8) + c, cols); 1212 output[5] = 1213 QuantizeTile8::ForReshape(q, input + cols * (r + 9) + c, cols); 1214 output[6] = 1215 QuantizeTile8::ForReshape(q, input + cols * (r + 12) + c, cols); 1216 output[7] = 1217 QuantizeTile8::ForReshape(q, input + cols * (r + 13) + c, cols); 1218 std::tie(output[0], output[1]) = 1219 interleave(xsimd::bitwise_cast<int8_t>(output[0]), 1220 xsimd::bitwise_cast<int8_t>(output[1])); 1221 std::tie(output[2], output[3]) = 1222 interleave(xsimd::bitwise_cast<int8_t>(output[2]), 1223 xsimd::bitwise_cast<int8_t>(output[3])); 1224 std::tie(output[4], output[5]) = 1225 interleave(xsimd::bitwise_cast<int8_t>(output[4]), 1226 xsimd::bitwise_cast<int8_t>(output[5])); 1227 std::tie(output[6], output[7]) = 1228 interleave(xsimd::bitwise_cast<int8_t>(output[6]), 1229 xsimd::bitwise_cast<int8_t>(output[7])); 1230 Transpose16InLane(output[0], output[1], output[2], output[3], output[4], 1231 output[5], output[6], output[7]); 1232 } 1233 } 1234 } 1235 1236 template <class Arch> 1237 void Engine<Arch>::PrepareA(const float *input, int8_t *output, 1238 float quant_mult, size_t rows, size_t cols) { 1239 Quantize(input, output, quant_mult, rows * cols); 1240 } 1241 1242 template <class Arch> 1243 void Engine<Arch>::Shift::PrepareA(const float *input, uint8_t *output, 1244 float quant_mult, size_t rows, size_t cols) { 1245 QuantizeU(input, output, quant_mult, rows * cols); 1246 } 1247 1248 template <class Arch> 1249 template <class Callback> 1250 void Engine<Arch>::Shift::Multiply(const uint8_t *A, const int8_t *B, 1251 size_t A_rows, size_t width, size_t B_cols, 1252 Callback callback) { 1253 1254 using batch8 = xsimd::batch<int8_t, Arch>; 1255 using ubatch8 = xsimd::batch<uint8_t, Arch>; 1256 using batch32 = xsimd::batch<int32_t, Arch>; 1257 1258 const size_t simd_width = width / batch8::size; 1259 for (size_t B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { 1260 const auto *B0_col = 1261 reinterpret_cast<const batch8 *>(B) + simd_width * B0_colidx; 1262 /* Process one row of A at a time. Doesn't seem to be faster to do multiple 1263 * rows of A at once.*/ 1264 for (size_t A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { 1265 const auto *A_row = 1266 reinterpret_cast<const ubatch8 *>(A + A_rowidx * width); 1267 /* These will be packed 16-bit integers containing sums for each row of B 1268 multiplied by the row of A. Iterate over shared (inner) dimension.*/ 1269 /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is 1270 * declared here.*/ 1271 size_t k = 0; 1272 ubatch8 a = *(A_row + k); 1273 batch32 isum0 = maddw(a, *(B0_col + k * 8)); 1274 batch32 isum1 = maddw(a, *(B0_col + k * 8 + 1)); 1275 batch32 isum2 = maddw(a, *(B0_col + k * 8 + 2)); 1276 batch32 isum3 = maddw(a, *(B0_col + k * 8 + 3)); 1277 batch32 isum4 = maddw(a, *(B0_col + k * 8 + 4)); 1278 batch32 isum5 = maddw(a, *(B0_col + k * 8 + 5)); 1279 batch32 isum6 = maddw(a, *(B0_col + k * 8 + 6)); 1280 batch32 isum7 = maddw(a, *(B0_col + k * 8 + 7)); 1281 for (k = 1; k < simd_width; ++k) { 1282 a = *(A_row + k); 1283 /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ 1284 /* Upcast to 32-bit and horizontally add.*/ 1285 isum0 = maddw(a, *(B0_col + k * 8 + 0), isum0); 1286 isum1 = maddw(a, *(B0_col + k * 8 + 1), isum1); 1287 isum2 = maddw(a, *(B0_col + k * 8 + 2), isum2); 1288 isum3 = maddw(a, *(B0_col + k * 8 + 3), isum3); 1289 isum4 = maddw(a, *(B0_col + k * 8 + 4), isum4); 1290 isum5 = maddw(a, *(B0_col + k * 8 + 5), isum5); 1291 isum6 = maddw(a, *(B0_col + k * 8 + 6), isum6); 1292 isum7 = maddw(a, *(B0_col + k * 8 + 7), isum7); 1293 } 1294 /* Reduce sums within 128-bit lanes.*/ 1295 auto pack0123 = Pack0123(isum0, isum1, isum2, isum3); 1296 auto pack4567 = Pack0123(isum4, isum5, isum6, isum7); 1297 /*The specific implementation may need to reduce further.*/ 1298 auto total = PermuteSummer(pack0123, pack4567); 1299 callback(total, A_rowidx, B0_colidx, B_cols); 1300 } 1301 } 1302 } 1303 1304 template <class Arch> 1305 template <class Callback> 1306 void Engine<Arch>::Shift::PrepareBias(const int8_t *B, size_t width, 1307 size_t B_cols, Callback C) { 1308 using batch8 = xsimd::batch<int8_t, Arch>; 1309 const size_t simd_width = width / batch8::size; 1310 xsimd::batch<uint8_t, Arch> a(1); 1311 for (size_t j = 0; j < B_cols; j += 8) { 1312 /*Process one row of A at a time. Doesn't seem to be faster to do multiple 1313 * rows of A at once.*/ 1314 const int8_t *B_j = B + j * width; 1315 1316 /* Rather than initializing as zeros and adding, just initialize the 1317 * first.*/ 1318 /* These will be packed 16-bit integers containing sums for each column of 1319 * B multiplied by the row of A.*/ 1320 /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is 1321 * declared here.*/ 1322 auto isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size])); 1323 auto isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size])); 1324 auto isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size])); 1325 auto isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size])); 1326 auto isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size])); 1327 auto isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size])); 1328 auto isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size])); 1329 auto isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size])); 1330 1331 B_j += 8 * batch8::size; 1332 1333 for (size_t k = 1; k < simd_width; ++k, B_j += 8 * batch8::size) { 1334 isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size]), isum0); 1335 isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size]), isum1); 1336 isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size]), isum2); 1337 isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size]), isum3); 1338 isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size]), isum4); 1339 isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size]), isum5); 1340 isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size]), isum6); 1341 isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size]), isum7); 1342 } 1343 1344 auto pack0123 = Pack0123(isum0, isum1, isum2, isum3); 1345 auto pack4567 = Pack0123(isum4, isum5, isum6, isum7); 1346 1347 auto total = PermuteSummer(pack0123, pack4567); 1348 C(total, 0, j, B_cols); 1349 } 1350 } 1351 1352 } // namespace gemmology 1353 1354 #endif