IntegerGemmIntrinsic.cpp (18488B)
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- 2 * vim: set ts=8 sts=2 et sw=2 tw=80: 3 * 4 * This Source Code Form is subject to the terms of the Mozilla Public 5 * License, v. 2.0. If a copy of the MPL was not distributed with this 6 * file, You can obtain one at https://mozilla.org/MPL/2.0/. 7 */ 8 9 #include "intgemm/IntegerGemmIntrinsic.h" 10 11 #include "mozilla/CheckedInt.h" 12 #include "mozilla/IntegerPrintfMacros.h" 13 #include "mozilla/TimeStamp.h" 14 15 #include <gemmology_fwd.h> 16 #include "fmt/format.h" 17 18 #include "js/ErrorReport.h" 19 #include "js/HeapAPI.h" 20 #include "vm/ArrayBufferObject.h" 21 #include "vm/GeckoProfiler.h" 22 #include "vm/JSContext.h" 23 #include "wasm/WasmBuiltins.h" 24 #include "wasm/WasmInstance.h" 25 #include "wasm/WasmLog.h" 26 27 #if defined(USE_AVX512BW) 28 # if defined(USE_AVX512VNNI) 29 # if defined(USE_AVXVNNI) 30 # define SUPPORTED_ARCHS \ 31 xsimd::arch_list<xsimd::avx512vnni<xsimd::avx512bw>, xsimd::avx512bw, \ 32 xsimd::avxvnni, xsimd::avx2, xsimd::ssse3, \ 33 xsimd::sse2> 34 # else 35 # define SUPPORTED_ARCHS \ 36 xsimd::arch_list<xsimd::avx512vnni<xsimd::avx512bw>, xsimd::avx512bw, \ 37 xsimd::avx2, xsimd::ssse3, xsimd::sse2> 38 # endif 39 # elif defined(USE_AVXVNNI) 40 # define SUPPORTED_ARCHS \ 41 xsimd::arch_list<xsimd::avx512bw, xsimd::avxvnni, xsimd::avx2, \ 42 xsimd::ssse3, xsimd::sse2> 43 # else 44 # define SUPPORTED_ARCHS \ 45 xsimd::arch_list<xsimd::avx512bw, xsimd::avx2, xsimd::ssse3, xsimd::sse2> 46 # endif 47 #elif defined(USE_AVXVNNI) 48 # define SUPPORTED_ARCHS \ 49 xsimd::arch_list<xsimd::avxvnni, xsimd::avx2, xsimd::ssse3, xsimd::sse2> 50 #elif defined(USE_AVX2) 51 # define SUPPORTED_ARCHS \ 52 xsimd::arch_list<xsimd::avx2, xsimd::ssse3, xsimd::sse2> 53 #elif defined(USE_SSSE3) 54 # define SUPPORTED_ARCHS xsimd::arch_list<xsimd::ssse3, xsimd::sse2> 55 #elif defined(USE_SSE2) 56 # define SUPPORTED_ARCHS xsimd::arch_list<xsimd::sse2> 57 #elif defined(USE_NEON) and defined(XSIMD_WITH_NEON64) 58 # if defined(USE_NEON_I8MM) 59 # define SUPPORTED_ARCHS \ 60 xsimd::arch_list<xsimd::i8mm<xsimd::neon64>, xsimd::neon64> 61 # else 62 # define SUPPORTED_ARCHS xsimd::arch_list<xsimd::neon64> 63 # endif 64 #else 65 # error no supported architecture 66 #endif 67 68 // Dispatch *at runtime* based on run-time hardware and compile-time 69 // architectures. 70 // 71 // FIXME: Ideally we would not run the dispatch code at each function call. 72 #define GEMMOLOGY_DISPATCH(FUNC) \ 73 xsimd::dispatch<SUPPORTED_ARCHS>([](auto arch, auto... args) { \ 74 return gemmology::Engine<decltype(arch)>::FUNC(args...); \ 75 }) 76 77 template <size_t TextLength = 512, typename CharT = char> 78 struct AutoProfilerMarker { 79 AutoProfilerMarker(js::GeckoProfilerRuntime& profiler, const CharT* name) 80 : profiler(profiler), name(name) { 81 if (profiler.enabled()) { 82 startTime = mozilla::TimeStamp::Now(); 83 } 84 } 85 86 template <typename... Args> 87 AutoProfilerMarker(js::GeckoProfilerRuntime& profiler, const CharT* name, 88 fmt::format_string<Args...> aFormatStr, Args&&... aArgs) 89 : profiler(profiler), name(name) { 90 if (profiler.enabled()) { 91 startTime = mozilla::TimeStamp::Now(); 92 auto [out, size] = fmt::vformat_to_n( 93 text, sizeof(text) - 1, aFormatStr, 94 fmt::make_format_args<fmt::buffered_context<CharT>>(aArgs...)); 95 96 MOZ_ASSERT(size > sizeof(text) - 1, 97 "Truncated marker, consider increasing the buffer"); 98 99 *out = 0; 100 } 101 } 102 ~AutoProfilerMarker() { 103 if (profiler.enabled()) { 104 profiler.markInterval(name, startTime, text, 105 JS::ProfilingCategoryPair::JS); 106 } 107 } 108 js::GeckoProfilerRuntime& profiler; 109 const char* name; 110 char text[TextLength]{}; 111 mozilla::TimeStamp startTime; 112 }; 113 114 static constexpr uint32_t ARRAY_ALIGNMENT = 64; 115 static constexpr uint32_t ROWS_A_MULTIPLIER = 1; 116 static constexpr uint32_t COLUMNS_A_MULTIPLIER = 64; 117 static constexpr uint32_t ROWS_B_MULTIPLIER = COLUMNS_A_MULTIPLIER; 118 static constexpr uint32_t COLUMNS_B_MULTIPLIER = 8; 119 static constexpr uint32_t SELECTED_COLUMNS_B_MULTIPLIER = 8; 120 121 size_t GetWasmRawBufferLength(const uint8_t* memBase) { 122 const js::WasmArrayRawBuffer* rawBuf = 123 js::WasmArrayRawBuffer::fromDataPtr(memBase); 124 return rawBuf->byteLength(); 125 } 126 127 bool CheckMatrixDimension(uint32_t size, uint32_t sizeMultiplier) { 128 // A valid size is a positive integral multiple of Multiplier 129 return !((size == 0) || (size % sizeMultiplier != 0)); 130 } 131 132 bool CheckMatrixBound(uint32_t input, uint64_t inputSize, 133 size_t wasmBufferSize) { 134 mozilla::CheckedUint64 inputUpperLimit(inputSize); 135 inputUpperLimit += input; 136 137 // Bound check fails if size overflows or it spans outside the wasm memory 138 return !(!inputUpperLimit.isValid() || 139 (inputUpperLimit.value() >= (uint64_t)wasmBufferSize)); 140 } 141 142 bool CheckMatrixBoundAndAlignment(uint32_t input, uint64_t inputSize, 143 size_t wasmBufferSize) { 144 // Alignment check: It is sufficient to check alignment for the offset rather 145 // than for the actual pointer within wasm memory (as long as following assert 146 // is satisfied) 147 static_assert(js::gc::PageSize >= ARRAY_ALIGNMENT, 148 "PageSize should be bigger than Alignment"); 149 if (input % ARRAY_ALIGNMENT != 0) { 150 return false; 151 } 152 153 // Check Bound 154 return CheckMatrixBound(input, inputSize, wasmBufferSize); 155 } 156 157 int32_t js::intgemm::IntrI8PrepareB(wasm::Instance* instance, 158 uint32_t inputMatrixB, float scale, 159 float zeroPoint, uint32_t rowsB, 160 uint32_t colsB, uint32_t outputMatrixB, 161 uint8_t* memBase) { 162 MOZ_ASSERT(wasm::SASigIntrI8PrepareB.failureMode == 163 wasm::FailureMode::FailOnNegI32); 164 JSContext* cx = instance->cx(); 165 AutoUnsafeCallWithABI unsafe; 166 167 // Size checks for matricies 168 if (!CheckMatrixDimension(rowsB, ROWS_B_MULTIPLIER) || 169 !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER)) { 170 return -1; 171 } 172 173 // Memory Bound and Alignment checks for matricies 174 uint64_t sizeB = (uint64_t)rowsB * (uint64_t)colsB; 175 size_t wasmBufferSize = GetWasmRawBufferLength(memBase); 176 if (!CheckMatrixBoundAndAlignment(inputMatrixB, sizeB, wasmBufferSize) || 177 !CheckMatrixBoundAndAlignment(outputMatrixB, sizeB, wasmBufferSize)) { 178 return -1; 179 } 180 181 // Actual call to the 3rd party library (intgemm) for PrepareB 182 const float* inputMatrixBPtr = 183 reinterpret_cast<const float*>(&memBase[inputMatrixB]); 184 int8_t* outputMatrixBPtr = reinterpret_cast<int8_t*>(&memBase[outputMatrixB]); 185 AutoProfilerMarker marker(cx->runtime()->geckoProfiler(), "integemm::PreparB", 186 FMT_STRING("rowsB: {} colsB: {} sizeB: {}"), rowsB, 187 colsB, sizeB); 188 GEMMOLOGY_DISPATCH(PrepareB) 189 (inputMatrixBPtr, outputMatrixBPtr, 190 scale, // Quant Mult 191 rowsB, colsB); 192 return 0; 193 } 194 195 int32_t js::intgemm::IntrI8PrepareBFromTransposed( 196 wasm::Instance* instance, uint32_t inputMatrixBTransposed, float scale, 197 float zeroPoint, uint32_t rowsB, uint32_t colsB, uint32_t outputMatrixB, 198 uint8_t* memBase) { 199 MOZ_ASSERT(wasm::SASigIntrI8PrepareBFromTransposed.failureMode == 200 wasm::FailureMode::FailOnNegI32); 201 JSContext* cx = instance->cx(); 202 AutoUnsafeCallWithABI unsafe; 203 204 // Size checks for matricies 205 if (!CheckMatrixDimension(rowsB, ROWS_B_MULTIPLIER) || 206 !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER)) { 207 return -1; 208 } 209 210 // Memory Bound checks for all matricies 211 uint64_t sizeB = (uint64_t)rowsB * (uint64_t)colsB; 212 size_t wasmBufferSize = GetWasmRawBufferLength(memBase); 213 if (!CheckMatrixBoundAndAlignment(inputMatrixBTransposed, sizeB, 214 wasmBufferSize) || 215 !CheckMatrixBoundAndAlignment(outputMatrixB, sizeB, wasmBufferSize)) { 216 return -1; 217 } 218 219 // Actual call to the 3rd party library (intgemm) for PrepareBTransposed 220 const float* inputMatrixBTransposedPtr = 221 reinterpret_cast<const float*>(&memBase[inputMatrixBTransposed]); 222 int8_t* outputMatrixBPtr = reinterpret_cast<int8_t*>(&memBase[outputMatrixB]); 223 AutoProfilerMarker marker( 224 cx->runtime()->geckoProfiler(), "intgemm::PreparBTransposed", 225 FMT_STRING("rowsB: {} colsB: {} sizeB: {}"), rowsB, colsB, sizeB); 226 GEMMOLOGY_DISPATCH(PrepareBTransposed) 227 (inputMatrixBTransposedPtr, outputMatrixBPtr, 228 scale, // Quant Mult 229 rowsB, colsB); 230 return 0; 231 } 232 233 int32_t js::intgemm::IntrI8PrepareBFromQuantizedTransposed( 234 wasm::Instance* instance, uint32_t inputMatrixBQuantizedTransposed, 235 uint32_t rowsB, uint32_t colsB, uint32_t outputMatrixB, uint8_t* memBase) { 236 MOZ_ASSERT(wasm::SASigIntrI8PrepareBFromQuantizedTransposed.failureMode == 237 wasm::FailureMode::FailOnNegI32); 238 JSContext* cx = instance->cx(); 239 AutoUnsafeCallWithABI unsafe; 240 241 // Size checks for matricies 242 if (!CheckMatrixDimension(rowsB, ROWS_B_MULTIPLIER) || 243 !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER)) { 244 return -1; 245 } 246 247 // Memory Bound checks for all matricies 248 uint64_t sizeB = (uint64_t)rowsB * (uint64_t)colsB; 249 size_t wasmBufferSize = GetWasmRawBufferLength(memBase); 250 if (!CheckMatrixBoundAndAlignment(inputMatrixBQuantizedTransposed, sizeB, 251 wasmBufferSize) || 252 !CheckMatrixBoundAndAlignment(outputMatrixB, sizeB, wasmBufferSize)) { 253 return -1; 254 } 255 256 // Actual call to the 3rd party library (intgemm) 257 const int8_t* inputMatrixBQuantizedTransposedPtr = 258 reinterpret_cast<const int8_t*>( 259 &memBase[inputMatrixBQuantizedTransposed]); 260 int8_t* outputMatrixBPtr = reinterpret_cast<int8_t*>(&memBase[outputMatrixB]); 261 AutoProfilerMarker marker(cx->runtime()->geckoProfiler(), 262 "intgemm::PrepareBQuantizedTransposed", 263 FMT_STRING("rowsB: {}, colsB: {}"), rowsB, colsB); 264 GEMMOLOGY_DISPATCH(PrepareBQuantizedTransposed) 265 (inputMatrixBQuantizedTransposedPtr, outputMatrixBPtr, rowsB, colsB); 266 return 0; 267 } 268 269 int32_t js::intgemm::IntrI8PrepareA(wasm::Instance* instance, 270 uint32_t inputMatrixA, float scale, 271 float zeroPoint, uint32_t rowsA, 272 uint32_t colsA, uint32_t outputMatrixA, 273 uint8_t* memBase) { 274 MOZ_ASSERT(wasm::SASigIntrI8PrepareA.failureMode == 275 wasm::FailureMode::FailOnNegI32); 276 JSContext* cx = instance->cx(); 277 AutoUnsafeCallWithABI unsafe; 278 279 // Size checks for matricies 280 if (!CheckMatrixDimension(rowsA, ROWS_A_MULTIPLIER) || 281 !CheckMatrixDimension(colsA, COLUMNS_A_MULTIPLIER)) { 282 return -1; 283 } 284 285 // Memory Bound checks for all matricies 286 uint64_t sizeA = (uint64_t)rowsA * (uint64_t)colsA; 287 size_t wasmBufferSize = GetWasmRawBufferLength(memBase); 288 if (!CheckMatrixBoundAndAlignment(inputMatrixA, sizeA, wasmBufferSize) || 289 !CheckMatrixBoundAndAlignment(outputMatrixA, sizeA, wasmBufferSize)) { 290 return -1; 291 } 292 293 // Actual call to the 3rd party library (intgemm) 294 const float* inputMatrixAPtr = 295 reinterpret_cast<const float*>(&memBase[inputMatrixA]); 296 uint8_t* outputMatrixAPtr = &memBase[outputMatrixA]; 297 AutoProfilerMarker marker(cx->runtime()->geckoProfiler(), "intgemm::PrepareA", 298 FMT_STRING("rowsA: {}, colsA: {}"), rowsA, colsA); 299 GEMMOLOGY_DISPATCH(Shift::PrepareA) 300 (inputMatrixAPtr, outputMatrixAPtr, scale, rowsA, colsA); 301 return 0; 302 } 303 304 int32_t js::intgemm::IntrI8PrepareBias( 305 wasm::Instance* instance, uint32_t inputMatrixBPrepared, float scaleA, 306 float zeroPointA, float scaleB, float zeroPointB, uint32_t rowsB, 307 uint32_t colsB, uint32_t inputBias, uint32_t output, uint8_t* memBase) { 308 MOZ_ASSERT(wasm::SASigIntrI8PrepareBias.failureMode == 309 wasm::FailureMode::FailOnNegI32); 310 JSContext* cx = instance->cx(); 311 AutoUnsafeCallWithABI unsafe; 312 313 // Size checks for matricies 314 if (!CheckMatrixDimension(rowsB, ROWS_B_MULTIPLIER) || 315 !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER)) { 316 return -1; 317 } 318 319 // Memory Bound checks for all matrices 320 uint64_t sizeB = (uint64_t)rowsB * (uint64_t)colsB; 321 uint64_t sizeBias = colsB; 322 size_t wasmBufferSize = GetWasmRawBufferLength(memBase); 323 if (!CheckMatrixBoundAndAlignment(inputMatrixBPrepared, sizeB, 324 wasmBufferSize) || 325 !CheckMatrixBound(output, sizeBias, wasmBufferSize)) { 326 return -1; 327 } 328 329 // Actual call to the 3rd party library (intgemm) 330 const int8_t* inputMatrixBPreparedPtr = 331 (const int8_t*)&memBase[inputMatrixBPrepared]; 332 float* outputPtr = (float*)&memBase[output]; 333 float unquantFactor = 334 (-1) * ((127.0f / scaleA) * (127.0f / scaleB)) / (127.0f); 335 336 if (inputBias) { 337 if (!CheckMatrixBound(inputBias, sizeBias, wasmBufferSize)) { 338 return -1; 339 } 340 const float* inputBiasPtr = reinterpret_cast<float*>(&memBase[inputBias]); 341 342 AutoProfilerMarker marker( 343 cx->runtime()->geckoProfiler(), "intgemm::PrepareBias w/ input bias", 344 FMT_STRING("rowsB: {} colsB: {} sizeB: {}"), rowsB, colsB, sizeB); 345 GEMMOLOGY_DISPATCH(Shift::PrepareBias) 346 (inputMatrixBPreparedPtr, rowsB, colsB, 347 gemmology::callbacks::UnquantizeAndAddBiasAndWrite( 348 unquantFactor, inputBiasPtr, outputPtr)); 349 } else { 350 AutoProfilerMarker marker( 351 cx->runtime()->geckoProfiler(), "intgemm::PrepareBias", 352 FMT_STRING("rowsB: {} colsB: {} sizeB: {}"), rowsB, colsB, sizeB); 353 GEMMOLOGY_DISPATCH(Shift::PrepareBias) 354 (inputMatrixBPreparedPtr, rowsB, colsB, 355 gemmology::callbacks::UnquantizeAndWrite(unquantFactor, outputPtr)); 356 } 357 return 0; 358 } 359 360 int32_t js::intgemm::IntrI8MultiplyAndAddBias( 361 wasm::Instance* instance, uint32_t inputMatrixAPrepared, float scaleA, 362 float zeroPointA, uint32_t inputMatrixBPrepared, float scaleB, 363 float zeroPointB, uint32_t inputBiasPrepared, float unquantMultiplier, 364 uint32_t rowsA, uint32_t width, uint32_t colsB, uint32_t output, 365 uint8_t* memBase) { 366 MOZ_ASSERT(wasm::SASigIntrI8MultiplyAndAddBias.failureMode == 367 wasm::FailureMode::FailOnNegI32); 368 JSContext* cx = instance->cx(); 369 AutoUnsafeCallWithABI unsafe; 370 371 // Size checks for matricies 372 if (!CheckMatrixDimension(rowsA, ROWS_A_MULTIPLIER) || 373 !CheckMatrixDimension(width, COLUMNS_A_MULTIPLIER) || 374 !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER)) { 375 return -1; 376 } 377 378 // Memory Bound checks for all matricies 379 uint64_t sizeA = (uint64_t)rowsA * (uint64_t)width; 380 uint64_t sizeB = (uint64_t)width * (uint64_t)colsB; 381 uint64_t sizeBias = (uint64_t)colsB; 382 uint64_t sizeOutput = (uint64_t)rowsA * (uint64_t)colsB; 383 size_t wasmBufferSize = GetWasmRawBufferLength(memBase); 384 if (!CheckMatrixBoundAndAlignment(inputMatrixAPrepared, sizeA, 385 wasmBufferSize) || 386 !CheckMatrixBoundAndAlignment(inputMatrixBPrepared, sizeB, 387 wasmBufferSize) || 388 !CheckMatrixBound(inputBiasPrepared, sizeBias, wasmBufferSize) || 389 !CheckMatrixBound(output, sizeOutput, wasmBufferSize)) { 390 return -1; 391 } 392 393 // Actual call to the 3rd party library (intgemm) 394 const uint8_t* inputMatrixAPreparedPtr = &memBase[inputMatrixAPrepared]; 395 const int8_t* inputMatrixBPreparedPtr = 396 reinterpret_cast<const int8_t*>(&memBase[inputMatrixBPrepared]); 397 const float* inputBiasPreparedPtr = 398 reinterpret_cast<const float*>(&memBase[inputBiasPrepared]); 399 float* outputPtr = reinterpret_cast<float*>(&memBase[output]); 400 float unquantFactor = unquantMultiplier / (scaleA * scaleB); 401 402 AutoProfilerMarker marker( 403 cx->runtime()->geckoProfiler(), "intgemm::Shift::Multiply", 404 FMT_STRING("rowsA: {}, width: {}, colsA: {}"), rowsA, width, colsB); 405 GEMMOLOGY_DISPATCH(Shift::Multiply) 406 (inputMatrixAPreparedPtr, inputMatrixBPreparedPtr, rowsA, width, colsB, 407 gemmology::callbacks::UnquantizeAndAddBiasAndWrite( 408 unquantFactor, inputBiasPreparedPtr, outputPtr)); 409 return 0; 410 } 411 412 int32_t js::intgemm::IntrI8SelectColumnsOfB(wasm::Instance* instance, 413 uint32_t inputMatrixBPrepared, 414 uint32_t rowsB, uint32_t colsB, 415 uint32_t colIndexList, 416 uint32_t sizeColIndexList, 417 uint32_t output, uint8_t* memBase) { 418 MOZ_ASSERT(wasm::SASigIntrI8SelectColumnsOfB.failureMode == 419 wasm::FailureMode::FailOnNegI32); 420 JSContext* cx = instance->cx(); 421 AutoUnsafeCallWithABI unsafe; 422 423 // Size checks for matricies 424 if (!CheckMatrixDimension(rowsB, ROWS_B_MULTIPLIER) || 425 !CheckMatrixDimension(colsB, COLUMNS_B_MULTIPLIER) || 426 !CheckMatrixDimension(sizeColIndexList, SELECTED_COLUMNS_B_MULTIPLIER)) { 427 return -1; 428 } 429 430 // Memory Bound checks for all matricies 431 uint64_t sizeB = (uint64_t)rowsB * (uint64_t)colsB; 432 uint64_t sizeOutput = (uint64_t)rowsB * (uint64_t)sizeColIndexList; 433 size_t wasmBufferSize = GetWasmRawBufferLength(memBase); 434 if (!CheckMatrixBoundAndAlignment(inputMatrixBPrepared, sizeB, 435 wasmBufferSize) || 436 !CheckMatrixBound(colIndexList, sizeColIndexList, wasmBufferSize) || 437 !CheckMatrixBound(output, sizeOutput, wasmBufferSize)) { 438 return -1; 439 } 440 441 // Actual call to the 3rd party library (intgemm) 442 const int8_t* inputMatrixBPreparedPtr = 443 reinterpret_cast<const int8_t*>(&memBase[inputMatrixBPrepared]); 444 const uint32_t* colIndexListPtr = 445 reinterpret_cast<const uint32_t*>(&memBase[colIndexList]); 446 int8_t* outputPtr = reinterpret_cast<int8_t*>(&memBase[output]); 447 AutoProfilerMarker marker( 448 cx->runtime()->geckoProfiler(), "integemm::SelectColumnsB", 449 FMT_STRING("rowsB: {} colsB: {} sizecolList: {}, sizeB: {}"), rowsB, 450 colsB, sizeColIndexList, sizeB); 451 GEMMOLOGY_DISPATCH(SelectColumnsB) 452 (inputMatrixBPreparedPtr, outputPtr, rowsB, colIndexListPtr, 453 colIndexListPtr + sizeColIndexList); 454 return 0; 455 } 456 457 #undef GEMMOLOGY_DISPATCH 458 #undef SUPPORTED_ARCHS