gemmology_fwd.h (7626B)
1 /*************************************************************** 2 * _ * 3 * | | * 4 * __ _ ___ _ __ ___ _ __ ___ ___ | | ___ __ _ _ _ * 5 * / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _ \| |/ _ \ / _` | | | | * 6 * | (_| | __/ | | | | | | | | | | (_) | | (_) | (_| | |_| | * 7 * \__, |\___|_| |_| |_|_| |_| |_|\___/|_|\___/ \__, |\__, | * 8 * __/ | __/ | __/ | * 9 * |___/ |___/ |___/ * 10 * * 11 * version 0.1 * 12 ***************************************************************/ 13 14 #ifndef GEMMOLOGY_FWD_H 15 #define GEMMOLOGY_FWD_H 16 17 #include <cstdint> 18 #include <cstring> 19 #include <tuple> 20 #include <xsimd/xsimd.hpp> 21 22 namespace gemmology { 23 24 namespace callbacks { 25 26 struct Unquantize { 27 float unquant_mult; 28 template <class Arch> 29 xsimd::batch<float, Arch> operator()(xsimd::batch<int32_t, Arch> total, size_t, size_t, size_t); 30 template <class Arch> 31 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> operator()( 32 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> 33 total, 34 size_t, size_t, size_t); 35 }; 36 37 struct AddBias { 38 const float *bias_addr; 39 template <class Arch> 40 xsimd::batch<float, Arch> operator()(xsimd::batch<float, Arch> total, size_t, size_t col_idx, 41 size_t); 42 template <class Arch> 43 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> 44 operator()( 45 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> total, 46 size_t, size_t col_idx, size_t); 47 }; 48 49 struct Write { 50 float *output_addr; 51 52 Write(float *o) : output_addr(o) {} 53 54 template <class Arch> 55 void operator()(xsimd::batch<float, Arch> result, size_t row_idx, 56 size_t col_idx, size_t col_size); 57 template <class Arch> 58 void operator()(xsimd::batch<int32_t, Arch> result, size_t row_idx, 59 size_t col_idx, size_t col_size); 60 61 template <class Arch> 62 void operator()( 63 std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> result, 64 size_t row_idx, size_t col_idx, size_t col_size); 65 66 template <class Arch> 67 void operator()( 68 std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>> 69 result, 70 size_t row_idx, size_t col_idx, size_t col_size); 71 }; 72 73 struct UnquantizeAndWrite { 74 75 Unquantize unquantize; 76 Write write; 77 78 UnquantizeAndWrite(float factor, float *output) 79 : unquantize{factor}, write{output} {} 80 81 template <class T> 82 void operator()(T const &total, size_t row_idx, size_t col_idx, 83 size_t col_size); 84 }; 85 86 struct UnquantizeAndAddBiasAndWrite { 87 88 Unquantize unquantize; 89 AddBias add_bias; 90 Write write; 91 92 UnquantizeAndAddBiasAndWrite(float factor, const float *bias, float *output) 93 : unquantize{factor}, add_bias{bias}, write{output} {} 94 95 template <class T> 96 void operator()(T const &total, size_t row_idx, size_t col_idx, 97 size_t col_size); 98 }; 99 100 } // namespace callbacks 101 102 // 103 // Arch-specific implementation of each routine 104 // 105 template <class Arch> struct Engine { 106 107 static void QuantizeU(const float *input, uint8_t *output, float quant_mult, 108 size_t size); 109 110 static void Quantize(const float *const input, int8_t *const output, 111 float quant_mult, size_t size); 112 113 template <typename IntegerTy> 114 static void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows, 115 const IntegerTy *cols_begin, 116 const IntegerTy *cols_end); 117 118 static void PrepareBTransposed(const float *input, int8_t *output, 119 float quant_mult, size_t cols, size_t rows); 120 121 static void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output, 122 size_t cols, size_t rows); 123 124 static void PrepareB(const float *input, int8_t *output_shadow, 125 float quant_mult, size_t rows, size_t cols); 126 127 static void PrepareA(const float *input, int8_t *output, float quant_mult, 128 size_t rows, size_t cols); 129 130 struct Shift { 131 132 static void PrepareA(const float *input, uint8_t *output, float quant_mult, 133 size_t rows, size_t cols); 134 135 template <class Callback> 136 static void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows, 137 size_t width, size_t B_cols, Callback callback); 138 139 template <class Callback> 140 static void PrepareBias(const int8_t *B, size_t width, size_t B_cols, 141 Callback C); 142 }; 143 }; 144 145 // 146 // Top-level wrappers that mostly match intgemm API 147 // 148 149 template <class Arch = xsimd::default_arch> 150 inline void QuantizeU(const float *input, uint8_t *output, float quant_mult, 151 size_t size) { 152 return Engine<Arch>::QuantizeU(input, output, quant_mult, size); 153 } 154 155 template <class Arch = xsimd::default_arch> 156 inline void Quantize(const float *const input, int8_t *const output, 157 float quant_mult, size_t size) { 158 return Engine<Arch>::Quantize(input, output, quant_mult, size); 159 } 160 161 template <class Arch = xsimd::default_arch, typename IntegerTy> 162 inline void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows, 163 const IntegerTy *cols_begin, 164 const IntegerTy *cols_end) { 165 return Engine<Arch>::SelectColumnsB(input, output, rows, cols_begin, 166 cols_end); 167 } 168 169 template <class Arch = xsimd::default_arch> 170 inline void PrepareBTransposed(const float *input, int8_t *output, 171 float quant_mult, size_t cols, size_t rows) { 172 return Engine<Arch>::PrepareBTransposed(input, output, quant_mult, cols, 173 rows); 174 } 175 176 template <class Arch = xsimd::default_arch> 177 inline void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output, 178 size_t cols, size_t rows) { 179 return Engine<Arch>::PrepareBQuantizedTransposed(input, output, cols, rows); 180 } 181 182 template <class Arch = xsimd::default_arch> 183 inline void PrepareB(const float *input, int8_t *output_shadow, 184 float quant_mult, size_t rows, size_t cols) { 185 return Engine<Arch>::PrepareB(input, output_shadow, quant_mult, rows, cols); 186 } 187 188 template <class Arch = xsimd::default_arch> 189 inline void PrepareA(const float *input, int8_t *output, float quant_mult, 190 size_t rows, size_t cols) { 191 return Engine<Arch>::PrepareA(input, output, quant_mult, rows, cols); 192 } 193 194 namespace Shift { 195 196 template <class Arch = xsimd::default_arch> 197 inline void PrepareA(const float *input, uint8_t *output, float quant_mult, 198 size_t rows, size_t cols) { 199 return Engine<Arch>::Shift::PrepareA(input, output, quant_mult, rows, cols); 200 } 201 202 template <class Arch = xsimd::default_arch, class Callback> 203 inline void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows, 204 size_t width, size_t B_cols, Callback C) { 205 return Engine<Arch>::Shift::Multiply(A, B, A_rows, width, B_cols, C); 206 } 207 208 template <class Arch = xsimd::default_arch, class Callback> 209 inline void PrepareBias(const int8_t *B, size_t width, size_t B_cols, 210 Callback C) { 211 return Engine<Arch>::Shift::PrepareBias(B, width, B_cols, C); 212 } 213 214 } // namespace Shift 215 216 } // namespace gemmology 217 218 #endif