tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

gemmology_fwd.h (7626B)


      1 /***************************************************************
      2 *                                       _                     *
      3 *                                      | |                    *
      4 *   __ _  ___ _ __ ___  _ __ ___   ___ | | ___   __ _ _   _   *
      5 *  / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _ \| |/ _ \ / _` | | | |  *
      6 * | (_| |  __/ | | | | | | | | | | (_) | | (_) | (_| | |_| |  *
      7 *  \__, |\___|_| |_| |_|_| |_| |_|\___/|_|\___/ \__, |\__, |  *
      8 *   __/ |                                        __/ | __/ |  *
      9 *  |___/                                        |___/ |___/   *
     10 *                                                             *
     11 *                                                 version 0.1 *
     12 ***************************************************************/
     13 
     14 #ifndef GEMMOLOGY_FWD_H
     15 #define GEMMOLOGY_FWD_H
     16 
     17 #include <cstdint>
     18 #include <cstring>
     19 #include <tuple>
     20 #include <xsimd/xsimd.hpp>
     21 
     22 namespace gemmology {
     23 
     24 namespace callbacks {
     25 
     26 struct Unquantize {
     27  float unquant_mult;
     28  template <class Arch>
     29  xsimd::batch<float, Arch> operator()(xsimd::batch<int32_t, Arch> total, size_t, size_t, size_t);
     30  template <class Arch>
     31  std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> operator()(
     32      std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
     33          total,
     34      size_t, size_t, size_t);
     35 };
     36 
     37 struct AddBias {
     38  const float *bias_addr;
     39  template <class Arch>
     40  xsimd::batch<float, Arch> operator()(xsimd::batch<float, Arch> total, size_t, size_t col_idx,
     41                  size_t);
     42  template <class Arch>
     43  std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>>
     44  operator()(
     45      std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> total,
     46      size_t, size_t col_idx, size_t);
     47 };
     48 
     49 struct Write {
     50  float *output_addr;
     51 
     52  Write(float *o) : output_addr(o) {}
     53 
     54  template <class Arch>
     55  void operator()(xsimd::batch<float, Arch> result, size_t row_idx,
     56                  size_t col_idx, size_t col_size);
     57  template <class Arch>
     58  void operator()(xsimd::batch<int32_t, Arch> result, size_t row_idx,
     59                  size_t col_idx, size_t col_size);
     60 
     61  template <class Arch>
     62  void operator()(
     63      std::tuple<xsimd::batch<float, Arch>, xsimd::batch<float, Arch>> result,
     64      size_t row_idx, size_t col_idx, size_t col_size);
     65 
     66  template <class Arch>
     67  void operator()(
     68      std::tuple<xsimd::batch<int32_t, Arch>, xsimd::batch<int32_t, Arch>>
     69          result,
     70      size_t row_idx, size_t col_idx, size_t col_size);
     71 };
     72 
     73 struct UnquantizeAndWrite {
     74 
     75  Unquantize unquantize;
     76  Write write;
     77 
     78  UnquantizeAndWrite(float factor, float *output)
     79      : unquantize{factor}, write{output} {}
     80 
     81  template <class T>
     82  void operator()(T const &total, size_t row_idx, size_t col_idx,
     83                  size_t col_size);
     84 };
     85 
     86 struct UnquantizeAndAddBiasAndWrite {
     87 
     88  Unquantize unquantize;
     89  AddBias add_bias;
     90  Write write;
     91 
     92  UnquantizeAndAddBiasAndWrite(float factor, const float *bias, float *output)
     93      : unquantize{factor}, add_bias{bias}, write{output} {}
     94 
     95  template <class T>
     96  void operator()(T const &total, size_t row_idx, size_t col_idx,
     97                  size_t col_size);
     98 };
     99 
    100 } // namespace callbacks
    101 
    102 //
    103 // Arch-specific implementation of each routine
    104 //
    105 template <class Arch> struct Engine {
    106 
    107  static void QuantizeU(const float *input, uint8_t *output, float quant_mult,
    108                        size_t size);
    109 
    110  static void Quantize(const float *const input, int8_t *const output,
    111                       float quant_mult, size_t size);
    112 
    113  template <typename IntegerTy>
    114  static void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows,
    115                             const IntegerTy *cols_begin,
    116                             const IntegerTy *cols_end);
    117 
    118  static void PrepareBTransposed(const float *input, int8_t *output,
    119                                 float quant_mult, size_t cols, size_t rows);
    120 
    121  static void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output,
    122                                          size_t cols, size_t rows);
    123 
    124  static void PrepareB(const float *input, int8_t *output_shadow,
    125                       float quant_mult, size_t rows, size_t cols);
    126 
    127  static void PrepareA(const float *input, int8_t *output, float quant_mult,
    128                       size_t rows, size_t cols);
    129 
    130  struct Shift {
    131 
    132    static void PrepareA(const float *input, uint8_t *output, float quant_mult,
    133                         size_t rows, size_t cols);
    134 
    135    template <class Callback>
    136    static void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows,
    137                         size_t width, size_t B_cols, Callback callback);
    138 
    139    template <class Callback>
    140    static void PrepareBias(const int8_t *B, size_t width, size_t B_cols,
    141                            Callback C);
    142  };
    143 };
    144 
    145 //
    146 // Top-level wrappers that mostly match intgemm API
    147 //
    148 
    149 template <class Arch = xsimd::default_arch>
    150 inline void QuantizeU(const float *input, uint8_t *output, float quant_mult,
    151                      size_t size) {
    152  return Engine<Arch>::QuantizeU(input, output, quant_mult, size);
    153 }
    154 
    155 template <class Arch = xsimd::default_arch>
    156 inline void Quantize(const float *const input, int8_t *const output,
    157                     float quant_mult, size_t size) {
    158  return Engine<Arch>::Quantize(input, output, quant_mult, size);
    159 }
    160 
    161 template <class Arch = xsimd::default_arch, typename IntegerTy>
    162 inline void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows,
    163                           const IntegerTy *cols_begin,
    164                           const IntegerTy *cols_end) {
    165  return Engine<Arch>::SelectColumnsB(input, output, rows, cols_begin,
    166                                      cols_end);
    167 }
    168 
    169 template <class Arch = xsimd::default_arch>
    170 inline void PrepareBTransposed(const float *input, int8_t *output,
    171                               float quant_mult, size_t cols, size_t rows) {
    172  return Engine<Arch>::PrepareBTransposed(input, output, quant_mult, cols,
    173                                          rows);
    174 }
    175 
    176 template <class Arch = xsimd::default_arch>
    177 inline void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output,
    178                                        size_t cols, size_t rows) {
    179  return Engine<Arch>::PrepareBQuantizedTransposed(input, output, cols, rows);
    180 }
    181 
    182 template <class Arch = xsimd::default_arch>
    183 inline void PrepareB(const float *input, int8_t *output_shadow,
    184                     float quant_mult, size_t rows, size_t cols) {
    185  return Engine<Arch>::PrepareB(input, output_shadow, quant_mult, rows, cols);
    186 }
    187 
    188 template <class Arch = xsimd::default_arch>
    189 inline void PrepareA(const float *input, int8_t *output, float quant_mult,
    190                     size_t rows, size_t cols) {
    191  return Engine<Arch>::PrepareA(input, output, quant_mult, rows, cols);
    192 }
    193 
    194 namespace Shift {
    195 
    196 template <class Arch = xsimd::default_arch>
    197 inline void PrepareA(const float *input, uint8_t *output, float quant_mult,
    198                     size_t rows, size_t cols) {
    199  return Engine<Arch>::Shift::PrepareA(input, output, quant_mult, rows, cols);
    200 }
    201 
    202 template <class Arch = xsimd::default_arch, class Callback>
    203 inline void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows,
    204                     size_t width, size_t B_cols, Callback C) {
    205  return Engine<Arch>::Shift::Multiply(A, B, A_rows, width, B_cols, C);
    206 }
    207 
    208 template <class Arch = xsimd::default_arch, class Callback>
    209 inline void PrepareBias(const int8_t *B, size_t width, size_t B_cols,
    210                        Callback C) {
    211  return Engine<Arch>::Shift::PrepareBias(B, width, B_cols, C);
    212 }
    213 
    214 } // namespace Shift
    215 
    216 } // namespace gemmology
    217 
    218 #endif