tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

enc_fast_lossless.cc (157763B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/base/status.h"
      7 #ifndef FJXL_SELF_INCLUDE
      8 
      9 #include <assert.h>
     10 #include <stdint.h>
     11 #include <string.h>
     12 
     13 #include <algorithm>
     14 #include <array>
     15 #include <limits>
     16 #include <memory>
     17 #include <vector>
     18 
     19 #include "lib/jxl/enc_fast_lossless.h"
     20 
     21 #if FJXL_STANDALONE
     22 #if defined(_MSC_VER)
     23 using ssize_t = intptr_t;
     24 #endif
     25 #else  // FJXL_STANDALONE
     26 #include "lib/jxl/encode_internal.h"
     27 #endif  // FJXL_STANDALONE
     28 
     29 #if defined(__x86_64__) || defined(_M_X64)
     30 #define FJXL_ARCH_IS_X86_64 1
     31 #else
     32 #define FJXL_ARCH_IS_X86_64 0
     33 #endif
     34 
     35 #if defined(__i386__) || defined(_M_IX86) || FJXL_ARCH_IS_X86_64
     36 #define FJXL_ARCH_IS_X86 1
     37 #else
     38 #define FJXL_ARCH_IS_X86 0
     39 #endif
     40 
     41 #if FJXL_ARCH_IS_X86
     42 #if defined(_MSC_VER)
     43 #include <intrin.h>
     44 #else  // _MSC_VER
     45 #include <cpuid.h>
     46 #endif  // _MSC_VER
     47 #endif  // FJXL_ARCH_IS_X86
     48 
     49 // Enable NEON and AVX2/AVX512 if not asked to do otherwise and the compilers
     50 // support it.
     51 #if defined(__aarch64__) || defined(_M_ARM64)  // ARCH
     52 #include <arm_neon.h>
     53 
     54 #if !defined(FJXL_ENABLE_NEON)
     55 #define FJXL_ENABLE_NEON 1
     56 #endif  // !defined(FJXL_ENABLE_NEON)
     57 
     58 #elif FJXL_ARCH_IS_X86_64 && !defined(_MSC_VER)  // ARCH
     59 #include <immintrin.h>
     60 
     61 // manually add _mm512_cvtsi512_si32 definition if missing
     62 // (e.g. with Xcode on macOS Mojave)
     63 // copied from gcc 11.1.0 include/avx512fintrin.h line 14367-14373
     64 #if defined(__clang__) &&                                           \
     65    ((!defined(__apple_build_version__) && __clang_major__ < 10) || \
     66     (defined(__apple_build_version__) && __apple_build_version__ < 12000032))
     67 inline int __attribute__((__gnu_inline__, __always_inline__, __artificial__))
     68 _mm512_cvtsi512_si32(__m512i __A) {
     69  __v16si __B = (__v16si)__A;
     70  return __B[0];
     71 }
     72 #endif
     73 
     74 #if !defined(FJXL_ENABLE_AVX2)
     75 #define FJXL_ENABLE_AVX2 1
     76 #endif  // !defined(FJXL_ENABLE_AVX2)
     77 
     78 #if !defined(FJXL_ENABLE_AVX512)
     79 // On clang-7 or earlier, and gcc-10 or earlier, AVX512 seems broken.
     80 #if (defined(__clang__) &&                                             \
     81         (!defined(__apple_build_version__) && __clang_major__ > 7) || \
     82     (defined(__apple_build_version__) &&                              \
     83      __apple_build_version__ > 10010046)) ||                          \
     84    (defined(__GNUC__) && __GNUC__ > 10)
     85 #define FJXL_ENABLE_AVX512 1
     86 #endif
     87 #endif  // !defined(FJXL_ENABLE_AVX512)
     88 
     89 #endif  // ARCH
     90 
     91 #ifndef FJXL_ENABLE_NEON
     92 #define FJXL_ENABLE_NEON 0
     93 #endif
     94 
     95 #ifndef FJXL_ENABLE_AVX2
     96 #define FJXL_ENABLE_AVX2 0
     97 #endif
     98 
     99 #ifndef FJXL_ENABLE_AVX512
    100 #define FJXL_ENABLE_AVX512 0
    101 #endif
    102 
    103 namespace {
    104 
    105 enum class CpuFeature : uint32_t {
    106  kAVX2 = 0,
    107 
    108  kAVX512F,
    109  kAVX512VL,
    110  kAVX512CD,
    111  kAVX512BW,
    112 
    113  kVBMI,
    114  kVBMI2
    115 };
    116 
    117 constexpr uint32_t CpuFeatureBit(CpuFeature feature) {
    118  return 1u << static_cast<uint32_t>(feature);
    119 }
    120 
    121 #if FJXL_ARCH_IS_X86
    122 #if defined(_MSC_VER)
    123 void Cpuid(const uint32_t level, const uint32_t count,
    124           std::array<uint32_t, 4>& abcd) {
    125  int regs[4];
    126  __cpuidex(regs, level, count);
    127  for (int i = 0; i < 4; ++i) {
    128    abcd[i] = regs[i];
    129  }
    130 }
    131 uint32_t ReadXCR0() { return static_cast<uint32_t>(_xgetbv(0)); }
    132 #else   // _MSC_VER
    133 void Cpuid(const uint32_t level, const uint32_t count,
    134           std::array<uint32_t, 4>& abcd) {
    135  uint32_t a;
    136  uint32_t b;
    137  uint32_t c;
    138  uint32_t d;
    139  __cpuid_count(level, count, a, b, c, d);
    140  abcd[0] = a;
    141  abcd[1] = b;
    142  abcd[2] = c;
    143  abcd[3] = d;
    144 }
    145 uint32_t ReadXCR0() {
    146  uint32_t xcr0;
    147  uint32_t xcr0_high;
    148  const uint32_t index = 0;
    149  asm volatile(".byte 0x0F, 0x01, 0xD0"
    150               : "=a"(xcr0), "=d"(xcr0_high)
    151               : "c"(index));
    152  return xcr0;
    153 }
    154 #endif  // _MSC_VER
    155 
    156 uint32_t DetectCpuFeatures() {
    157  uint32_t flags = 0;  // return value
    158  std::array<uint32_t, 4> abcd;
    159  Cpuid(0, 0, abcd);
    160  const uint32_t max_level = abcd[0];
    161 
    162  const auto check_bit = [](uint32_t v, uint32_t idx) -> bool {
    163    return (v & (1U << idx)) != 0;
    164  };
    165 
    166  // Extended features
    167  if (max_level >= 7) {
    168    Cpuid(7, 0, abcd);
    169    flags |= check_bit(abcd[1], 5) ? CpuFeatureBit(CpuFeature::kAVX2) : 0;
    170 
    171    flags |= check_bit(abcd[1], 16) ? CpuFeatureBit(CpuFeature::kAVX512F) : 0;
    172    flags |= check_bit(abcd[1], 28) ? CpuFeatureBit(CpuFeature::kAVX512CD) : 0;
    173    flags |= check_bit(abcd[1], 30) ? CpuFeatureBit(CpuFeature::kAVX512BW) : 0;
    174    flags |= check_bit(abcd[1], 31) ? CpuFeatureBit(CpuFeature::kAVX512VL) : 0;
    175 
    176    flags |= check_bit(abcd[2], 1) ? CpuFeatureBit(CpuFeature::kVBMI) : 0;
    177    flags |= check_bit(abcd[2], 6) ? CpuFeatureBit(CpuFeature::kVBMI2) : 0;
    178  }
    179 
    180  Cpuid(1, 0, abcd);
    181  const bool os_has_xsave = check_bit(abcd[2], 27);
    182  if (os_has_xsave) {
    183    const uint32_t xcr0 = ReadXCR0();
    184    if (!check_bit(xcr0, 1) || !check_bit(xcr0, 2) || !check_bit(xcr0, 5) ||
    185        !check_bit(xcr0, 6) || !check_bit(xcr0, 7)) {
    186      flags = 0;  // TODO(eustas): be more selective?
    187    }
    188  }
    189 
    190  return flags;
    191 }
    192 #else   // FJXL_ARCH_IS_X86
    193 uint32_t DetectCpuFeatures() { return 0; }
    194 #endif  // FJXL_ARCH_IS_X86
    195 
    196 #if defined(_MSC_VER)
    197 #define FJXL_UNUSED
    198 #else
    199 #define FJXL_UNUSED __attribute__((unused))
    200 #endif
    201 
    202 FJXL_UNUSED bool HasCpuFeature(CpuFeature feature) {
    203  static uint32_t cpu_features = DetectCpuFeatures();
    204  return (cpu_features & CpuFeatureBit(feature)) != 0;
    205 }
    206 
    207 #if defined(_MSC_VER) && !defined(__clang__)
    208 #define FJXL_INLINE __forceinline
    209 FJXL_INLINE uint32_t FloorLog2(uint32_t v) {
    210  unsigned long index;
    211  _BitScanReverse(&index, v);
    212  return index;
    213 }
    214 FJXL_INLINE uint32_t CtzNonZero(uint64_t v) {
    215  unsigned long index;
    216  _BitScanForward(&index, v);
    217  return index;
    218 }
    219 #else
    220 #define FJXL_INLINE inline __attribute__((always_inline))
    221 FJXL_INLINE uint32_t FloorLog2(uint32_t v) {
    222  return v ? 31 - __builtin_clz(v) : 0;
    223 }
    224 FJXL_UNUSED FJXL_INLINE uint32_t CtzNonZero(uint64_t v) {
    225  return __builtin_ctzll(v);
    226 }
    227 #endif
    228 
    229 // Compiles to a memcpy on little-endian systems.
    230 FJXL_INLINE void StoreLE64(uint8_t* tgt, uint64_t data) {
    231 #if (!defined(__BYTE_ORDER__) || (__BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__))
    232  for (int i = 0; i < 8; i++) {
    233    tgt[i] = (data >> (i * 8)) & 0xFF;
    234  }
    235 #else
    236  memcpy(tgt, &data, 8);
    237 #endif
    238 }
    239 
    240 FJXL_INLINE size_t AddBits(uint32_t count, uint64_t bits, uint8_t* data_buf,
    241                           size_t& bits_in_buffer, uint64_t& bit_buffer) {
    242  bit_buffer |= bits << bits_in_buffer;
    243  bits_in_buffer += count;
    244  StoreLE64(data_buf, bit_buffer);
    245  size_t bytes_in_buffer = bits_in_buffer / 8;
    246  bits_in_buffer -= bytes_in_buffer * 8;
    247  bit_buffer >>= bytes_in_buffer * 8;
    248  return bytes_in_buffer;
    249 }
    250 
    251 struct BitWriter {
    252  void Allocate(size_t maximum_bit_size) {
    253    assert(data == nullptr);
    254    // Leave some padding.
    255    data.reset(static_cast<uint8_t*>(malloc(maximum_bit_size / 8 + 64)));
    256  }
    257 
    258  void Write(uint32_t count, uint64_t bits) {
    259    bytes_written += AddBits(count, bits, data.get() + bytes_written,
    260                             bits_in_buffer, buffer);
    261  }
    262 
    263  void ZeroPadToByte() {
    264    if (bits_in_buffer != 0) {
    265      Write(8 - bits_in_buffer, 0);
    266    }
    267  }
    268 
    269  FJXL_INLINE void WriteMultiple(const uint64_t* nbits, const uint64_t* bits,
    270                                 size_t n) {
    271    // Necessary because Write() is only guaranteed to work with <=56 bits.
    272    // Trying to SIMD-fy this code results in lower speed (and definitely less
    273    // clarity).
    274    {
    275      for (size_t i = 0; i < n; i++) {
    276        this->buffer |= bits[i] << this->bits_in_buffer;
    277        memcpy(this->data.get() + this->bytes_written, &this->buffer, 8);
    278        uint64_t shift = 64 - this->bits_in_buffer;
    279        this->bits_in_buffer += nbits[i];
    280        // This `if` seems to be faster than using ternaries.
    281        if (this->bits_in_buffer >= 64) {
    282          uint64_t next_buffer = bits[i] >> shift;
    283          this->buffer = next_buffer;
    284          this->bits_in_buffer -= 64;
    285          this->bytes_written += 8;
    286        }
    287      }
    288      memcpy(this->data.get() + this->bytes_written, &this->buffer, 8);
    289      size_t bytes_in_buffer = this->bits_in_buffer / 8;
    290      this->bits_in_buffer -= bytes_in_buffer * 8;
    291      this->buffer >>= bytes_in_buffer * 8;
    292      this->bytes_written += bytes_in_buffer;
    293    }
    294  }
    295 
    296  std::unique_ptr<uint8_t[], void (*)(void*)> data = {nullptr, free};
    297  size_t bytes_written = 0;
    298  size_t bits_in_buffer = 0;
    299  uint64_t buffer = 0;
    300 };
    301 
    302 size_t SectionSize(const std::array<BitWriter, 4>& group_data) {
    303  size_t sz = 0;
    304  for (size_t j = 0; j < 4; j++) {
    305    const auto& writer = group_data[j];
    306    sz += writer.bytes_written * 8 + writer.bits_in_buffer;
    307  }
    308  sz = (sz + 7) / 8;
    309  return sz;
    310 }
    311 
    312 constexpr size_t kMaxFrameHeaderSize = 5;
    313 
    314 constexpr size_t kGroupSizeOffset[4] = {
    315    static_cast<size_t>(0),
    316    static_cast<size_t>(1024),
    317    static_cast<size_t>(17408),
    318    static_cast<size_t>(4211712),
    319 };
    320 constexpr size_t kTOCBits[4] = {12, 16, 24, 32};
    321 
    322 size_t TOCBucket(size_t group_size) {
    323  size_t bucket = 0;
    324  while (bucket < 3 && group_size >= kGroupSizeOffset[bucket + 1]) ++bucket;
    325  return bucket;
    326 }
    327 
    328 #if !FJXL_STANDALONE
    329 size_t TOCSize(const std::vector<size_t>& group_sizes) {
    330  size_t toc_bits = 0;
    331  for (size_t group_size : group_sizes) {
    332    toc_bits += kTOCBits[TOCBucket(group_size)];
    333  }
    334  return (toc_bits + 7) / 8;
    335 }
    336 
    337 size_t FrameHeaderSize(bool have_alpha, bool is_last) {
    338  size_t nbits = 28 + (have_alpha ? 4 : 0) + (is_last ? 0 : 2);
    339  return (nbits + 7) / 8;
    340 }
    341 #endif
    342 
    343 void ComputeAcGroupDataOffset(size_t dc_global_size, size_t num_dc_groups,
    344                              size_t num_ac_groups, size_t& min_dc_global_size,
    345                              size_t& ac_group_offset) {
    346  // Max AC group size is 768 kB, so max AC group TOC bits is 24.
    347  size_t ac_toc_max_bits = num_ac_groups * 24;
    348  size_t ac_toc_min_bits = num_ac_groups * 12;
    349  size_t max_padding = 1 + (ac_toc_max_bits - ac_toc_min_bits + 7) / 8;
    350  min_dc_global_size = dc_global_size;
    351  size_t dc_global_bucket = TOCBucket(min_dc_global_size);
    352  while (TOCBucket(min_dc_global_size + max_padding) > dc_global_bucket) {
    353    dc_global_bucket = TOCBucket(min_dc_global_size + max_padding);
    354    min_dc_global_size = kGroupSizeOffset[dc_global_bucket];
    355  }
    356  assert(TOCBucket(min_dc_global_size) == dc_global_bucket);
    357  assert(TOCBucket(min_dc_global_size + max_padding) == dc_global_bucket);
    358  size_t max_toc_bits =
    359      kTOCBits[dc_global_bucket] + 12 * (1 + num_dc_groups) + ac_toc_max_bits;
    360  size_t max_toc_size = (max_toc_bits + 7) / 8;
    361  ac_group_offset = kMaxFrameHeaderSize + max_toc_size + min_dc_global_size;
    362 }
    363 
    364 #if !FJXL_STANDALONE
    365 size_t ComputeDcGlobalPadding(const std::vector<size_t>& group_sizes,
    366                              size_t ac_group_data_offset,
    367                              size_t min_dc_global_size, bool have_alpha,
    368                              bool is_last) {
    369  std::vector<size_t> new_group_sizes = group_sizes;
    370  new_group_sizes[0] = min_dc_global_size;
    371  size_t toc_size = TOCSize(new_group_sizes);
    372  size_t actual_offset =
    373      FrameHeaderSize(have_alpha, is_last) + toc_size + group_sizes[0];
    374  return ac_group_data_offset - actual_offset;
    375 }
    376 #endif
    377 
    378 constexpr size_t kNumRawSymbols = 19;
    379 constexpr size_t kNumLZ77 = 33;
    380 constexpr size_t kLZ77CacheSize = 32;
    381 
    382 constexpr size_t kLZ77Offset = 224;
    383 constexpr size_t kLZ77MinLength = 7;
    384 
    385 void EncodeHybridUintLZ77(uint32_t value, uint32_t* token, uint32_t* nbits,
    386                          uint32_t* bits) {
    387  // 400 config
    388  uint32_t n = FloorLog2(value);
    389  *token = value < 16 ? value : 16 + n - 4;
    390  *nbits = value < 16 ? 0 : n;
    391  *bits = value < 16 ? 0 : value - (1 << *nbits);
    392 }
    393 
    394 struct PrefixCode {
    395  uint8_t raw_nbits[kNumRawSymbols] = {};
    396  uint8_t raw_bits[kNumRawSymbols] = {};
    397 
    398  uint8_t lz77_nbits[kNumLZ77] = {};
    399  uint16_t lz77_bits[kNumLZ77] = {};
    400 
    401  uint64_t lz77_cache_bits[kLZ77CacheSize] = {};
    402  uint8_t lz77_cache_nbits[kLZ77CacheSize] = {};
    403 
    404  size_t numraw;
    405 
    406  static uint16_t BitReverse(size_t nbits, uint16_t bits) {
    407    constexpr uint16_t kNibbleLookup[16] = {
    408        0b0000, 0b1000, 0b0100, 0b1100, 0b0010, 0b1010, 0b0110, 0b1110,
    409        0b0001, 0b1001, 0b0101, 0b1101, 0b0011, 0b1011, 0b0111, 0b1111,
    410    };
    411    uint16_t rev16 = (kNibbleLookup[bits & 0xF] << 12) |
    412                     (kNibbleLookup[(bits >> 4) & 0xF] << 8) |
    413                     (kNibbleLookup[(bits >> 8) & 0xF] << 4) |
    414                     (kNibbleLookup[bits >> 12]);
    415    return rev16 >> (16 - nbits);
    416  }
    417 
    418  // Create the prefix codes given the code lengths.
    419  // Supports the code lengths being split into two halves.
    420  static void ComputeCanonicalCode(const uint8_t* first_chunk_nbits,
    421                                   uint8_t* first_chunk_bits,
    422                                   size_t first_chunk_size,
    423                                   const uint8_t* second_chunk_nbits,
    424                                   uint16_t* second_chunk_bits,
    425                                   size_t second_chunk_size) {
    426    constexpr size_t kMaxCodeLength = 15;
    427    uint8_t code_length_counts[kMaxCodeLength + 1] = {};
    428    for (size_t i = 0; i < first_chunk_size; i++) {
    429      code_length_counts[first_chunk_nbits[i]]++;
    430      assert(first_chunk_nbits[i] <= kMaxCodeLength);
    431      assert(first_chunk_nbits[i] <= 8);
    432      assert(first_chunk_nbits[i] > 0);
    433    }
    434    for (size_t i = 0; i < second_chunk_size; i++) {
    435      code_length_counts[second_chunk_nbits[i]]++;
    436      assert(second_chunk_nbits[i] <= kMaxCodeLength);
    437    }
    438 
    439    uint16_t next_code[kMaxCodeLength + 1] = {};
    440 
    441    uint16_t code = 0;
    442    for (size_t i = 1; i < kMaxCodeLength + 1; i++) {
    443      code = (code + code_length_counts[i - 1]) << 1;
    444      next_code[i] = code;
    445    }
    446 
    447    for (size_t i = 0; i < first_chunk_size; i++) {
    448      first_chunk_bits[i] =
    449          BitReverse(first_chunk_nbits[i], next_code[first_chunk_nbits[i]]++);
    450    }
    451    for (size_t i = 0; i < second_chunk_size; i++) {
    452      second_chunk_bits[i] =
    453          BitReverse(second_chunk_nbits[i], next_code[second_chunk_nbits[i]]++);
    454    }
    455  }
    456 
    457  template <typename T>
    458  static void ComputeCodeLengthsNonZeroImpl(const uint64_t* freqs, size_t n,
    459                                            size_t precision, T infty,
    460                                            const uint8_t* min_limit,
    461                                            const uint8_t* max_limit,
    462                                            uint8_t* nbits) {
    463    assert(precision < 15);
    464    assert(n <= kMaxNumSymbols);
    465    std::vector<T> dynp(((1U << precision) + 1) * (n + 1), infty);
    466    auto d = [&](size_t sym, size_t off) -> T& {
    467      return dynp[sym * ((1 << precision) + 1) + off];
    468    };
    469    d(0, 0) = 0;
    470    for (size_t sym = 0; sym < n; sym++) {
    471      for (T bits = min_limit[sym]; bits <= max_limit[sym]; bits++) {
    472        size_t off_delta = 1U << (precision - bits);
    473        for (size_t off = 0; off + off_delta <= (1U << precision); off++) {
    474          d(sym + 1, off + off_delta) =
    475              std::min(d(sym, off) + static_cast<T>(freqs[sym]) * bits,
    476                       d(sym + 1, off + off_delta));
    477        }
    478      }
    479    }
    480 
    481    size_t sym = n;
    482    size_t off = 1U << precision;
    483 
    484    assert(d(sym, off) != infty);
    485 
    486    while (sym-- > 0) {
    487      assert(off > 0);
    488      for (size_t bits = min_limit[sym]; bits <= max_limit[sym]; bits++) {
    489        size_t off_delta = 1U << (precision - bits);
    490        if (off_delta <= off &&
    491            d(sym + 1, off) == d(sym, off - off_delta) + freqs[sym] * bits) {
    492          off -= off_delta;
    493          nbits[sym] = bits;
    494          break;
    495        }
    496      }
    497    }
    498  }
    499 
    500  // Computes nbits[i] for i <= n, subject to min_limit[i] <= nbits[i] <=
    501  // max_limit[i] and sum 2**-nbits[i] == 1, so to minimize sum(nbits[i] *
    502  // freqs[i]).
    503  static void ComputeCodeLengthsNonZero(const uint64_t* freqs, size_t n,
    504                                        uint8_t* min_limit, uint8_t* max_limit,
    505                                        uint8_t* nbits) {
    506    size_t precision = 0;
    507    size_t shortest_length = 255;
    508    uint64_t freqsum = 0;
    509    for (size_t i = 0; i < n; i++) {
    510      assert(freqs[i] != 0);
    511      freqsum += freqs[i];
    512      if (min_limit[i] < 1) min_limit[i] = 1;
    513      assert(min_limit[i] <= max_limit[i]);
    514      precision = std::max<size_t>(max_limit[i], precision);
    515      shortest_length = std::min<size_t>(min_limit[i], shortest_length);
    516    }
    517    // If all the minimum limits are greater than 1, shift precision so that we
    518    // behave as if the shortest was 1.
    519    precision -= shortest_length - 1;
    520    uint64_t infty = freqsum * precision;
    521    if (infty < std::numeric_limits<uint32_t>::max() / 2) {
    522      ComputeCodeLengthsNonZeroImpl(freqs, n, precision,
    523                                    static_cast<uint32_t>(infty), min_limit,
    524                                    max_limit, nbits);
    525    } else {
    526      ComputeCodeLengthsNonZeroImpl(freqs, n, precision, infty, min_limit,
    527                                    max_limit, nbits);
    528    }
    529  }
    530 
    531  static constexpr size_t kMaxNumSymbols =
    532      kNumRawSymbols + 1 < kNumLZ77 ? kNumLZ77 : kNumRawSymbols + 1;
    533  static void ComputeCodeLengths(const uint64_t* freqs, size_t n,
    534                                 const uint8_t* min_limit_in,
    535                                 const uint8_t* max_limit_in, uint8_t* nbits) {
    536    assert(n <= kMaxNumSymbols);
    537    uint64_t compact_freqs[kMaxNumSymbols];
    538    uint8_t min_limit[kMaxNumSymbols];
    539    uint8_t max_limit[kMaxNumSymbols];
    540    size_t ni = 0;
    541    for (size_t i = 0; i < n; i++) {
    542      if (freqs[i]) {
    543        compact_freqs[ni] = freqs[i];
    544        min_limit[ni] = min_limit_in[i];
    545        max_limit[ni] = max_limit_in[i];
    546        ni++;
    547      }
    548    }
    549    uint8_t num_bits[kMaxNumSymbols] = {};
    550    ComputeCodeLengthsNonZero(compact_freqs, ni, min_limit, max_limit,
    551                              num_bits);
    552    ni = 0;
    553    for (size_t i = 0; i < n; i++) {
    554      nbits[i] = 0;
    555      if (freqs[i]) {
    556        nbits[i] = num_bits[ni++];
    557      }
    558    }
    559  }
    560 
    561  // Invalid code, used to construct arrays.
    562  PrefixCode() = default;
    563 
    564  template <typename BitDepth>
    565  PrefixCode(BitDepth /* bitdepth */, uint64_t* raw_counts,
    566             uint64_t* lz77_counts) {
    567    // "merge" together all the lz77 counts in a single symbol for the level 1
    568    // table (containing just the raw symbols, up to length 7).
    569    uint64_t level1_counts[kNumRawSymbols + 1];
    570    memcpy(level1_counts, raw_counts, kNumRawSymbols * sizeof(uint64_t));
    571    numraw = kNumRawSymbols;
    572    while (numraw > 0 && level1_counts[numraw - 1] == 0) numraw--;
    573 
    574    level1_counts[numraw] = 0;
    575    for (size_t i = 0; i < kNumLZ77; i++) {
    576      level1_counts[numraw] += lz77_counts[i];
    577    }
    578    uint8_t level1_nbits[kNumRawSymbols + 1] = {};
    579    ComputeCodeLengths(level1_counts, numraw + 1, BitDepth::kMinRawLength,
    580                       BitDepth::kMaxRawLength, level1_nbits);
    581 
    582    uint8_t level2_nbits[kNumLZ77] = {};
    583    uint8_t min_lengths[kNumLZ77] = {};
    584    uint8_t l = 15 - level1_nbits[numraw];
    585    uint8_t max_lengths[kNumLZ77];
    586    for (uint8_t& max_length : max_lengths) {
    587      max_length = l;
    588    }
    589    size_t num_lz77 = kNumLZ77;
    590    while (num_lz77 > 0 && lz77_counts[num_lz77 - 1] == 0) num_lz77--;
    591    ComputeCodeLengths(lz77_counts, num_lz77, min_lengths, max_lengths,
    592                       level2_nbits);
    593    for (size_t i = 0; i < numraw; i++) {
    594      raw_nbits[i] = level1_nbits[i];
    595    }
    596    for (size_t i = 0; i < num_lz77; i++) {
    597      lz77_nbits[i] =
    598          level2_nbits[i] ? level1_nbits[numraw] + level2_nbits[i] : 0;
    599    }
    600 
    601    ComputeCanonicalCode(raw_nbits, raw_bits, numraw, lz77_nbits, lz77_bits,
    602                         kNumLZ77);
    603 
    604    // Prepare lz77 cache
    605    for (size_t count = 0; count < kLZ77CacheSize; count++) {
    606      unsigned token, nbits, bits;
    607      EncodeHybridUintLZ77(count, &token, &nbits, &bits);
    608      lz77_cache_nbits[count] = lz77_nbits[token] + nbits + raw_nbits[0];
    609      lz77_cache_bits[count] =
    610          (((bits << lz77_nbits[token]) | lz77_bits[token]) << raw_nbits[0]) |
    611          raw_bits[0];
    612    }
    613  }
    614 
    615  // Max bits written: 2 + 72 + 95 + 24 + 165 = 286
    616  void WriteTo(BitWriter* writer) const {
    617    uint64_t code_length_counts[18] = {};
    618    code_length_counts[17] = 3 + 2 * (kNumLZ77 - 1);
    619    for (uint8_t raw_nbit : raw_nbits) {
    620      code_length_counts[raw_nbit]++;
    621    }
    622    for (uint8_t lz77_nbit : lz77_nbits) {
    623      code_length_counts[lz77_nbit]++;
    624    }
    625    uint8_t code_length_nbits[18] = {};
    626    uint8_t code_length_nbits_min[18] = {};
    627    uint8_t code_length_nbits_max[18] = {
    628        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
    629    };
    630    ComputeCodeLengths(code_length_counts, 18, code_length_nbits_min,
    631                       code_length_nbits_max, code_length_nbits);
    632    writer->Write(2, 0b00);  // HSKIP = 0, i.e. don't skip code lengths.
    633 
    634    // As per Brotli RFC.
    635    uint8_t code_length_order[18] = {1, 2, 3, 4,  0,  5,  17, 6,  16,
    636                                     7, 8, 9, 10, 11, 12, 13, 14, 15};
    637    uint8_t code_length_length_nbits[] = {2, 4, 3, 2, 2, 4};
    638    uint8_t code_length_length_bits[] = {0, 7, 3, 2, 1, 15};
    639 
    640    // Encode lengths of code lengths.
    641    size_t num_code_lengths = 18;
    642    while (code_length_nbits[code_length_order[num_code_lengths - 1]] == 0) {
    643      num_code_lengths--;
    644    }
    645    // Max bits written in this loop: 18 * 4 = 72
    646    for (size_t i = 0; i < num_code_lengths; i++) {
    647      int symbol = code_length_nbits[code_length_order[i]];
    648      writer->Write(code_length_length_nbits[symbol],
    649                    code_length_length_bits[symbol]);
    650    }
    651 
    652    // Compute the canonical codes for the codes that represent the lengths of
    653    // the actual codes for data.
    654    uint16_t code_length_bits[18] = {};
    655    ComputeCanonicalCode(nullptr, nullptr, 0, code_length_nbits,
    656                         code_length_bits, 18);
    657    // Encode raw bit code lengths.
    658    // Max bits written in this loop: 19 * 5 = 95
    659    for (uint8_t raw_nbit : raw_nbits) {
    660      writer->Write(code_length_nbits[raw_nbit], code_length_bits[raw_nbit]);
    661    }
    662    size_t num_lz77 = kNumLZ77;
    663    while (lz77_nbits[num_lz77 - 1] == 0) {
    664      num_lz77--;
    665    }
    666    // Encode 0s until 224 (start of LZ77 symbols). This is in total 224-19 =
    667    // 205.
    668    static_assert(kLZ77Offset == 224);
    669    static_assert(kNumRawSymbols == 19);
    670    {
    671      // Max bits in this block: 24
    672      writer->Write(code_length_nbits[17], code_length_bits[17]);
    673      writer->Write(3, 0b010);  // 5
    674      writer->Write(code_length_nbits[17], code_length_bits[17]);
    675      writer->Write(3, 0b000);  // (5-2)*8 + 3 = 27
    676      writer->Write(code_length_nbits[17], code_length_bits[17]);
    677      writer->Write(3, 0b010);  // (27-2)*8 + 5 = 205
    678    }
    679    // Encode LZ77 symbols, with values 224+i.
    680    // Max bits written in this loop: 33 * 5 = 165
    681    for (size_t i = 0; i < num_lz77; i++) {
    682      writer->Write(code_length_nbits[lz77_nbits[i]],
    683                    code_length_bits[lz77_nbits[i]]);
    684    }
    685  }
    686 };
    687 
    688 }  // namespace
    689 
    690 extern "C" {
    691 
    692 struct JxlFastLosslessFrameState {
    693  JxlChunkedFrameInputSource input;
    694  size_t width;
    695  size_t height;
    696  size_t num_groups_x;
    697  size_t num_groups_y;
    698  size_t num_dc_groups_x;
    699  size_t num_dc_groups_y;
    700  size_t nb_chans;
    701  size_t bitdepth;
    702  int big_endian;
    703  int effort;
    704  bool collided;
    705  PrefixCode hcode[4];
    706  std::vector<int16_t> lookup;
    707  BitWriter header;
    708  std::vector<std::array<BitWriter, 4>> group_data;
    709  std::vector<size_t> group_sizes;
    710  size_t ac_group_data_offset = 0;
    711  size_t min_dc_global_size = 0;
    712  size_t current_bit_writer = 0;
    713  size_t bit_writer_byte_pos = 0;
    714  size_t bits_in_buffer = 0;
    715  uint64_t bit_buffer = 0;
    716  bool process_done = false;
    717 };
    718 
    719 size_t JxlFastLosslessOutputSize(const JxlFastLosslessFrameState* frame) {
    720  size_t total_size_groups = 0;
    721  for (const auto& section : frame->group_data) {
    722    total_size_groups += SectionSize(section);
    723  }
    724  return frame->header.bytes_written + total_size_groups;
    725 }
    726 
    727 size_t JxlFastLosslessMaxRequiredOutput(
    728    const JxlFastLosslessFrameState* frame) {
    729  return JxlFastLosslessOutputSize(frame) + 32;
    730 }
    731 
    732 void JxlFastLosslessPrepareHeader(JxlFastLosslessFrameState* frame,
    733                                  int add_image_header, int is_last) {
    734  BitWriter* output = &frame->header;
    735  output->Allocate(1000 + frame->group_sizes.size() * 32);
    736 
    737  bool have_alpha = (frame->nb_chans == 2 || frame->nb_chans == 4);
    738 
    739 #if FJXL_STANDALONE
    740  if (add_image_header) {
    741    // Signature
    742    output->Write(16, 0x0AFF);
    743 
    744    // Size header, hand-crafted.
    745    // Not small
    746    output->Write(1, 0);
    747 
    748    auto wsz = [output](size_t size) {
    749      if (size - 1 < (1 << 9)) {
    750        output->Write(2, 0b00);
    751        output->Write(9, size - 1);
    752      } else if (size - 1 < (1 << 13)) {
    753        output->Write(2, 0b01);
    754        output->Write(13, size - 1);
    755      } else if (size - 1 < (1 << 18)) {
    756        output->Write(2, 0b10);
    757        output->Write(18, size - 1);
    758      } else {
    759        output->Write(2, 0b11);
    760        output->Write(30, size - 1);
    761      }
    762    };
    763 
    764    wsz(frame->height);
    765 
    766    // No special ratio.
    767    output->Write(3, 0);
    768 
    769    wsz(frame->width);
    770 
    771    // Hand-crafted ImageMetadata.
    772    output->Write(1, 0);  // all_default
    773    output->Write(1, 0);  // extra_fields
    774    output->Write(1, 0);  // bit_depth.floating_point_sample
    775    if (frame->bitdepth == 8) {
    776      output->Write(2, 0b00);  // bit_depth.bits_per_sample = 8
    777    } else if (frame->bitdepth == 10) {
    778      output->Write(2, 0b01);  // bit_depth.bits_per_sample = 10
    779    } else if (frame->bitdepth == 12) {
    780      output->Write(2, 0b10);  // bit_depth.bits_per_sample = 12
    781    } else {
    782      output->Write(2, 0b11);  // 1 + u(6)
    783      output->Write(6, frame->bitdepth - 1);
    784    }
    785    if (frame->bitdepth <= 14) {
    786      output->Write(1, 1);  // 16-bit-buffer sufficient
    787    } else {
    788      output->Write(1, 0);  // 16-bit-buffer NOT sufficient
    789    }
    790    if (have_alpha) {
    791      output->Write(2, 0b01);  // One extra channel
    792      output->Write(1, 1);     // ... all_default (ie. 8-bit alpha)
    793    } else {
    794      output->Write(2, 0b00);  // No extra channel
    795    }
    796    output->Write(1, 0);  // Not XYB
    797    if (frame->nb_chans > 2) {
    798      output->Write(1, 1);  // color_encoding.all_default (sRGB)
    799    } else {
    800      output->Write(1, 0);     // color_encoding.all_default false
    801      output->Write(1, 0);     // color_encoding.want_icc false
    802      output->Write(2, 1);     // grayscale
    803      output->Write(2, 1);     // D65
    804      output->Write(1, 0);     // no gamma transfer function
    805      output->Write(2, 0b10);  // tf: 2 + u(4)
    806      output->Write(4, 11);    // tf of sRGB
    807      output->Write(2, 1);     // relative rendering intent
    808    }
    809    output->Write(2, 0b00);  // No extensions.
    810 
    811    output->Write(1, 1);  // all_default transform data
    812 
    813    // No ICC, no preview. Frame should start at byte boundary.
    814    output->ZeroPadToByte();
    815  }
    816 #else
    817  assert(!add_image_header);
    818 #endif
    819  // Handcrafted frame header.
    820  output->Write(1, 0);     // all_default
    821  output->Write(2, 0b00);  // regular frame
    822  output->Write(1, 1);     // modular
    823  output->Write(2, 0b00);  // default flags
    824  output->Write(1, 0);     // not YCbCr
    825  output->Write(2, 0b00);  // no upsampling
    826  if (have_alpha) {
    827    output->Write(2, 0b00);  // no alpha upsampling
    828  }
    829  output->Write(2, 0b01);  // default group size
    830  output->Write(2, 0b00);  // exactly one pass
    831  output->Write(1, 0);     // no custom size or origin
    832  output->Write(2, 0b00);  // kReplace blending mode
    833  if (have_alpha) {
    834    output->Write(2, 0b00);  // kReplace blending mode for alpha channel
    835  }
    836  output->Write(1, is_last);  // is_last
    837  if (!is_last) {
    838    output->Write(2, 0b00);  // can not be saved as reference
    839  }
    840  output->Write(2, 0b00);  // a frame has no name
    841  output->Write(1, 0);     // loop filter is not all_default
    842  output->Write(1, 0);     // no gaborish
    843  output->Write(2, 0);     // 0 EPF iters
    844  output->Write(2, 0b00);  // No LF extensions
    845  output->Write(2, 0b00);  // No FH extensions
    846 
    847  output->Write(1, 0);      // No TOC permutation
    848  output->ZeroPadToByte();  // TOC is byte-aligned.
    849  assert(add_image_header || output->bytes_written <= kMaxFrameHeaderSize);
    850  for (size_t group_size : frame->group_sizes) {
    851    size_t bucket = TOCBucket(group_size);
    852    output->Write(2, bucket);
    853    output->Write(kTOCBits[bucket] - 2, group_size - kGroupSizeOffset[bucket]);
    854  }
    855  output->ZeroPadToByte();  // Groups are byte-aligned.
    856 }
    857 
    858 #if !FJXL_STANDALONE
    859 bool JxlFastLosslessOutputAlignedSection(
    860    const BitWriter& bw, JxlEncoderOutputProcessorWrapper* output_processor) {
    861  assert(bw.bits_in_buffer == 0);
    862  const uint8_t* data = bw.data.get();
    863  size_t remaining_len = bw.bytes_written;
    864  while (remaining_len > 0) {
    865    JXL_ASSIGN_OR_RETURN(auto buffer,
    866                         output_processor->GetBuffer(1, remaining_len));
    867    size_t n = std::min(buffer.size(), remaining_len);
    868    if (n == 0) break;
    869    memcpy(buffer.data(), data, n);
    870    JXL_RETURN_IF_ERROR(buffer.advance(n));
    871    data += n;
    872    remaining_len -= n;
    873  };
    874  return true;
    875 }
    876 
    877 bool JxlFastLosslessOutputHeaders(
    878    JxlFastLosslessFrameState* frame_state,
    879    JxlEncoderOutputProcessorWrapper* output_processor) {
    880  JXL_RETURN_IF_ERROR(JxlFastLosslessOutputAlignedSection(frame_state->header,
    881                                                          output_processor));
    882  JXL_RETURN_IF_ERROR(JxlFastLosslessOutputAlignedSection(
    883      frame_state->group_data[0][0], output_processor));
    884  return true;
    885 }
    886 #endif
    887 
    888 #if FJXL_ENABLE_AVX512
    889 __attribute__((target("avx512vbmi2"))) static size_t AppendBytesWithBitOffset(
    890    const uint8_t* data, size_t n, size_t bit_buffer_nbits,
    891    unsigned char* output, uint64_t& bit_buffer) {
    892  if (n < 128) {
    893    return 0;
    894  }
    895 
    896  size_t i = 0;
    897  __m512i shift = _mm512_set1_epi64(64 - bit_buffer_nbits);
    898  __m512i carry = _mm512_set1_epi64(bit_buffer << (64 - bit_buffer_nbits));
    899 
    900  for (; i + 64 <= n; i += 64) {
    901    __m512i current = _mm512_loadu_si512(data + i);
    902    __m512i previous_u64 = _mm512_alignr_epi64(current, carry, 7);
    903    carry = current;
    904    __m512i out = _mm512_shrdv_epi64(previous_u64, current, shift);
    905    _mm512_storeu_si512(output + i, out);
    906  }
    907 
    908  bit_buffer = data[i - 1] >> (8 - bit_buffer_nbits);
    909 
    910  return i;
    911 }
    912 #endif
    913 
    914 size_t JxlFastLosslessWriteOutput(JxlFastLosslessFrameState* frame,
    915                                  unsigned char* output, size_t output_size) {
    916  assert(output_size >= 32);
    917  unsigned char* initial_output = output;
    918  size_t (*append_bytes_with_bit_offset)(const uint8_t*, size_t, size_t,
    919                                         unsigned char*, uint64_t&) = nullptr;
    920 
    921 #if FJXL_ENABLE_AVX512
    922  if (HasCpuFeature(CpuFeature::kVBMI2)) {
    923    append_bytes_with_bit_offset = AppendBytesWithBitOffset;
    924  }
    925 #endif
    926 
    927  while (true) {
    928    size_t& cur = frame->current_bit_writer;
    929    size_t& bw_pos = frame->bit_writer_byte_pos;
    930    if (cur >= 1 + frame->group_data.size() * frame->nb_chans) {
    931      return output - initial_output;
    932    }
    933    if (output_size <= 9) {
    934      return output - initial_output;
    935    }
    936    size_t nbc = frame->nb_chans;
    937    const BitWriter& writer =
    938        cur == 0 ? frame->header
    939                 : frame->group_data[(cur - 1) / nbc][(cur - 1) % nbc];
    940    size_t full_byte_count =
    941        std::min(output_size - 9, writer.bytes_written - bw_pos);
    942    if (frame->bits_in_buffer == 0) {
    943      memcpy(output, writer.data.get() + bw_pos, full_byte_count);
    944    } else {
    945      size_t i = 0;
    946      if (append_bytes_with_bit_offset) {
    947        i += append_bytes_with_bit_offset(
    948            writer.data.get() + bw_pos, full_byte_count, frame->bits_in_buffer,
    949            output, frame->bit_buffer);
    950      }
    951 #if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
    952      // Copy 8 bytes at a time until we reach the border.
    953      for (; i + 8 < full_byte_count; i += 8) {
    954        uint64_t chunk;
    955        memcpy(&chunk, writer.data.get() + bw_pos + i, 8);
    956        uint64_t out = frame->bit_buffer | (chunk << frame->bits_in_buffer);
    957        memcpy(output + i, &out, 8);
    958        frame->bit_buffer = chunk >> (64 - frame->bits_in_buffer);
    959      }
    960 #endif
    961      for (; i < full_byte_count; i++) {
    962        AddBits(8, writer.data.get()[bw_pos + i], output + i,
    963                frame->bits_in_buffer, frame->bit_buffer);
    964      }
    965    }
    966    output += full_byte_count;
    967    output_size -= full_byte_count;
    968    bw_pos += full_byte_count;
    969    if (bw_pos == writer.bytes_written) {
    970      auto write = [&](size_t num, uint64_t bits) {
    971        size_t n = AddBits(num, bits, output, frame->bits_in_buffer,
    972                           frame->bit_buffer);
    973        output += n;
    974        output_size -= n;
    975      };
    976      if (writer.bits_in_buffer) {
    977        write(writer.bits_in_buffer, writer.buffer);
    978      }
    979      bw_pos = 0;
    980      cur++;
    981      if ((cur - 1) % nbc == 0 && frame->bits_in_buffer != 0) {
    982        write(8 - frame->bits_in_buffer, 0);
    983      }
    984    }
    985  }
    986 }
    987 
    988 void JxlFastLosslessFreeFrameState(JxlFastLosslessFrameState* frame) {
    989  delete frame;
    990 }
    991 
    992 }  // extern "C"
    993 
    994 #endif
    995 
    996 #ifdef FJXL_SELF_INCLUDE
    997 
    998 namespace {
    999 
   1000 template <typename T>
   1001 struct VecPair {
   1002  T low;
   1003  T hi;
   1004 };
   1005 
   1006 #ifdef FJXL_GENERIC_SIMD
   1007 #undef FJXL_GENERIC_SIMD
   1008 #endif
   1009 
   1010 #ifdef FJXL_AVX512
   1011 #define FJXL_GENERIC_SIMD
   1012 struct SIMDVec32;
   1013 struct Mask32 {
   1014  __mmask16 mask;
   1015  SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false);
   1016  size_t CountPrefix() const {
   1017    return CtzNonZero(~uint64_t{_cvtmask16_u32(mask)});
   1018  }
   1019 };
   1020 
   1021 struct SIMDVec32 {
   1022  __m512i vec;
   1023 
   1024  static constexpr size_t kLanes = 16;
   1025 
   1026  FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) {
   1027    return SIMDVec32{_mm512_loadu_si512((__m512i*)data)};
   1028  }
   1029  FJXL_INLINE void Store(uint32_t* data) {
   1030    _mm512_storeu_si512((__m512i*)data, vec);
   1031  }
   1032  FJXL_INLINE static SIMDVec32 Val(uint32_t v) {
   1033    return SIMDVec32{_mm512_set1_epi32(v)};
   1034  }
   1035  FJXL_INLINE SIMDVec32 ValToToken() const {
   1036    return SIMDVec32{
   1037        _mm512_sub_epi32(_mm512_set1_epi32(32), _mm512_lzcnt_epi32(vec))};
   1038  }
   1039  FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const {
   1040    return SIMDVec32{_mm512_sub_epi32(_mm512_max_epu32(vec, to_subtract.vec),
   1041                                      to_subtract.vec)};
   1042  }
   1043  FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const {
   1044    return SIMDVec32{_mm512_sub_epi32(vec, to_subtract.vec)};
   1045  }
   1046  FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const {
   1047    return SIMDVec32{_mm512_add_epi32(vec, oth.vec)};
   1048  }
   1049  FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const {
   1050    return SIMDVec32{_mm512_xor_epi32(vec, oth.vec)};
   1051  }
   1052  FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const {
   1053    return Mask32{_mm512_cmpeq_epi32_mask(vec, oth.vec)};
   1054  }
   1055  FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const {
   1056    return Mask32{_mm512_cmpgt_epi32_mask(vec, oth.vec)};
   1057  }
   1058  FJXL_INLINE SIMDVec32 Pow2() const {
   1059    return SIMDVec32{_mm512_sllv_epi32(_mm512_set1_epi32(1), vec)};
   1060  }
   1061  template <size_t i>
   1062  FJXL_INLINE SIMDVec32 SignedShiftRight() const {
   1063    return SIMDVec32{_mm512_srai_epi32(vec, i)};
   1064  }
   1065 };
   1066 
   1067 struct SIMDVec16;
   1068 
   1069 struct Mask16 {
   1070  __mmask32 mask;
   1071  SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false);
   1072  Mask16 And(const Mask16& oth) const {
   1073    return Mask16{_kand_mask32(mask, oth.mask)};
   1074  }
   1075  size_t CountPrefix() const {
   1076    return CtzNonZero(~uint64_t{_cvtmask32_u32(mask)});
   1077  }
   1078 };
   1079 
   1080 struct SIMDVec16 {
   1081  __m512i vec;
   1082 
   1083  static constexpr size_t kLanes = 32;
   1084 
   1085  FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) {
   1086    return SIMDVec16{_mm512_loadu_si512((__m512i*)data)};
   1087  }
   1088  FJXL_INLINE void Store(uint16_t* data) {
   1089    _mm512_storeu_si512((__m512i*)data, vec);
   1090  }
   1091  FJXL_INLINE static SIMDVec16 Val(uint16_t v) {
   1092    return SIMDVec16{_mm512_set1_epi16(v)};
   1093  }
   1094  FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo,
   1095                                         const SIMDVec32& hi) {
   1096    auto tmp = _mm512_packus_epi32(lo.vec, hi.vec);
   1097    alignas(64) uint64_t perm[8] = {0, 2, 4, 6, 1, 3, 5, 7};
   1098    return SIMDVec16{
   1099        _mm512_permutex2var_epi64(tmp, _mm512_load_si512((__m512i*)perm), tmp)};
   1100  }
   1101 
   1102  FJXL_INLINE SIMDVec16 ValToToken() const {
   1103    auto c16 = _mm512_set1_epi32(16);
   1104    auto c32 = _mm512_set1_epi32(32);
   1105    auto low16bit = _mm512_set1_epi32(0x0000FFFF);
   1106    auto lzhi =
   1107        _mm512_sub_epi32(c16, _mm512_min_epu32(c16, _mm512_lzcnt_epi32(vec)));
   1108    auto lzlo = _mm512_sub_epi32(
   1109        c32, _mm512_lzcnt_epi32(_mm512_and_si512(low16bit, vec)));
   1110    return SIMDVec16{_mm512_or_si512(lzlo, _mm512_slli_epi32(lzhi, 16))};
   1111  }
   1112 
   1113  FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const {
   1114    return SIMDVec16{_mm512_subs_epu16(vec, to_subtract.vec)};
   1115  }
   1116  FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const {
   1117    return SIMDVec16{_mm512_sub_epi16(vec, to_subtract.vec)};
   1118  }
   1119  FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const {
   1120    return SIMDVec16{_mm512_add_epi16(vec, oth.vec)};
   1121  }
   1122  FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const {
   1123    return SIMDVec16{_mm512_min_epu16(vec, oth.vec)};
   1124  }
   1125  FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const {
   1126    return Mask16{_mm512_cmpeq_epi16_mask(vec, oth.vec)};
   1127  }
   1128  FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const {
   1129    return Mask16{_mm512_cmpgt_epi16_mask(vec, oth.vec)};
   1130  }
   1131  FJXL_INLINE SIMDVec16 Pow2() const {
   1132    return SIMDVec16{_mm512_sllv_epi16(_mm512_set1_epi16(1), vec)};
   1133  }
   1134  FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const {
   1135    return SIMDVec16{_mm512_or_si512(vec, oth.vec)};
   1136  }
   1137  FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const {
   1138    return SIMDVec16{_mm512_xor_si512(vec, oth.vec)};
   1139  }
   1140  FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const {
   1141    return SIMDVec16{_mm512_and_si512(vec, oth.vec)};
   1142  }
   1143  FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const {
   1144    return SIMDVec16{_mm512_srai_epi16(_mm512_add_epi16(vec, oth.vec), 1)};
   1145  }
   1146  FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const {
   1147    return SIMDVec16{_mm512_or_si512(vec, _mm512_set1_epi16(0xFF00))};
   1148  }
   1149  FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const {
   1150    return SIMDVec16{_mm512_shuffle_epi8(
   1151        _mm512_broadcast_i32x4(_mm_loadu_si128((__m128i*)table)), vec)};
   1152  }
   1153  FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const {
   1154    auto lo = _mm512_unpacklo_epi16(low.vec, vec);
   1155    auto hi = _mm512_unpackhi_epi16(low.vec, vec);
   1156    alignas(64) uint64_t perm1[8] = {0, 1, 8, 9, 2, 3, 10, 11};
   1157    alignas(64) uint64_t perm2[8] = {4, 5, 12, 13, 6, 7, 14, 15};
   1158    return {SIMDVec16{_mm512_permutex2var_epi64(
   1159                lo, _mm512_load_si512((__m512i*)perm1), hi)},
   1160            SIMDVec16{_mm512_permutex2var_epi64(
   1161                lo, _mm512_load_si512((__m512i*)perm2), hi)}};
   1162  }
   1163  FJXL_INLINE VecPair<SIMDVec32> Upcast() const {
   1164    auto lo = _mm512_unpacklo_epi16(vec, _mm512_setzero_si512());
   1165    auto hi = _mm512_unpackhi_epi16(vec, _mm512_setzero_si512());
   1166    alignas(64) uint64_t perm1[8] = {0, 1, 8, 9, 2, 3, 10, 11};
   1167    alignas(64) uint64_t perm2[8] = {4, 5, 12, 13, 6, 7, 14, 15};
   1168    return {SIMDVec32{_mm512_permutex2var_epi64(
   1169                lo, _mm512_load_si512((__m512i*)perm1), hi)},
   1170            SIMDVec32{_mm512_permutex2var_epi64(
   1171                lo, _mm512_load_si512((__m512i*)perm2), hi)}};
   1172  }
   1173  template <size_t i>
   1174  FJXL_INLINE SIMDVec16 SignedShiftRight() const {
   1175    return SIMDVec16{_mm512_srai_epi16(vec, i)};
   1176  }
   1177 
   1178  static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) {
   1179    __m256i bytes = _mm256_loadu_si256((__m256i*)data);
   1180    return {SIMDVec16{_mm512_cvtepu8_epi16(bytes)}};
   1181  }
   1182  static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) {
   1183    return {Load((const uint16_t*)data)};
   1184  }
   1185 
   1186  static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) {
   1187    __m512i bytes = _mm512_loadu_si512((__m512i*)data);
   1188    __m512i gray = _mm512_and_si512(bytes, _mm512_set1_epi16(0xFF));
   1189    __m512i alpha = _mm512_srli_epi16(bytes, 8);
   1190    return {SIMDVec16{gray}, SIMDVec16{alpha}};
   1191  }
   1192  static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) {
   1193    __m512i bytes1 = _mm512_loadu_si512((__m512i*)data);
   1194    __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 64));
   1195    __m512i g_mask = _mm512_set1_epi32(0xFFFF);
   1196    __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
   1197    __m512i g = _mm512_permutexvar_epi64(
   1198        permuteidx, _mm512_packus_epi32(_mm512_and_si512(bytes1, g_mask),
   1199                                        _mm512_and_si512(bytes2, g_mask)));
   1200    __m512i a = _mm512_permutexvar_epi64(
   1201        permuteidx, _mm512_packus_epi32(_mm512_srli_epi32(bytes1, 16),
   1202                                        _mm512_srli_epi32(bytes2, 16)));
   1203    return {SIMDVec16{g}, SIMDVec16{a}};
   1204  }
   1205 
   1206  static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) {
   1207    __m512i bytes0 = _mm512_loadu_si512((__m512i*)data);
   1208    __m512i bytes1 =
   1209        _mm512_zextsi256_si512(_mm256_loadu_si256((__m256i*)(data + 64)));
   1210 
   1211    // 0x7A = element of upper half of second vector = 0 after lookup; still in
   1212    // the upper half once we add 1 or 2.
   1213    uint8_t z = 0x7A;
   1214    __m512i ridx =
   1215        _mm512_set_epi8(z, 93, z, 90, z, 87, z, 84, z, 81, z, 78, z, 75, z, 72,
   1216                        z, 69, z, 66, z, 63, z, 60, z, 57, z, 54, z, 51, z, 48,
   1217                        z, 45, z, 42, z, 39, z, 36, z, 33, z, 30, z, 27, z, 24,
   1218                        z, 21, z, 18, z, 15, z, 12, z, 9, z, 6, z, 3, z, 0);
   1219    __m512i gidx = _mm512_add_epi8(ridx, _mm512_set1_epi8(1));
   1220    __m512i bidx = _mm512_add_epi8(gidx, _mm512_set1_epi8(1));
   1221    __m512i r = _mm512_permutex2var_epi8(bytes0, ridx, bytes1);
   1222    __m512i g = _mm512_permutex2var_epi8(bytes0, gidx, bytes1);
   1223    __m512i b = _mm512_permutex2var_epi8(bytes0, bidx, bytes1);
   1224    return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}};
   1225  }
   1226  static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) {
   1227    __m512i bytes0 = _mm512_loadu_si512((__m512i*)data);
   1228    __m512i bytes1 = _mm512_loadu_si512((__m512i*)(data + 64));
   1229    __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 128));
   1230 
   1231    __m512i ridx_lo = _mm512_set_epi16(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 63, 60, 57,
   1232                                       54, 51, 48, 45, 42, 39, 36, 33, 30, 27,
   1233                                       24, 21, 18, 15, 12, 9, 6, 3, 0);
   1234    // -1 is such that when adding 1 or 2, we get the correct index for
   1235    // green/blue.
   1236    __m512i ridx_hi =
   1237        _mm512_set_epi16(29, 26, 23, 20, 17, 14, 11, 8, 5, 2, -1, 0, 0, 0, 0, 0,
   1238                         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
   1239    __m512i gidx_lo = _mm512_add_epi16(ridx_lo, _mm512_set1_epi16(1));
   1240    __m512i gidx_hi = _mm512_add_epi16(ridx_hi, _mm512_set1_epi16(1));
   1241    __m512i bidx_lo = _mm512_add_epi16(gidx_lo, _mm512_set1_epi16(1));
   1242    __m512i bidx_hi = _mm512_add_epi16(gidx_hi, _mm512_set1_epi16(1));
   1243 
   1244    __mmask32 rmask = _cvtu32_mask32(0b11111111110000000000000000000000);
   1245    __mmask32 gbmask = _cvtu32_mask32(0b11111111111000000000000000000000);
   1246 
   1247    __m512i rlo = _mm512_permutex2var_epi16(bytes0, ridx_lo, bytes1);
   1248    __m512i glo = _mm512_permutex2var_epi16(bytes0, gidx_lo, bytes1);
   1249    __m512i blo = _mm512_permutex2var_epi16(bytes0, bidx_lo, bytes1);
   1250    __m512i r = _mm512_mask_permutexvar_epi16(rlo, rmask, ridx_hi, bytes2);
   1251    __m512i g = _mm512_mask_permutexvar_epi16(glo, gbmask, gidx_hi, bytes2);
   1252    __m512i b = _mm512_mask_permutexvar_epi16(blo, gbmask, bidx_hi, bytes2);
   1253    return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}};
   1254  }
   1255 
   1256  static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) {
   1257    __m512i bytes1 = _mm512_loadu_si512((__m512i*)data);
   1258    __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 64));
   1259    __m512i rg_mask = _mm512_set1_epi32(0xFFFF);
   1260    __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
   1261    __m512i rg = _mm512_permutexvar_epi64(
   1262        permuteidx, _mm512_packus_epi32(_mm512_and_si512(bytes1, rg_mask),
   1263                                        _mm512_and_si512(bytes2, rg_mask)));
   1264    __m512i b_a = _mm512_permutexvar_epi64(
   1265        permuteidx, _mm512_packus_epi32(_mm512_srli_epi32(bytes1, 16),
   1266                                        _mm512_srli_epi32(bytes2, 16)));
   1267    __m512i r = _mm512_and_si512(rg, _mm512_set1_epi16(0xFF));
   1268    __m512i g = _mm512_srli_epi16(rg, 8);
   1269    __m512i b = _mm512_and_si512(b_a, _mm512_set1_epi16(0xFF));
   1270    __m512i a = _mm512_srli_epi16(b_a, 8);
   1271    return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
   1272  }
   1273  static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) {
   1274    __m512i bytes0 = _mm512_loadu_si512((__m512i*)data);
   1275    __m512i bytes1 = _mm512_loadu_si512((__m512i*)(data + 64));
   1276    __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 128));
   1277    __m512i bytes3 = _mm512_loadu_si512((__m512i*)(data + 192));
   1278 
   1279    auto pack32 = [](__m512i a, __m512i b) {
   1280      __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
   1281      return _mm512_permutexvar_epi64(permuteidx, _mm512_packus_epi32(a, b));
   1282    };
   1283    auto packlow32 = [&pack32](__m512i a, __m512i b) {
   1284      __m512i mask = _mm512_set1_epi32(0xFFFF);
   1285      return pack32(_mm512_and_si512(a, mask), _mm512_and_si512(b, mask));
   1286    };
   1287    auto packhi32 = [&pack32](__m512i a, __m512i b) {
   1288      return pack32(_mm512_srli_epi32(a, 16), _mm512_srli_epi32(b, 16));
   1289    };
   1290 
   1291    __m512i rb0 = packlow32(bytes0, bytes1);
   1292    __m512i rb1 = packlow32(bytes2, bytes3);
   1293    __m512i ga0 = packhi32(bytes0, bytes1);
   1294    __m512i ga1 = packhi32(bytes2, bytes3);
   1295 
   1296    __m512i r = packlow32(rb0, rb1);
   1297    __m512i g = packlow32(ga0, ga1);
   1298    __m512i b = packhi32(rb0, rb1);
   1299    __m512i a = packhi32(ga0, ga1);
   1300    return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
   1301  }
   1302 
   1303  void SwapEndian() {
   1304    auto indices = _mm512_broadcast_i32x4(
   1305        _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14));
   1306    vec = _mm512_shuffle_epi8(vec, indices);
   1307  }
   1308 };
   1309 
   1310 SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true,
   1311                             const SIMDVec16& if_false) {
   1312  return SIMDVec16{_mm512_mask_blend_epi16(mask, if_false.vec, if_true.vec)};
   1313 }
   1314 
   1315 SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true,
   1316                             const SIMDVec32& if_false) {
   1317  return SIMDVec32{_mm512_mask_blend_epi32(mask, if_false.vec, if_true.vec)};
   1318 }
   1319 
   1320 struct Bits64 {
   1321  static constexpr size_t kLanes = 8;
   1322 
   1323  __m512i nbits;
   1324  __m512i bits;
   1325 
   1326  FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) {
   1327    _mm512_storeu_si512((__m512i*)nbits_out, nbits);
   1328    _mm512_storeu_si512((__m512i*)bits_out, bits);
   1329  }
   1330 };
   1331 
   1332 struct Bits32 {
   1333  __m512i nbits;
   1334  __m512i bits;
   1335 
   1336  static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) {
   1337    return Bits32{nbits.vec, bits.vec};
   1338  }
   1339 
   1340  Bits64 Merge() const {
   1341    auto nbits_hi32 = _mm512_srli_epi64(nbits, 32);
   1342    auto nbits_lo32 = _mm512_and_si512(nbits, _mm512_set1_epi64(0xFFFFFFFF));
   1343    auto bits_hi32 = _mm512_srli_epi64(bits, 32);
   1344    auto bits_lo32 = _mm512_and_si512(bits, _mm512_set1_epi64(0xFFFFFFFF));
   1345 
   1346    auto nbits64 = _mm512_add_epi64(nbits_hi32, nbits_lo32);
   1347    auto bits64 =
   1348        _mm512_or_si512(_mm512_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32);
   1349    return Bits64{nbits64, bits64};
   1350  }
   1351 
   1352  void Interleave(const Bits32& low) {
   1353    bits = _mm512_or_si512(_mm512_sllv_epi32(bits, low.nbits), low.bits);
   1354    nbits = _mm512_add_epi32(nbits, low.nbits);
   1355  }
   1356 
   1357  void ClipTo(size_t n) {
   1358    n = std::min<size_t>(n, 16);
   1359    constexpr uint32_t kMask[32] = {
   1360        ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
   1361        ~0u, ~0u, ~0u, ~0u, ~0u, 0,   0,   0,   0,   0,   0,
   1362        0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
   1363    };
   1364    __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 16 - n));
   1365    nbits = _mm512_and_si512(mask, nbits);
   1366    bits = _mm512_and_si512(mask, bits);
   1367  }
   1368  void Skip(size_t n) {
   1369    n = std::min<size_t>(n, 16);
   1370    constexpr uint32_t kMask[32] = {
   1371        0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
   1372        0,   0,   0,   0,   0,   ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
   1373        ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
   1374    };
   1375    __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 16 - n));
   1376    nbits = _mm512_and_si512(mask, nbits);
   1377    bits = _mm512_and_si512(mask, bits);
   1378  }
   1379 };
   1380 
   1381 struct Bits16 {
   1382  __m512i nbits;
   1383  __m512i bits;
   1384 
   1385  static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) {
   1386    return Bits16{nbits.vec, bits.vec};
   1387  }
   1388 
   1389  Bits32 Merge() const {
   1390    auto nbits_hi16 = _mm512_srli_epi32(nbits, 16);
   1391    auto nbits_lo16 = _mm512_and_si512(nbits, _mm512_set1_epi32(0xFFFF));
   1392    auto bits_hi16 = _mm512_srli_epi32(bits, 16);
   1393    auto bits_lo16 = _mm512_and_si512(bits, _mm512_set1_epi32(0xFFFF));
   1394 
   1395    auto nbits32 = _mm512_add_epi32(nbits_hi16, nbits_lo16);
   1396    auto bits32 =
   1397        _mm512_or_si512(_mm512_sllv_epi32(bits_hi16, nbits_lo16), bits_lo16);
   1398    return Bits32{nbits32, bits32};
   1399  }
   1400 
   1401  void Interleave(const Bits16& low) {
   1402    bits = _mm512_or_si512(_mm512_sllv_epi16(bits, low.nbits), low.bits);
   1403    nbits = _mm512_add_epi16(nbits, low.nbits);
   1404  }
   1405 
   1406  void ClipTo(size_t n) {
   1407    n = std::min<size_t>(n, 32);
   1408    constexpr uint16_t kMask[64] = {
   1409        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1410        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1411        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1412        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1413        0,      0,      0,      0,      0,      0,      0,      0,
   1414        0,      0,      0,      0,      0,      0,      0,      0,
   1415        0,      0,      0,      0,      0,      0,      0,      0,
   1416        0,      0,      0,      0,      0,      0,      0,      0,
   1417    };
   1418    __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 32 - n));
   1419    nbits = _mm512_and_si512(mask, nbits);
   1420    bits = _mm512_and_si512(mask, bits);
   1421  }
   1422  void Skip(size_t n) {
   1423    n = std::min<size_t>(n, 32);
   1424    constexpr uint16_t kMask[64] = {
   1425        0,      0,      0,      0,      0,      0,      0,      0,
   1426        0,      0,      0,      0,      0,      0,      0,      0,
   1427        0,      0,      0,      0,      0,      0,      0,      0,
   1428        0,      0,      0,      0,      0,      0,      0,      0,
   1429        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1430        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1431        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1432        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1433    };
   1434    __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 32 - n));
   1435    nbits = _mm512_and_si512(mask, nbits);
   1436    bits = _mm512_and_si512(mask, bits);
   1437  }
   1438 };
   1439 
   1440 #endif
   1441 
   1442 #ifdef FJXL_AVX2
   1443 #define FJXL_GENERIC_SIMD
   1444 
   1445 struct SIMDVec32;
   1446 
   1447 struct Mask32 {
   1448  __m256i mask;
   1449  SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false);
   1450  size_t CountPrefix() const {
   1451    return CtzNonZero(~static_cast<uint64_t>(
   1452        static_cast<uint8_t>(_mm256_movemask_ps(_mm256_castsi256_ps(mask)))));
   1453  }
   1454 };
   1455 
   1456 struct SIMDVec32 {
   1457  __m256i vec;
   1458 
   1459  static constexpr size_t kLanes = 8;
   1460 
   1461  FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) {
   1462    return SIMDVec32{_mm256_loadu_si256((__m256i*)data)};
   1463  }
   1464  FJXL_INLINE void Store(uint32_t* data) {
   1465    _mm256_storeu_si256((__m256i*)data, vec);
   1466  }
   1467  FJXL_INLINE static SIMDVec32 Val(uint32_t v) {
   1468    return SIMDVec32{_mm256_set1_epi32(v)};
   1469  }
   1470  FJXL_INLINE SIMDVec32 ValToToken() const {
   1471    auto f32 = _mm256_castps_si256(_mm256_cvtepi32_ps(vec));
   1472    return SIMDVec32{_mm256_max_epi32(
   1473        _mm256_setzero_si256(),
   1474        _mm256_sub_epi32(_mm256_srli_epi32(f32, 23), _mm256_set1_epi32(126)))};
   1475  }
   1476  FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const {
   1477    return SIMDVec32{_mm256_sub_epi32(_mm256_max_epu32(vec, to_subtract.vec),
   1478                                      to_subtract.vec)};
   1479  }
   1480  FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const {
   1481    return SIMDVec32{_mm256_sub_epi32(vec, to_subtract.vec)};
   1482  }
   1483  FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const {
   1484    return SIMDVec32{_mm256_add_epi32(vec, oth.vec)};
   1485  }
   1486  FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const {
   1487    return SIMDVec32{_mm256_xor_si256(vec, oth.vec)};
   1488  }
   1489  FJXL_INLINE SIMDVec32 Pow2() const {
   1490    return SIMDVec32{_mm256_sllv_epi32(_mm256_set1_epi32(1), vec)};
   1491  }
   1492  FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const {
   1493    return Mask32{_mm256_cmpeq_epi32(vec, oth.vec)};
   1494  }
   1495  FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const {
   1496    return Mask32{_mm256_cmpgt_epi32(vec, oth.vec)};
   1497  }
   1498  template <size_t i>
   1499  FJXL_INLINE SIMDVec32 SignedShiftRight() const {
   1500    return SIMDVec32{_mm256_srai_epi32(vec, i)};
   1501  }
   1502 };
   1503 
   1504 struct SIMDVec16;
   1505 
   1506 struct Mask16 {
   1507  __m256i mask;
   1508  SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false);
   1509  Mask16 And(const Mask16& oth) const {
   1510    return Mask16{_mm256_and_si256(mask, oth.mask)};
   1511  }
   1512  size_t CountPrefix() const {
   1513    return CtzNonZero(~static_cast<uint64_t>(
   1514               static_cast<uint32_t>(_mm256_movemask_epi8(mask)))) /
   1515           2;
   1516  }
   1517 };
   1518 
   1519 struct SIMDVec16 {
   1520  __m256i vec;
   1521 
   1522  static constexpr size_t kLanes = 16;
   1523 
   1524  FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) {
   1525    return SIMDVec16{_mm256_loadu_si256((__m256i*)data)};
   1526  }
   1527  FJXL_INLINE void Store(uint16_t* data) {
   1528    _mm256_storeu_si256((__m256i*)data, vec);
   1529  }
   1530  FJXL_INLINE static SIMDVec16 Val(uint16_t v) {
   1531    return SIMDVec16{_mm256_set1_epi16(v)};
   1532  }
   1533  FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo,
   1534                                         const SIMDVec32& hi) {
   1535    auto tmp = _mm256_packus_epi32(lo.vec, hi.vec);
   1536    return SIMDVec16{_mm256_permute4x64_epi64(tmp, 0b11011000)};
   1537  }
   1538 
   1539  FJXL_INLINE SIMDVec16 ValToToken() const {
   1540    auto nibble0 =
   1541        _mm256_or_si256(_mm256_and_si256(vec, _mm256_set1_epi16(0xF)),
   1542                        _mm256_set1_epi16(0xFF00));
   1543    auto nibble1 = _mm256_or_si256(
   1544        _mm256_and_si256(_mm256_srli_epi16(vec, 4), _mm256_set1_epi16(0xF)),
   1545        _mm256_set1_epi16(0xFF00));
   1546    auto nibble2 = _mm256_or_si256(
   1547        _mm256_and_si256(_mm256_srli_epi16(vec, 8), _mm256_set1_epi16(0xF)),
   1548        _mm256_set1_epi16(0xFF00));
   1549    auto nibble3 =
   1550        _mm256_or_si256(_mm256_srli_epi16(vec, 12), _mm256_set1_epi16(0xFF00));
   1551 
   1552    auto lut0 = _mm256_broadcastsi128_si256(
   1553        _mm_setr_epi8(0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4));
   1554    auto lut1 = _mm256_broadcastsi128_si256(
   1555        _mm_setr_epi8(0, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8));
   1556    auto lut2 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1557        0, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12));
   1558    auto lut3 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1559        0, 13, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16));
   1560 
   1561    auto token0 = _mm256_shuffle_epi8(lut0, nibble0);
   1562    auto token1 = _mm256_shuffle_epi8(lut1, nibble1);
   1563    auto token2 = _mm256_shuffle_epi8(lut2, nibble2);
   1564    auto token3 = _mm256_shuffle_epi8(lut3, nibble3);
   1565 
   1566    auto token = _mm256_max_epi16(_mm256_max_epi16(token0, token1),
   1567                                  _mm256_max_epi16(token2, token3));
   1568    return SIMDVec16{token};
   1569  }
   1570 
   1571  FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const {
   1572    return SIMDVec16{_mm256_subs_epu16(vec, to_subtract.vec)};
   1573  }
   1574  FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const {
   1575    return SIMDVec16{_mm256_sub_epi16(vec, to_subtract.vec)};
   1576  }
   1577  FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const {
   1578    return SIMDVec16{_mm256_add_epi16(vec, oth.vec)};
   1579  }
   1580  FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const {
   1581    return SIMDVec16{_mm256_min_epu16(vec, oth.vec)};
   1582  }
   1583  FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const {
   1584    return Mask16{_mm256_cmpeq_epi16(vec, oth.vec)};
   1585  }
   1586  FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const {
   1587    return Mask16{_mm256_cmpgt_epi16(vec, oth.vec)};
   1588  }
   1589  FJXL_INLINE SIMDVec16 Pow2() const {
   1590    auto pow2_lo_lut = _mm256_broadcastsi128_si256(
   1591        _mm_setr_epi8(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6,
   1592                      1u << 7, 0, 0, 0, 0, 0, 0, 0, 0));
   1593    auto pow2_hi_lut = _mm256_broadcastsi128_si256(
   1594        _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 1 << 0, 1 << 1, 1 << 2, 1 << 3,
   1595                      1 << 4, 1 << 5, 1 << 6, 1u << 7));
   1596 
   1597    auto masked = _mm256_or_si256(vec, _mm256_set1_epi16(0xFF00));
   1598 
   1599    auto pow2_lo = _mm256_shuffle_epi8(pow2_lo_lut, masked);
   1600    auto pow2_hi = _mm256_shuffle_epi8(pow2_hi_lut, masked);
   1601 
   1602    auto pow2 = _mm256_or_si256(_mm256_slli_epi16(pow2_hi, 8), pow2_lo);
   1603    return SIMDVec16{pow2};
   1604  }
   1605  FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const {
   1606    return SIMDVec16{_mm256_or_si256(vec, oth.vec)};
   1607  }
   1608  FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const {
   1609    return SIMDVec16{_mm256_xor_si256(vec, oth.vec)};
   1610  }
   1611  FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const {
   1612    return SIMDVec16{_mm256_and_si256(vec, oth.vec)};
   1613  }
   1614  FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const {
   1615    return SIMDVec16{_mm256_srai_epi16(_mm256_add_epi16(vec, oth.vec), 1)};
   1616  }
   1617  FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const {
   1618    return SIMDVec16{_mm256_or_si256(vec, _mm256_set1_epi16(0xFF00))};
   1619  }
   1620  FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const {
   1621    return SIMDVec16{_mm256_shuffle_epi8(
   1622        _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)table)), vec)};
   1623  }
   1624  FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const {
   1625    auto v02 = _mm256_unpacklo_epi16(low.vec, vec);
   1626    auto v13 = _mm256_unpackhi_epi16(low.vec, vec);
   1627    return {SIMDVec16{_mm256_permute2x128_si256(v02, v13, 0x20)},
   1628            SIMDVec16{_mm256_permute2x128_si256(v02, v13, 0x31)}};
   1629  }
   1630  FJXL_INLINE VecPair<SIMDVec32> Upcast() const {
   1631    auto v02 = _mm256_unpacklo_epi16(vec, _mm256_setzero_si256());
   1632    auto v13 = _mm256_unpackhi_epi16(vec, _mm256_setzero_si256());
   1633    return {SIMDVec32{_mm256_permute2x128_si256(v02, v13, 0x20)},
   1634            SIMDVec32{_mm256_permute2x128_si256(v02, v13, 0x31)}};
   1635  }
   1636  template <size_t i>
   1637  FJXL_INLINE SIMDVec16 SignedShiftRight() const {
   1638    return SIMDVec16{_mm256_srai_epi16(vec, i)};
   1639  }
   1640 
   1641  static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) {
   1642    __m128i bytes = _mm_loadu_si128((__m128i*)data);
   1643    return {SIMDVec16{_mm256_cvtepu8_epi16(bytes)}};
   1644  }
   1645  static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) {
   1646    return {Load((const uint16_t*)data)};
   1647  }
   1648 
   1649  static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) {
   1650    __m256i bytes = _mm256_loadu_si256((__m256i*)data);
   1651    __m256i gray = _mm256_and_si256(bytes, _mm256_set1_epi16(0xFF));
   1652    __m256i alpha = _mm256_srli_epi16(bytes, 8);
   1653    return {SIMDVec16{gray}, SIMDVec16{alpha}};
   1654  }
   1655  static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) {
   1656    __m256i bytes1 = _mm256_loadu_si256((__m256i*)data);
   1657    __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 32));
   1658    __m256i g_mask = _mm256_set1_epi32(0xFFFF);
   1659    __m256i g = _mm256_permute4x64_epi64(
   1660        _mm256_packus_epi32(_mm256_and_si256(bytes1, g_mask),
   1661                            _mm256_and_si256(bytes2, g_mask)),
   1662        0b11011000);
   1663    __m256i a = _mm256_permute4x64_epi64(
   1664        _mm256_packus_epi32(_mm256_srli_epi32(bytes1, 16),
   1665                            _mm256_srli_epi32(bytes2, 16)),
   1666        0b11011000);
   1667    return {SIMDVec16{g}, SIMDVec16{a}};
   1668  }
   1669 
   1670  static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) {
   1671    __m128i bytes0 = _mm_loadu_si128((__m128i*)data);
   1672    __m128i bytes1 = _mm_loadu_si128((__m128i*)(data + 16));
   1673    __m128i bytes2 = _mm_loadu_si128((__m128i*)(data + 32));
   1674 
   1675    __m128i idx =
   1676        _mm_setr_epi8(0, 3, 6, 9, 12, 15, 2, 5, 8, 11, 14, 1, 4, 7, 10, 13);
   1677 
   1678    __m128i r6b5g5_0 = _mm_shuffle_epi8(bytes0, idx);
   1679    __m128i g6r5b5_1 = _mm_shuffle_epi8(bytes1, idx);
   1680    __m128i b6g5r5_2 = _mm_shuffle_epi8(bytes2, idx);
   1681 
   1682    __m128i mask010 = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF,
   1683                                    0xFF, 0, 0, 0, 0, 0);
   1684    __m128i mask001 = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF,
   1685                                    0xFF, 0xFF, 0xFF);
   1686 
   1687    __m128i b2g2b1 = _mm_blendv_epi8(b6g5r5_2, g6r5b5_1, mask001);
   1688    __m128i b2b0b1 = _mm_blendv_epi8(b2g2b1, r6b5g5_0, mask010);
   1689 
   1690    __m128i r0r1b1 = _mm_blendv_epi8(r6b5g5_0, g6r5b5_1, mask010);
   1691    __m128i r0r1r2 = _mm_blendv_epi8(r0r1b1, b6g5r5_2, mask001);
   1692 
   1693    __m128i g1r1g0 = _mm_blendv_epi8(g6r5b5_1, r6b5g5_0, mask001);
   1694    __m128i g1g2g0 = _mm_blendv_epi8(g1r1g0, b6g5r5_2, mask010);
   1695 
   1696    __m128i g0g1g2 = _mm_alignr_epi8(g1g2g0, g1g2g0, 11);
   1697    __m128i b0b1b2 = _mm_alignr_epi8(b2b0b1, b2b0b1, 6);
   1698 
   1699    return {SIMDVec16{_mm256_cvtepu8_epi16(r0r1r2)},
   1700            SIMDVec16{_mm256_cvtepu8_epi16(g0g1g2)},
   1701            SIMDVec16{_mm256_cvtepu8_epi16(b0b1b2)}};
   1702  }
   1703  static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) {
   1704    auto load_and_split_lohi = [](const unsigned char* data) {
   1705      // LHLHLH...
   1706      __m256i bytes = _mm256_loadu_si256((__m256i*)data);
   1707      // L0L0L0...
   1708      __m256i lo = _mm256_and_si256(bytes, _mm256_set1_epi16(0xFF));
   1709      // H0H0H0...
   1710      __m256i hi = _mm256_srli_epi16(bytes, 8);
   1711      // LLLLLLLLHHHHHHHHLLLLLLLLHHHHHHHH
   1712      __m256i packed = _mm256_packus_epi16(lo, hi);
   1713      return _mm256_permute4x64_epi64(packed, 0b11011000);
   1714    };
   1715    __m256i bytes0 = load_and_split_lohi(data);
   1716    __m256i bytes1 = load_and_split_lohi(data + 32);
   1717    __m256i bytes2 = load_and_split_lohi(data + 64);
   1718 
   1719    __m256i idx = _mm256_broadcastsi128_si256(
   1720        _mm_setr_epi8(0, 3, 6, 9, 12, 15, 2, 5, 8, 11, 14, 1, 4, 7, 10, 13));
   1721 
   1722    __m256i r6b5g5_0 = _mm256_shuffle_epi8(bytes0, idx);
   1723    __m256i g6r5b5_1 = _mm256_shuffle_epi8(bytes1, idx);
   1724    __m256i b6g5r5_2 = _mm256_shuffle_epi8(bytes2, idx);
   1725 
   1726    __m256i mask010 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1727        0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, 0));
   1728    __m256i mask001 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
   1729        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF));
   1730 
   1731    __m256i b2g2b1 = _mm256_blendv_epi8(b6g5r5_2, g6r5b5_1, mask001);
   1732    __m256i b2b0b1 = _mm256_blendv_epi8(b2g2b1, r6b5g5_0, mask010);
   1733 
   1734    __m256i r0r1b1 = _mm256_blendv_epi8(r6b5g5_0, g6r5b5_1, mask010);
   1735    __m256i r0r1r2 = _mm256_blendv_epi8(r0r1b1, b6g5r5_2, mask001);
   1736 
   1737    __m256i g1r1g0 = _mm256_blendv_epi8(g6r5b5_1, r6b5g5_0, mask001);
   1738    __m256i g1g2g0 = _mm256_blendv_epi8(g1r1g0, b6g5r5_2, mask010);
   1739 
   1740    __m256i g0g1g2 = _mm256_alignr_epi8(g1g2g0, g1g2g0, 11);
   1741    __m256i b0b1b2 = _mm256_alignr_epi8(b2b0b1, b2b0b1, 6);
   1742 
   1743    // Now r0r1r2, g0g1g2, b0b1b2 have the low bytes of the RGB pixels in their
   1744    // lower half, and the high bytes in their upper half.
   1745 
   1746    auto combine_low_hi = [](__m256i v) {
   1747      __m128i low = _mm256_extracti128_si256(v, 0);
   1748      __m128i hi = _mm256_extracti128_si256(v, 1);
   1749      __m256i low16 = _mm256_cvtepu8_epi16(low);
   1750      __m256i hi16 = _mm256_cvtepu8_epi16(hi);
   1751      return _mm256_or_si256(_mm256_slli_epi16(hi16, 8), low16);
   1752    };
   1753 
   1754    return {SIMDVec16{combine_low_hi(r0r1r2)},
   1755            SIMDVec16{combine_low_hi(g0g1g2)},
   1756            SIMDVec16{combine_low_hi(b0b1b2)}};
   1757  }
   1758 
   1759  static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) {
   1760    __m256i bytes1 = _mm256_loadu_si256((__m256i*)data);
   1761    __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 32));
   1762    __m256i rg_mask = _mm256_set1_epi32(0xFFFF);
   1763    __m256i rg = _mm256_permute4x64_epi64(
   1764        _mm256_packus_epi32(_mm256_and_si256(bytes1, rg_mask),
   1765                            _mm256_and_si256(bytes2, rg_mask)),
   1766        0b11011000);
   1767    __m256i b_a = _mm256_permute4x64_epi64(
   1768        _mm256_packus_epi32(_mm256_srli_epi32(bytes1, 16),
   1769                            _mm256_srli_epi32(bytes2, 16)),
   1770        0b11011000);
   1771    __m256i r = _mm256_and_si256(rg, _mm256_set1_epi16(0xFF));
   1772    __m256i g = _mm256_srli_epi16(rg, 8);
   1773    __m256i b = _mm256_and_si256(b_a, _mm256_set1_epi16(0xFF));
   1774    __m256i a = _mm256_srli_epi16(b_a, 8);
   1775    return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
   1776  }
   1777  static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) {
   1778    __m256i bytes0 = _mm256_loadu_si256((__m256i*)data);
   1779    __m256i bytes1 = _mm256_loadu_si256((__m256i*)(data + 32));
   1780    __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 64));
   1781    __m256i bytes3 = _mm256_loadu_si256((__m256i*)(data + 96));
   1782 
   1783    auto pack32 = [](__m256i a, __m256i b) {
   1784      return _mm256_permute4x64_epi64(_mm256_packus_epi32(a, b), 0b11011000);
   1785    };
   1786    auto packlow32 = [&pack32](__m256i a, __m256i b) {
   1787      __m256i mask = _mm256_set1_epi32(0xFFFF);
   1788      return pack32(_mm256_and_si256(a, mask), _mm256_and_si256(b, mask));
   1789    };
   1790    auto packhi32 = [&pack32](__m256i a, __m256i b) {
   1791      return pack32(_mm256_srli_epi32(a, 16), _mm256_srli_epi32(b, 16));
   1792    };
   1793 
   1794    __m256i rb0 = packlow32(bytes0, bytes1);
   1795    __m256i rb1 = packlow32(bytes2, bytes3);
   1796    __m256i ga0 = packhi32(bytes0, bytes1);
   1797    __m256i ga1 = packhi32(bytes2, bytes3);
   1798 
   1799    __m256i r = packlow32(rb0, rb1);
   1800    __m256i g = packlow32(ga0, ga1);
   1801    __m256i b = packhi32(rb0, rb1);
   1802    __m256i a = packhi32(ga0, ga1);
   1803    return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
   1804  }
   1805 
   1806  void SwapEndian() {
   1807    auto indices = _mm256_broadcastsi128_si256(
   1808        _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14));
   1809    vec = _mm256_shuffle_epi8(vec, indices);
   1810  }
   1811 };
   1812 
   1813 SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true,
   1814                             const SIMDVec16& if_false) {
   1815  return SIMDVec16{_mm256_blendv_epi8(if_false.vec, if_true.vec, mask)};
   1816 }
   1817 
   1818 SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true,
   1819                             const SIMDVec32& if_false) {
   1820  return SIMDVec32{_mm256_blendv_epi8(if_false.vec, if_true.vec, mask)};
   1821 }
   1822 
   1823 struct Bits64 {
   1824  static constexpr size_t kLanes = 4;
   1825 
   1826  __m256i nbits;
   1827  __m256i bits;
   1828 
   1829  FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) {
   1830    _mm256_storeu_si256((__m256i*)nbits_out, nbits);
   1831    _mm256_storeu_si256((__m256i*)bits_out, bits);
   1832  }
   1833 };
   1834 
   1835 struct Bits32 {
   1836  __m256i nbits;
   1837  __m256i bits;
   1838 
   1839  static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) {
   1840    return Bits32{nbits.vec, bits.vec};
   1841  }
   1842 
   1843  Bits64 Merge() const {
   1844    auto nbits_hi32 = _mm256_srli_epi64(nbits, 32);
   1845    auto nbits_lo32 = _mm256_and_si256(nbits, _mm256_set1_epi64x(0xFFFFFFFF));
   1846    auto bits_hi32 = _mm256_srli_epi64(bits, 32);
   1847    auto bits_lo32 = _mm256_and_si256(bits, _mm256_set1_epi64x(0xFFFFFFFF));
   1848 
   1849    auto nbits64 = _mm256_add_epi64(nbits_hi32, nbits_lo32);
   1850    auto bits64 =
   1851        _mm256_or_si256(_mm256_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32);
   1852    return Bits64{nbits64, bits64};
   1853  }
   1854 
   1855  void Interleave(const Bits32& low) {
   1856    bits = _mm256_or_si256(_mm256_sllv_epi32(bits, low.nbits), low.bits);
   1857    nbits = _mm256_add_epi32(nbits, low.nbits);
   1858  }
   1859 
   1860  void ClipTo(size_t n) {
   1861    n = std::min<size_t>(n, 8);
   1862    constexpr uint32_t kMask[16] = {
   1863        ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0, 0,
   1864    };
   1865    __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 8 - n));
   1866    nbits = _mm256_and_si256(mask, nbits);
   1867    bits = _mm256_and_si256(mask, bits);
   1868  }
   1869  void Skip(size_t n) {
   1870    n = std::min<size_t>(n, 8);
   1871    constexpr uint32_t kMask[16] = {
   1872        0, 0, 0, 0, 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
   1873    };
   1874    __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 8 - n));
   1875    nbits = _mm256_and_si256(mask, nbits);
   1876    bits = _mm256_and_si256(mask, bits);
   1877  }
   1878 };
   1879 
   1880 struct Bits16 {
   1881  __m256i nbits;
   1882  __m256i bits;
   1883 
   1884  static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) {
   1885    return Bits16{nbits.vec, bits.vec};
   1886  }
   1887 
   1888  Bits32 Merge() const {
   1889    auto nbits_hi16 = _mm256_srli_epi32(nbits, 16);
   1890    auto nbits_lo16 = _mm256_and_si256(nbits, _mm256_set1_epi32(0xFFFF));
   1891    auto bits_hi16 = _mm256_srli_epi32(bits, 16);
   1892    auto bits_lo16 = _mm256_and_si256(bits, _mm256_set1_epi32(0xFFFF));
   1893 
   1894    auto nbits32 = _mm256_add_epi32(nbits_hi16, nbits_lo16);
   1895    auto bits32 =
   1896        _mm256_or_si256(_mm256_sllv_epi32(bits_hi16, nbits_lo16), bits_lo16);
   1897    return Bits32{nbits32, bits32};
   1898  }
   1899 
   1900  void Interleave(const Bits16& low) {
   1901    auto pow2_lo_lut = _mm256_broadcastsi128_si256(
   1902        _mm_setr_epi8(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6,
   1903                      1u << 7, 0, 0, 0, 0, 0, 0, 0, 0));
   1904    auto low_nbits_masked =
   1905        _mm256_or_si256(low.nbits, _mm256_set1_epi16(0xFF00));
   1906 
   1907    auto bits_shifted = _mm256_mullo_epi16(
   1908        bits, _mm256_shuffle_epi8(pow2_lo_lut, low_nbits_masked));
   1909 
   1910    nbits = _mm256_add_epi16(nbits, low.nbits);
   1911    bits = _mm256_or_si256(bits_shifted, low.bits);
   1912  }
   1913 
   1914  void ClipTo(size_t n) {
   1915    n = std::min<size_t>(n, 16);
   1916    constexpr uint16_t kMask[32] = {
   1917        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1918        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1919        0,      0,      0,      0,      0,      0,      0,      0,
   1920        0,      0,      0,      0,      0,      0,      0,      0,
   1921    };
   1922    __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 16 - n));
   1923    nbits = _mm256_and_si256(mask, nbits);
   1924    bits = _mm256_and_si256(mask, bits);
   1925  }
   1926 
   1927  void Skip(size_t n) {
   1928    n = std::min<size_t>(n, 16);
   1929    constexpr uint16_t kMask[32] = {
   1930        0,      0,      0,      0,      0,      0,      0,      0,
   1931        0,      0,      0,      0,      0,      0,      0,      0,
   1932        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1933        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   1934    };
   1935    __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 16 - n));
   1936    nbits = _mm256_and_si256(mask, nbits);
   1937    bits = _mm256_and_si256(mask, bits);
   1938  }
   1939 };
   1940 
   1941 #endif
   1942 
   1943 #ifdef FJXL_NEON
   1944 #define FJXL_GENERIC_SIMD
   1945 
   1946 struct SIMDVec32;
   1947 
   1948 struct Mask32 {
   1949  uint32x4_t mask;
   1950  SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false);
   1951  Mask32 And(const Mask32& oth) const {
   1952    return Mask32{vandq_u32(mask, oth.mask)};
   1953  }
   1954  size_t CountPrefix() const {
   1955    uint32_t val_unset[4] = {0, 1, 2, 3};
   1956    uint32_t val_set[4] = {4, 4, 4, 4};
   1957    uint32x4_t val = vbslq_u32(mask, vld1q_u32(val_set), vld1q_u32(val_unset));
   1958    return vminvq_u32(val);
   1959  }
   1960 };
   1961 
   1962 struct SIMDVec32 {
   1963  uint32x4_t vec;
   1964 
   1965  static constexpr size_t kLanes = 4;
   1966 
   1967  FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) {
   1968    return SIMDVec32{vld1q_u32(data)};
   1969  }
   1970  FJXL_INLINE void Store(uint32_t* data) { vst1q_u32(data, vec); }
   1971  FJXL_INLINE static SIMDVec32 Val(uint32_t v) {
   1972    return SIMDVec32{vdupq_n_u32(v)};
   1973  }
   1974  FJXL_INLINE SIMDVec32 ValToToken() const {
   1975    return SIMDVec32{vsubq_u32(vdupq_n_u32(32), vclzq_u32(vec))};
   1976  }
   1977  FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const {
   1978    return SIMDVec32{vqsubq_u32(vec, to_subtract.vec)};
   1979  }
   1980  FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const {
   1981    return SIMDVec32{vsubq_u32(vec, to_subtract.vec)};
   1982  }
   1983  FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const {
   1984    return SIMDVec32{vaddq_u32(vec, oth.vec)};
   1985  }
   1986  FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const {
   1987    return SIMDVec32{veorq_u32(vec, oth.vec)};
   1988  }
   1989  FJXL_INLINE SIMDVec32 Pow2() const {
   1990    return SIMDVec32{vshlq_u32(vdupq_n_u32(1), vreinterpretq_s32_u32(vec))};
   1991  }
   1992  FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const {
   1993    return Mask32{vceqq_u32(vec, oth.vec)};
   1994  }
   1995  FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const {
   1996    return Mask32{
   1997        vcgtq_s32(vreinterpretq_s32_u32(vec), vreinterpretq_s32_u32(oth.vec))};
   1998  }
   1999  template <size_t i>
   2000  FJXL_INLINE SIMDVec32 SignedShiftRight() const {
   2001    return SIMDVec32{
   2002        vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(vec), i))};
   2003  }
   2004 };
   2005 
   2006 struct SIMDVec16;
   2007 
   2008 struct Mask16 {
   2009  uint16x8_t mask;
   2010  SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false);
   2011  Mask16 And(const Mask16& oth) const {
   2012    return Mask16{vandq_u16(mask, oth.mask)};
   2013  }
   2014  size_t CountPrefix() const {
   2015    uint16_t val_unset[8] = {0, 1, 2, 3, 4, 5, 6, 7};
   2016    uint16_t val_set[8] = {8, 8, 8, 8, 8, 8, 8, 8};
   2017    uint16x8_t val = vbslq_u16(mask, vld1q_u16(val_set), vld1q_u16(val_unset));
   2018    return vminvq_u16(val);
   2019  }
   2020 };
   2021 
   2022 struct SIMDVec16 {
   2023  uint16x8_t vec;
   2024 
   2025  static constexpr size_t kLanes = 8;
   2026 
   2027  FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) {
   2028    return SIMDVec16{vld1q_u16(data)};
   2029  }
   2030  FJXL_INLINE void Store(uint16_t* data) { vst1q_u16(data, vec); }
   2031  FJXL_INLINE static SIMDVec16 Val(uint16_t v) {
   2032    return SIMDVec16{vdupq_n_u16(v)};
   2033  }
   2034  FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo,
   2035                                         const SIMDVec32& hi) {
   2036    return SIMDVec16{vmovn_high_u32(vmovn_u32(lo.vec), hi.vec)};
   2037  }
   2038 
   2039  FJXL_INLINE SIMDVec16 ValToToken() const {
   2040    return SIMDVec16{vsubq_u16(vdupq_n_u16(16), vclzq_u16(vec))};
   2041  }
   2042  FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const {
   2043    return SIMDVec16{vqsubq_u16(vec, to_subtract.vec)};
   2044  }
   2045  FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const {
   2046    return SIMDVec16{vsubq_u16(vec, to_subtract.vec)};
   2047  }
   2048  FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const {
   2049    return SIMDVec16{vaddq_u16(vec, oth.vec)};
   2050  }
   2051  FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const {
   2052    return SIMDVec16{vminq_u16(vec, oth.vec)};
   2053  }
   2054  FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const {
   2055    return Mask16{vceqq_u16(vec, oth.vec)};
   2056  }
   2057  FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const {
   2058    return Mask16{
   2059        vcgtq_s16(vreinterpretq_s16_u16(vec), vreinterpretq_s16_u16(oth.vec))};
   2060  }
   2061  FJXL_INLINE SIMDVec16 Pow2() const {
   2062    return SIMDVec16{vshlq_u16(vdupq_n_u16(1), vreinterpretq_s16_u16(vec))};
   2063  }
   2064  FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const {
   2065    return SIMDVec16{vorrq_u16(vec, oth.vec)};
   2066  }
   2067  FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const {
   2068    return SIMDVec16{veorq_u16(vec, oth.vec)};
   2069  }
   2070  FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const {
   2071    return SIMDVec16{vandq_u16(vec, oth.vec)};
   2072  }
   2073  FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const {
   2074    return SIMDVec16{vhaddq_u16(vec, oth.vec)};
   2075  }
   2076  FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const {
   2077    return SIMDVec16{vorrq_u16(vec, vdupq_n_u16(0xFF00))};
   2078  }
   2079  FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const {
   2080    uint8x16_t tbl = vld1q_u8(table);
   2081    uint8x16_t indices = vreinterpretq_u8_u16(vec);
   2082    return SIMDVec16{vreinterpretq_u16_u8(vqtbl1q_u8(tbl, indices))};
   2083  }
   2084  FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const {
   2085    return {SIMDVec16{vzip1q_u16(low.vec, vec)},
   2086            SIMDVec16{vzip2q_u16(low.vec, vec)}};
   2087  }
   2088  FJXL_INLINE VecPair<SIMDVec32> Upcast() const {
   2089    uint32x4_t lo = vmovl_u16(vget_low_u16(vec));
   2090    uint32x4_t hi = vmovl_high_u16(vec);
   2091    return {SIMDVec32{lo}, SIMDVec32{hi}};
   2092  }
   2093  template <size_t i>
   2094  FJXL_INLINE SIMDVec16 SignedShiftRight() const {
   2095    return SIMDVec16{
   2096        vreinterpretq_u16_s16(vshrq_n_s16(vreinterpretq_s16_u16(vec), i))};
   2097  }
   2098 
   2099  static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) {
   2100    uint8x8_t v = vld1_u8(data);
   2101    return {SIMDVec16{vmovl_u8(v)}};
   2102  }
   2103  static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) {
   2104    return {Load((const uint16_t*)data)};
   2105  }
   2106 
   2107  static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) {
   2108    uint8x8x2_t v = vld2_u8(data);
   2109    return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])}};
   2110  }
   2111  static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) {
   2112    uint16x8x2_t v = vld2q_u16((const uint16_t*)data);
   2113    return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}};
   2114  }
   2115 
   2116  static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) {
   2117    uint8x8x3_t v = vld3_u8(data);
   2118    return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])},
   2119            SIMDVec16{vmovl_u8(v.val[2])}};
   2120  }
   2121  static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) {
   2122    uint16x8x3_t v = vld3q_u16((const uint16_t*)data);
   2123    return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}, SIMDVec16{v.val[2]}};
   2124  }
   2125 
   2126  static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) {
   2127    uint8x8x4_t v = vld4_u8(data);
   2128    return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])},
   2129            SIMDVec16{vmovl_u8(v.val[2])}, SIMDVec16{vmovl_u8(v.val[3])}};
   2130  }
   2131  static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) {
   2132    uint16x8x4_t v = vld4q_u16((const uint16_t*)data);
   2133    return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}, SIMDVec16{v.val[2]},
   2134            SIMDVec16{v.val[3]}};
   2135  }
   2136 
   2137  void SwapEndian() {
   2138    vec = vreinterpretq_u16_u8(vrev16q_u8(vreinterpretq_u8_u16(vec)));
   2139  }
   2140 };
   2141 
   2142 SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true,
   2143                             const SIMDVec16& if_false) {
   2144  return SIMDVec16{vbslq_u16(mask, if_true.vec, if_false.vec)};
   2145 }
   2146 
   2147 SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true,
   2148                             const SIMDVec32& if_false) {
   2149  return SIMDVec32{vbslq_u32(mask, if_true.vec, if_false.vec)};
   2150 }
   2151 
   2152 struct Bits64 {
   2153  static constexpr size_t kLanes = 2;
   2154 
   2155  uint64x2_t nbits;
   2156  uint64x2_t bits;
   2157 
   2158  FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) {
   2159    vst1q_u64(nbits_out, nbits);
   2160    vst1q_u64(bits_out, bits);
   2161  }
   2162 };
   2163 
   2164 struct Bits32 {
   2165  uint32x4_t nbits;
   2166  uint32x4_t bits;
   2167 
   2168  static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) {
   2169    return Bits32{nbits.vec, bits.vec};
   2170  }
   2171 
   2172  Bits64 Merge() const {
   2173    // TODO(veluca): can probably be optimized.
   2174    uint64x2_t nbits_lo32 =
   2175        vandq_u64(vreinterpretq_u64_u32(nbits), vdupq_n_u64(0xFFFFFFFF));
   2176    uint64x2_t bits_hi32 =
   2177        vshlq_u64(vshrq_n_u64(vreinterpretq_u64_u32(bits), 32),
   2178                  vreinterpretq_s64_u64(nbits_lo32));
   2179    uint64x2_t bits_lo32 =
   2180        vandq_u64(vreinterpretq_u64_u32(bits), vdupq_n_u64(0xFFFFFFFF));
   2181    uint64x2_t nbits64 =
   2182        vsraq_n_u64(nbits_lo32, vreinterpretq_u64_u32(nbits), 32);
   2183    uint64x2_t bits64 = vorrq_u64(bits_hi32, bits_lo32);
   2184    return Bits64{nbits64, bits64};
   2185  }
   2186 
   2187  void Interleave(const Bits32& low) {
   2188    bits =
   2189        vorrq_u32(vshlq_u32(bits, vreinterpretq_s32_u32(low.nbits)), low.bits);
   2190    nbits = vaddq_u32(nbits, low.nbits);
   2191  }
   2192 
   2193  void ClipTo(size_t n) {
   2194    n = std::min<size_t>(n, 4);
   2195    constexpr uint32_t kMask[8] = {
   2196        ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0,
   2197    };
   2198    uint32x4_t mask = vld1q_u32(kMask + 4 - n);
   2199    nbits = vandq_u32(mask, nbits);
   2200    bits = vandq_u32(mask, bits);
   2201  }
   2202  void Skip(size_t n) {
   2203    n = std::min<size_t>(n, 4);
   2204    constexpr uint32_t kMask[8] = {
   2205        0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u,
   2206    };
   2207    uint32x4_t mask = vld1q_u32(kMask + 4 - n);
   2208    nbits = vandq_u32(mask, nbits);
   2209    bits = vandq_u32(mask, bits);
   2210  }
   2211 };
   2212 
   2213 struct Bits16 {
   2214  uint16x8_t nbits;
   2215  uint16x8_t bits;
   2216 
   2217  static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) {
   2218    return Bits16{nbits.vec, bits.vec};
   2219  }
   2220 
   2221  Bits32 Merge() const {
   2222    // TODO(veluca): can probably be optimized.
   2223    uint32x4_t nbits_lo16 =
   2224        vandq_u32(vreinterpretq_u32_u16(nbits), vdupq_n_u32(0xFFFF));
   2225    uint32x4_t bits_hi16 =
   2226        vshlq_u32(vshrq_n_u32(vreinterpretq_u32_u16(bits), 16),
   2227                  vreinterpretq_s32_u32(nbits_lo16));
   2228    uint32x4_t bits_lo16 =
   2229        vandq_u32(vreinterpretq_u32_u16(bits), vdupq_n_u32(0xFFFF));
   2230    uint32x4_t nbits32 =
   2231        vsraq_n_u32(nbits_lo16, vreinterpretq_u32_u16(nbits), 16);
   2232    uint32x4_t bits32 = vorrq_u32(bits_hi16, bits_lo16);
   2233    return Bits32{nbits32, bits32};
   2234  }
   2235 
   2236  void Interleave(const Bits16& low) {
   2237    bits =
   2238        vorrq_u16(vshlq_u16(bits, vreinterpretq_s16_u16(low.nbits)), low.bits);
   2239    nbits = vaddq_u16(nbits, low.nbits);
   2240  }
   2241 
   2242  void ClipTo(size_t n) {
   2243    n = std::min<size_t>(n, 8);
   2244    constexpr uint16_t kMask[16] = {
   2245        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   2246        0,      0,      0,      0,      0,      0,      0,      0,
   2247    };
   2248    uint16x8_t mask = vld1q_u16(kMask + 8 - n);
   2249    nbits = vandq_u16(mask, nbits);
   2250    bits = vandq_u16(mask, bits);
   2251  }
   2252  void Skip(size_t n) {
   2253    n = std::min<size_t>(n, 8);
   2254    constexpr uint16_t kMask[16] = {
   2255        0,      0,      0,      0,      0,      0,      0,      0,
   2256        0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
   2257    };
   2258    uint16x8_t mask = vld1q_u16(kMask + 8 - n);
   2259    nbits = vandq_u16(mask, nbits);
   2260    bits = vandq_u16(mask, bits);
   2261  }
   2262 };
   2263 
   2264 #endif
   2265 
   2266 #ifdef FJXL_GENERIC_SIMD
   2267 constexpr size_t SIMDVec32::kLanes;
   2268 constexpr size_t SIMDVec16::kLanes;
   2269 
   2270 //  Each of these functions will process SIMDVec16::kLanes worth of values.
   2271 
   2272 FJXL_INLINE void TokenizeSIMD(const uint16_t* residuals, uint16_t* token_out,
   2273                              uint16_t* nbits_out, uint16_t* bits_out) {
   2274  SIMDVec16 res = SIMDVec16::Load(residuals);
   2275  SIMDVec16 token = res.ValToToken();
   2276  SIMDVec16 nbits = token.SatSubU(SIMDVec16::Val(1));
   2277  SIMDVec16 bits = res.SatSubU(nbits.Pow2());
   2278  token.Store(token_out);
   2279  nbits.Store(nbits_out);
   2280  bits.Store(bits_out);
   2281 }
   2282 
   2283 FJXL_INLINE void TokenizeSIMD(const uint32_t* residuals, uint16_t* token_out,
   2284                              uint32_t* nbits_out, uint32_t* bits_out) {
   2285  static_assert(SIMDVec16::kLanes == 2 * SIMDVec32::kLanes, "");
   2286  SIMDVec32 res_lo = SIMDVec32::Load(residuals);
   2287  SIMDVec32 res_hi = SIMDVec32::Load(residuals + SIMDVec32::kLanes);
   2288  SIMDVec32 token_lo = res_lo.ValToToken();
   2289  SIMDVec32 token_hi = res_hi.ValToToken();
   2290  SIMDVec32 nbits_lo = token_lo.SatSubU(SIMDVec32::Val(1));
   2291  SIMDVec32 nbits_hi = token_hi.SatSubU(SIMDVec32::Val(1));
   2292  SIMDVec32 bits_lo = res_lo.SatSubU(nbits_lo.Pow2());
   2293  SIMDVec32 bits_hi = res_hi.SatSubU(nbits_hi.Pow2());
   2294  SIMDVec16 token = SIMDVec16::FromTwo32(token_lo, token_hi);
   2295  token.Store(token_out);
   2296  nbits_lo.Store(nbits_out);
   2297  nbits_hi.Store(nbits_out + SIMDVec32::kLanes);
   2298  bits_lo.Store(bits_out);
   2299  bits_hi.Store(bits_out + SIMDVec32::kLanes);
   2300 }
   2301 
   2302 FJXL_INLINE void HuffmanSIMDUpTo13(const uint16_t* tokens,
   2303                                   const uint8_t* raw_nbits_simd,
   2304                                   const uint8_t* raw_bits_simd,
   2305                                   uint16_t* nbits_out, uint16_t* bits_out) {
   2306  SIMDVec16 tok = SIMDVec16::Load(tokens).PrepareForU8Lookup();
   2307  tok.U8Lookup(raw_nbits_simd).Store(nbits_out);
   2308  tok.U8Lookup(raw_bits_simd).Store(bits_out);
   2309 }
   2310 
   2311 FJXL_INLINE void HuffmanSIMD14(const uint16_t* tokens,
   2312                               const uint8_t* raw_nbits_simd,
   2313                               const uint8_t* raw_bits_simd,
   2314                               uint16_t* nbits_out, uint16_t* bits_out) {
   2315  SIMDVec16 token_cap = SIMDVec16::Val(15);
   2316  SIMDVec16 tok = SIMDVec16::Load(tokens);
   2317  SIMDVec16 tok_index = tok.Min(token_cap).PrepareForU8Lookup();
   2318  SIMDVec16 huff_bits_pre = tok_index.U8Lookup(raw_bits_simd);
   2319  // Set the highest bit when token == 16; the Huffman code is constructed in
   2320  // such a way that the code for token 15 is the same as the code for 16,
   2321  // except for the highest bit.
   2322  Mask16 needs_high_bit = tok.Eq(SIMDVec16::Val(16));
   2323  SIMDVec16 huff_bits = needs_high_bit.IfThenElse(
   2324      huff_bits_pre.Or(SIMDVec16::Val(128)), huff_bits_pre);
   2325  huff_bits.Store(bits_out);
   2326  tok_index.U8Lookup(raw_nbits_simd).Store(nbits_out);
   2327 }
   2328 
   2329 FJXL_INLINE void HuffmanSIMDAbove14(const uint16_t* tokens,
   2330                                    const uint8_t* raw_nbits_simd,
   2331                                    const uint8_t* raw_bits_simd,
   2332                                    uint16_t* nbits_out, uint16_t* bits_out) {
   2333  SIMDVec16 tok = SIMDVec16::Load(tokens);
   2334  // We assume `tok` fits in a *signed* 16-bit integer.
   2335  Mask16 above = tok.Gt(SIMDVec16::Val(12));
   2336  // 13, 14 -> 13
   2337  // 15, 16 -> 14
   2338  // 17, 18 -> 15
   2339  SIMDVec16 remap_tok = above.IfThenElse(tok.HAdd(SIMDVec16::Val(13)), tok);
   2340  SIMDVec16 tok_index = remap_tok.PrepareForU8Lookup();
   2341  SIMDVec16 huff_bits_pre = tok_index.U8Lookup(raw_bits_simd);
   2342  // Set the highest bit when token == 14, 16, 18.
   2343  Mask16 needs_high_bit = above.And(tok.Eq(tok.And(SIMDVec16::Val(0xFFFE))));
   2344  SIMDVec16 huff_bits = needs_high_bit.IfThenElse(
   2345      huff_bits_pre.Or(SIMDVec16::Val(128)), huff_bits_pre);
   2346  huff_bits.Store(bits_out);
   2347  tok_index.U8Lookup(raw_nbits_simd).Store(nbits_out);
   2348 }
   2349 
   2350 FJXL_INLINE void StoreSIMDUpTo8(const uint16_t* nbits_tok,
   2351                                const uint16_t* bits_tok,
   2352                                const uint16_t* nbits_huff,
   2353                                const uint16_t* bits_huff, size_t n,
   2354                                size_t skip, Bits32* bits_out) {
   2355  Bits16 bits =
   2356      Bits16::FromRaw(SIMDVec16::Load(nbits_tok), SIMDVec16::Load(bits_tok));
   2357  Bits16 huff_bits =
   2358      Bits16::FromRaw(SIMDVec16::Load(nbits_huff), SIMDVec16::Load(bits_huff));
   2359  bits.Interleave(huff_bits);
   2360  bits.ClipTo(n);
   2361  bits.Skip(skip);
   2362  bits_out[0] = bits.Merge();
   2363 }
   2364 
   2365 // Huffman and raw bits don't necessarily fit in a single u16 here.
   2366 FJXL_INLINE void StoreSIMDUpTo14(const uint16_t* nbits_tok,
   2367                                 const uint16_t* bits_tok,
   2368                                 const uint16_t* nbits_huff,
   2369                                 const uint16_t* bits_huff, size_t n,
   2370                                 size_t skip, Bits32* bits_out) {
   2371  VecPair<SIMDVec16> bits =
   2372      SIMDVec16::Load(bits_tok).Interleave(SIMDVec16::Load(bits_huff));
   2373  VecPair<SIMDVec16> nbits =
   2374      SIMDVec16::Load(nbits_tok).Interleave(SIMDVec16::Load(nbits_huff));
   2375  Bits16 low = Bits16::FromRaw(nbits.low, bits.low);
   2376  Bits16 hi = Bits16::FromRaw(nbits.hi, bits.hi);
   2377  low.ClipTo(2 * n);
   2378  low.Skip(2 * skip);
   2379  hi.ClipTo(std::max(2 * n, SIMDVec16::kLanes) - SIMDVec16::kLanes);
   2380  hi.Skip(std::max(2 * skip, SIMDVec16::kLanes) - SIMDVec16::kLanes);
   2381 
   2382  bits_out[0] = low.Merge();
   2383  bits_out[1] = hi.Merge();
   2384 }
   2385 
   2386 FJXL_INLINE void StoreSIMDAbove14(const uint32_t* nbits_tok,
   2387                                  const uint32_t* bits_tok,
   2388                                  const uint16_t* nbits_huff,
   2389                                  const uint16_t* bits_huff, size_t n,
   2390                                  size_t skip, Bits32* bits_out) {
   2391  static_assert(SIMDVec16::kLanes == 2 * SIMDVec32::kLanes, "");
   2392  Bits32 bits_low =
   2393      Bits32::FromRaw(SIMDVec32::Load(nbits_tok), SIMDVec32::Load(bits_tok));
   2394  Bits32 bits_hi =
   2395      Bits32::FromRaw(SIMDVec32::Load(nbits_tok + SIMDVec32::kLanes),
   2396                      SIMDVec32::Load(bits_tok + SIMDVec32::kLanes));
   2397 
   2398  VecPair<SIMDVec32> huff_bits = SIMDVec16::Load(bits_huff).Upcast();
   2399  VecPair<SIMDVec32> huff_nbits = SIMDVec16::Load(nbits_huff).Upcast();
   2400 
   2401  Bits32 huff_low = Bits32::FromRaw(huff_nbits.low, huff_bits.low);
   2402  Bits32 huff_hi = Bits32::FromRaw(huff_nbits.hi, huff_bits.hi);
   2403 
   2404  bits_low.Interleave(huff_low);
   2405  bits_low.ClipTo(n);
   2406  bits_low.Skip(skip);
   2407  bits_out[0] = bits_low;
   2408  bits_hi.Interleave(huff_hi);
   2409  bits_hi.ClipTo(std::max(n, SIMDVec32::kLanes) - SIMDVec32::kLanes);
   2410  bits_hi.Skip(std::max(skip, SIMDVec32::kLanes) - SIMDVec32::kLanes);
   2411  bits_out[1] = bits_hi;
   2412 }
   2413 
   2414 #ifdef FJXL_AVX512
   2415 FJXL_INLINE void StoreToWriterAVX512(const Bits32& bits32, BitWriter& output) {
   2416  __m512i bits = bits32.bits;
   2417  __m512i nbits = bits32.nbits;
   2418 
   2419  // Insert the leftover bits from the bit buffer at the bottom of the vector
   2420  // and extract the top of the vector.
   2421  uint64_t trail_bits =
   2422      _mm512_cvtsi512_si32(_mm512_alignr_epi32(bits, bits, 15));
   2423  uint64_t trail_nbits =
   2424      _mm512_cvtsi512_si32(_mm512_alignr_epi32(nbits, nbits, 15));
   2425  __m512i lead_bits = _mm512_set1_epi32(output.buffer);
   2426  __m512i lead_nbits = _mm512_set1_epi32(output.bits_in_buffer);
   2427  bits = _mm512_alignr_epi32(bits, lead_bits, 15);
   2428  nbits = _mm512_alignr_epi32(nbits, lead_nbits, 15);
   2429 
   2430  // Merge 32 -> 64 bits.
   2431  Bits32 b{nbits, bits};
   2432  Bits64 b64 = b.Merge();
   2433  bits = b64.bits;
   2434  nbits = b64.nbits;
   2435 
   2436  __m512i zero = _mm512_setzero_si512();
   2437 
   2438  auto sh1 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 7); };
   2439  auto sh2 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 6); };
   2440  auto sh4 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 4); };
   2441 
   2442  // Compute first-past-end-bit-position.
   2443  __m512i end_intermediate0 = _mm512_add_epi64(nbits, sh1(nbits));
   2444  __m512i end_intermediate1 =
   2445      _mm512_add_epi64(end_intermediate0, sh2(end_intermediate0));
   2446  __m512i end = _mm512_add_epi64(end_intermediate1, sh4(end_intermediate1));
   2447 
   2448  uint64_t simd_nbits = _mm512_cvtsi512_si32(_mm512_alignr_epi64(end, end, 7));
   2449 
   2450  // Compute begin-bit-position.
   2451  __m512i begin = _mm512_sub_epi64(end, nbits);
   2452 
   2453  // Index of the last bit in the chunk, or the end bit if nbits==0.
   2454  __m512i last = _mm512_mask_sub_epi64(
   2455      end, _mm512_cmpneq_epi64_mask(nbits, zero), end, _mm512_set1_epi64(1));
   2456 
   2457  __m512i lane_offset_mask = _mm512_set1_epi64(63);
   2458 
   2459  // Starting position of the chunk that each lane will ultimately belong to.
   2460  __m512i chunk_start = _mm512_andnot_si512(lane_offset_mask, last);
   2461 
   2462  // For all lanes that contain bits belonging to two different 64-bit chunks,
   2463  // compute the number of bits that belong to the first chunk.
   2464  // total # of bits fit in a u16, so we can satsub_u16 here.
   2465  __m512i first_chunk_nbits = _mm512_subs_epu16(chunk_start, begin);
   2466 
   2467  // Move all the previous-chunk-bits to the previous lane.
   2468  __m512i negnbits = _mm512_sub_epi64(_mm512_set1_epi64(64), first_chunk_nbits);
   2469  __m512i first_chunk_bits =
   2470      _mm512_srlv_epi64(_mm512_sllv_epi64(bits, negnbits), negnbits);
   2471  __m512i first_chunk_bits_down =
   2472      _mm512_alignr_epi32(zero, first_chunk_bits, 2);
   2473  bits = _mm512_srlv_epi64(bits, first_chunk_nbits);
   2474  nbits = _mm512_sub_epi64(nbits, first_chunk_nbits);
   2475  bits = _mm512_or_si512(bits, _mm512_sllv_epi64(first_chunk_bits_down, nbits));
   2476  begin = _mm512_add_epi64(begin, first_chunk_nbits);
   2477 
   2478  // We now know that every lane should give bits to only one chunk. We can
   2479  // shift the bits and then horizontally-or-reduce them within the same chunk.
   2480  __m512i offset = _mm512_and_si512(begin, lane_offset_mask);
   2481  __m512i aligned_bits = _mm512_sllv_epi64(bits, offset);
   2482  // h-or-reduce within same chunk
   2483  __m512i red0 = _mm512_mask_or_epi64(
   2484      aligned_bits, _mm512_cmpeq_epi64_mask(sh1(chunk_start), chunk_start),
   2485      sh1(aligned_bits), aligned_bits);
   2486  __m512i red1 = _mm512_mask_or_epi64(
   2487      red0, _mm512_cmpeq_epi64_mask(sh2(chunk_start), chunk_start), sh2(red0),
   2488      red0);
   2489  __m512i reduced = _mm512_mask_or_epi64(
   2490      red1, _mm512_cmpeq_epi64_mask(sh4(chunk_start), chunk_start), sh4(red1),
   2491      red1);
   2492  // Extract the highest lane that belongs to each chunk (the lane that ends up
   2493  // with the OR-ed value of all the other lanes of that chunk).
   2494  __m512i next_chunk_start =
   2495      _mm512_alignr_epi32(_mm512_set1_epi64(~0), chunk_start, 2);
   2496  __m512i result = _mm512_maskz_compress_epi64(
   2497      _mm512_cmpneq_epi64_mask(chunk_start, next_chunk_start), reduced);
   2498 
   2499  _mm512_storeu_si512((__m512i*)(output.data.get() + output.bytes_written),
   2500                      result);
   2501 
   2502  // Update the bit writer and add the last 32-bit lane.
   2503  // Note that since trail_nbits was at most 32 to begin with, operating on
   2504  // trail_bits does not risk overflowing.
   2505  output.bytes_written += simd_nbits / 8;
   2506  // Here we are implicitly relying on the fact that simd_nbits < 512 to know
   2507  // that the byte of bitreader data we access is initialized. This is
   2508  // guaranteed because the remaining bits in the bitreader buffer are at most
   2509  // 7, so simd_nbits <= 505 always.
   2510  trail_bits = (trail_bits << (simd_nbits % 8)) +
   2511               output.data.get()[output.bytes_written];
   2512  trail_nbits += simd_nbits % 8;
   2513  StoreLE64(output.data.get() + output.bytes_written, trail_bits);
   2514  size_t trail_bytes = trail_nbits / 8;
   2515  output.bits_in_buffer = trail_nbits % 8;
   2516  output.buffer = trail_bits >> (trail_bytes * 8);
   2517  output.bytes_written += trail_bytes;
   2518 }
   2519 
   2520 #endif
   2521 
   2522 template <size_t n>
   2523 FJXL_INLINE void StoreToWriter(const Bits32* bits, BitWriter& output) {
   2524 #ifdef FJXL_AVX512
   2525  static_assert(n <= 2, "");
   2526  StoreToWriterAVX512(bits[0], output);
   2527  if (n == 2) {
   2528    StoreToWriterAVX512(bits[1], output);
   2529  }
   2530  return;
   2531 #endif
   2532  static_assert(n <= 4, "");
   2533  alignas(64) uint64_t nbits64[Bits64::kLanes * n];
   2534  alignas(64) uint64_t bits64[Bits64::kLanes * n];
   2535  bits[0].Merge().Store(nbits64, bits64);
   2536  if (n > 1) {
   2537    bits[1].Merge().Store(nbits64 + Bits64::kLanes, bits64 + Bits64::kLanes);
   2538  }
   2539  if (n > 2) {
   2540    bits[2].Merge().Store(nbits64 + 2 * Bits64::kLanes,
   2541                          bits64 + 2 * Bits64::kLanes);
   2542  }
   2543  if (n > 3) {
   2544    bits[3].Merge().Store(nbits64 + 3 * Bits64::kLanes,
   2545                          bits64 + 3 * Bits64::kLanes);
   2546  }
   2547  output.WriteMultiple(nbits64, bits64, Bits64::kLanes * n);
   2548 }
   2549 
   2550 namespace detail {
   2551 template <typename T>
   2552 struct IntegerTypes;
   2553 
   2554 template <>
   2555 struct IntegerTypes<SIMDVec16> {
   2556  using signed_ = int16_t;
   2557  using unsigned_ = uint16_t;
   2558 };
   2559 
   2560 template <>
   2561 struct IntegerTypes<SIMDVec32> {
   2562  using signed_ = int32_t;
   2563  using unsigned_ = uint32_t;
   2564 };
   2565 
   2566 template <typename T>
   2567 struct SIMDType;
   2568 
   2569 template <>
   2570 struct SIMDType<int16_t> {
   2571  using type = SIMDVec16;
   2572 };
   2573 
   2574 template <>
   2575 struct SIMDType<int32_t> {
   2576  using type = SIMDVec32;
   2577 };
   2578 
   2579 }  // namespace detail
   2580 
   2581 template <typename T>
   2582 using signed_t = typename detail::IntegerTypes<T>::signed_;
   2583 
   2584 template <typename T>
   2585 using unsigned_t = typename detail::IntegerTypes<T>::unsigned_;
   2586 
   2587 template <typename T>
   2588 using simd_t = typename detail::SIMDType<T>::type;
   2589 
   2590 // This function will process exactly one vector worth of pixels.
   2591 
   2592 template <typename T>
   2593 size_t PredictPixels(const signed_t<T>* pixels, const signed_t<T>* pixels_left,
   2594                     const signed_t<T>* pixels_top,
   2595                     const signed_t<T>* pixels_topleft,
   2596                     unsigned_t<T>* residuals) {
   2597  T px = T::Load((unsigned_t<T>*)pixels);
   2598  T left = T::Load((unsigned_t<T>*)pixels_left);
   2599  T top = T::Load((unsigned_t<T>*)pixels_top);
   2600  T topleft = T::Load((unsigned_t<T>*)pixels_topleft);
   2601  T ac = left.Sub(topleft);
   2602  T ab = left.Sub(top);
   2603  T bc = top.Sub(topleft);
   2604  T grad = ac.Add(top);
   2605  T d = ab.Xor(bc);
   2606  T zero = T::Val(0);
   2607  T clamp = zero.Gt(d).IfThenElse(top, left);
   2608  T s = ac.Xor(bc);
   2609  T pred = zero.Gt(s).IfThenElse(grad, clamp);
   2610  T res = px.Sub(pred);
   2611  T res_times_2 = res.Add(res);
   2612  res = zero.Gt(res).IfThenElse(T::Val(-1).Sub(res_times_2), res_times_2);
   2613  res.Store(residuals);
   2614  return res.Eq(T::Val(0)).CountPrefix();
   2615 }
   2616 
   2617 #endif
   2618 
   2619 void EncodeHybridUint000(uint32_t value, uint32_t* token, uint32_t* nbits,
   2620                         uint32_t* bits) {
   2621  uint32_t n = FloorLog2(value);
   2622  *token = value ? n + 1 : 0;
   2623  *nbits = value ? n : 0;
   2624  *bits = value ? value - (1 << n) : 0;
   2625 }
   2626 
   2627 #ifdef FJXL_AVX512
   2628 constexpr static size_t kLogChunkSize = 5;
   2629 #elif defined(FJXL_AVX2) || defined(FJXL_NEON)
   2630 // Even if NEON only has 128-bit lanes, it is still significantly (~1.3x) faster
   2631 // to process two vectors at a time.
   2632 constexpr static size_t kLogChunkSize = 4;
   2633 #else
   2634 constexpr static size_t kLogChunkSize = 3;
   2635 #endif
   2636 
   2637 constexpr static size_t kChunkSize = 1 << kLogChunkSize;
   2638 
   2639 template <typename Residual>
   2640 void GenericEncodeChunk(const Residual* residuals, size_t n, size_t skip,
   2641                        const PrefixCode& code, BitWriter& output) {
   2642  for (size_t ix = skip; ix < n; ix++) {
   2643    unsigned token, nbits, bits;
   2644    EncodeHybridUint000(residuals[ix], &token, &nbits, &bits);
   2645    output.Write(code.raw_nbits[token] + nbits,
   2646                 code.raw_bits[token] | bits << code.raw_nbits[token]);
   2647  }
   2648 }
   2649 
   2650 struct UpTo8Bits {
   2651  size_t bitdepth;
   2652  explicit UpTo8Bits(size_t bitdepth) : bitdepth(bitdepth) {
   2653    assert(bitdepth <= 8);
   2654  }
   2655  // Here we can fit up to 9 extra bits + 7 Huffman bits in a u16; for all other
   2656  // symbols, we could actually go up to 8 Huffman bits as we have at most 8
   2657  // extra bits; however, the SIMD bit merging logic for AVX2 assumes that no
   2658  // Huffman length is 8 or more, so we cap at 8 anyway. Last symbol is used for
   2659  // LZ77 lengths and has no limitations except allowing to represent 32 symbols
   2660  // in total.
   2661  static constexpr uint8_t kMinRawLength[12] = {};
   2662  static constexpr uint8_t kMaxRawLength[12] = {
   2663      7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 10,
   2664  };
   2665  static size_t MaxEncodedBitsPerSample() { return 16; }
   2666  static constexpr size_t kInputBytes = 1;
   2667  using pixel_t = int16_t;
   2668  using upixel_t = uint16_t;
   2669 
   2670  static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
   2671                             size_t n, uint8_t* nbits_simd,
   2672                             uint8_t* bits_simd) {
   2673    assert(n <= 16);
   2674    memcpy(nbits_simd, nbits, 16);
   2675    memcpy(bits_simd, bits, 16);
   2676  }
   2677 
   2678 #ifdef FJXL_GENERIC_SIMD
   2679  static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip,
   2680                              const uint8_t* raw_nbits_simd,
   2681                              const uint8_t* raw_bits_simd, BitWriter& output) {
   2682    Bits32 bits32[kChunkSize / SIMDVec16::kLanes];
   2683    alignas(64) uint16_t bits[SIMDVec16::kLanes];
   2684    alignas(64) uint16_t nbits[SIMDVec16::kLanes];
   2685    alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
   2686    alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
   2687    alignas(64) uint16_t token[SIMDVec16::kLanes];
   2688    for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
   2689      TokenizeSIMD(residuals + i, token, nbits, bits);
   2690      HuffmanSIMDUpTo13(token, raw_nbits_simd, raw_bits_simd, nbits_huff,
   2691                        bits_huff);
   2692      StoreSIMDUpTo8(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
   2693                     std::max(skip, i) - i, bits32 + i / SIMDVec16::kLanes);
   2694    }
   2695    StoreToWriter<kChunkSize / SIMDVec16::kLanes>(bits32, output);
   2696  }
   2697 #endif
   2698 
   2699  size_t NumSymbols(bool doing_ycocg_or_large_palette) const {
   2700    // values gain 1 bit for YCoCg, 1 bit for prediction.
   2701    // Maximum symbol is 1 + effective bit depth of residuals.
   2702    if (doing_ycocg_or_large_palette) {
   2703      return bitdepth + 3;
   2704    } else {
   2705      return bitdepth + 2;
   2706    }
   2707  }
   2708 };
   2709 constexpr uint8_t UpTo8Bits::kMinRawLength[];
   2710 constexpr uint8_t UpTo8Bits::kMaxRawLength[];
   2711 
   2712 struct From9To13Bits {
   2713  size_t bitdepth;
   2714  explicit From9To13Bits(size_t bitdepth) : bitdepth(bitdepth) {
   2715    assert(bitdepth <= 13 && bitdepth >= 9);
   2716  }
   2717  // Last symbol is used for LZ77 lengths and has no limitations except allowing
   2718  // to represent 32 symbols in total.
   2719  // We cannot fit all the bits in a u16, so do not even try and use up to 8
   2720  // bits per raw symbol.
   2721  // There are at most 16 raw symbols, so Huffman coding can be SIMDfied without
   2722  // any special tricks.
   2723  static constexpr uint8_t kMinRawLength[17] = {};
   2724  static constexpr uint8_t kMaxRawLength[17] = {
   2725      8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 10,
   2726  };
   2727  static size_t MaxEncodedBitsPerSample() { return 21; }
   2728  static constexpr size_t kInputBytes = 2;
   2729  using pixel_t = int16_t;
   2730  using upixel_t = uint16_t;
   2731 
   2732  static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
   2733                             size_t n, uint8_t* nbits_simd,
   2734                             uint8_t* bits_simd) {
   2735    assert(n <= 16);
   2736    memcpy(nbits_simd, nbits, 16);
   2737    memcpy(bits_simd, bits, 16);
   2738  }
   2739 
   2740 #ifdef FJXL_GENERIC_SIMD
   2741  static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip,
   2742                              const uint8_t* raw_nbits_simd,
   2743                              const uint8_t* raw_bits_simd, BitWriter& output) {
   2744    Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes];
   2745    alignas(64) uint16_t bits[SIMDVec16::kLanes];
   2746    alignas(64) uint16_t nbits[SIMDVec16::kLanes];
   2747    alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
   2748    alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
   2749    alignas(64) uint16_t token[SIMDVec16::kLanes];
   2750    for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
   2751      TokenizeSIMD(residuals + i, token, nbits, bits);
   2752      HuffmanSIMDUpTo13(token, raw_nbits_simd, raw_bits_simd, nbits_huff,
   2753                        bits_huff);
   2754      StoreSIMDUpTo14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
   2755                      std::max(skip, i) - i,
   2756                      bits32 + 2 * i / SIMDVec16::kLanes);
   2757    }
   2758    StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output);
   2759  }
   2760 #endif
   2761 
   2762  size_t NumSymbols(bool doing_ycocg_or_large_palette) const {
   2763    // values gain 1 bit for YCoCg, 1 bit for prediction.
   2764    // Maximum symbol is 1 + effective bit depth of residuals.
   2765    if (doing_ycocg_or_large_palette) {
   2766      return bitdepth + 3;
   2767    } else {
   2768      return bitdepth + 2;
   2769    }
   2770  }
   2771 };
   2772 constexpr uint8_t From9To13Bits::kMinRawLength[];
   2773 constexpr uint8_t From9To13Bits::kMaxRawLength[];
   2774 
   2775 void CheckHuffmanBitsSIMD(int bits1, int nbits1, int bits2, int nbits2) {
   2776  assert(nbits1 == 8);
   2777  assert(nbits2 == 8);
   2778  assert(bits2 == (bits1 | 128));
   2779 }
   2780 
   2781 struct Exactly14Bits {
   2782  explicit Exactly14Bits(size_t bitdepth) { assert(bitdepth == 14); }
   2783  // Force LZ77 symbols to have at least 8 bits, and raw symbols 15 and 16 to
   2784  // have exactly 8, and no other symbol to have 8 or more. This ensures that
   2785  // the representation for 15 and 16 is identical up to one bit.
   2786  static constexpr uint8_t kMinRawLength[18] = {
   2787      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 7,
   2788  };
   2789  static constexpr uint8_t kMaxRawLength[18] = {
   2790      7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 10,
   2791  };
   2792  static constexpr size_t bitdepth = 14;
   2793  static size_t MaxEncodedBitsPerSample() { return 22; }
   2794  static constexpr size_t kInputBytes = 2;
   2795  using pixel_t = int16_t;
   2796  using upixel_t = uint16_t;
   2797 
   2798  static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
   2799                             size_t n, uint8_t* nbits_simd,
   2800                             uint8_t* bits_simd) {
   2801    assert(n == 17);
   2802    CheckHuffmanBitsSIMD(bits[15], nbits[15], bits[16], nbits[16]);
   2803    memcpy(nbits_simd, nbits, 16);
   2804    memcpy(bits_simd, bits, 16);
   2805  }
   2806 
   2807 #ifdef FJXL_GENERIC_SIMD
   2808  static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip,
   2809                              const uint8_t* raw_nbits_simd,
   2810                              const uint8_t* raw_bits_simd, BitWriter& output) {
   2811    Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes];
   2812    alignas(64) uint16_t bits[SIMDVec16::kLanes];
   2813    alignas(64) uint16_t nbits[SIMDVec16::kLanes];
   2814    alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
   2815    alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
   2816    alignas(64) uint16_t token[SIMDVec16::kLanes];
   2817    for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
   2818      TokenizeSIMD(residuals + i, token, nbits, bits);
   2819      HuffmanSIMD14(token, raw_nbits_simd, raw_bits_simd, nbits_huff,
   2820                    bits_huff);
   2821      StoreSIMDUpTo14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
   2822                      std::max(skip, i) - i,
   2823                      bits32 + 2 * i / SIMDVec16::kLanes);
   2824    }
   2825    StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output);
   2826  }
   2827 #endif
   2828 
   2829  size_t NumSymbols(bool) const { return 17; }
   2830 };
   2831 constexpr uint8_t Exactly14Bits::kMinRawLength[];
   2832 constexpr uint8_t Exactly14Bits::kMaxRawLength[];
   2833 
   2834 struct MoreThan14Bits {
   2835  size_t bitdepth;
   2836  explicit MoreThan14Bits(size_t bitdepth) : bitdepth(bitdepth) {
   2837    assert(bitdepth > 14);
   2838    assert(bitdepth <= 16);
   2839  }
   2840  // Force LZ77 symbols to have at least 8 bits, and raw symbols 13 to 18 to
   2841  // have exactly 8, and no other symbol to have 8 or more. This ensures that
   2842  // the representation for (13, 14), (15, 16), (17, 18) is identical up to one
   2843  // bit.
   2844  static constexpr uint8_t kMinRawLength[20] = {
   2845      0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 7,
   2846  };
   2847  static constexpr uint8_t kMaxRawLength[20] = {
   2848      7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 10,
   2849  };
   2850  static size_t MaxEncodedBitsPerSample() { return 24; }
   2851  static constexpr size_t kInputBytes = 2;
   2852  using pixel_t = int32_t;
   2853  using upixel_t = uint32_t;
   2854 
   2855  static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
   2856                             size_t n, uint8_t* nbits_simd,
   2857                             uint8_t* bits_simd) {
   2858    assert(n == 19);
   2859    CheckHuffmanBitsSIMD(bits[13], nbits[13], bits[14], nbits[14]);
   2860    CheckHuffmanBitsSIMD(bits[15], nbits[15], bits[16], nbits[16]);
   2861    CheckHuffmanBitsSIMD(bits[17], nbits[17], bits[18], nbits[18]);
   2862    for (size_t i = 0; i < 14; i++) {
   2863      nbits_simd[i] = nbits[i];
   2864      bits_simd[i] = bits[i];
   2865    }
   2866    nbits_simd[14] = nbits[15];
   2867    bits_simd[14] = bits[15];
   2868    nbits_simd[15] = nbits[17];
   2869    bits_simd[15] = bits[17];
   2870  }
   2871 
   2872 #ifdef FJXL_GENERIC_SIMD
   2873  static void EncodeChunkSimd(upixel_t* residuals, size_t n, size_t skip,
   2874                              const uint8_t* raw_nbits_simd,
   2875                              const uint8_t* raw_bits_simd, BitWriter& output) {
   2876    Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes];
   2877    alignas(64) uint32_t bits[SIMDVec16::kLanes];
   2878    alignas(64) uint32_t nbits[SIMDVec16::kLanes];
   2879    alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
   2880    alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
   2881    alignas(64) uint16_t token[SIMDVec16::kLanes];
   2882    for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
   2883      TokenizeSIMD(residuals + i, token, nbits, bits);
   2884      HuffmanSIMDAbove14(token, raw_nbits_simd, raw_bits_simd, nbits_huff,
   2885                         bits_huff);
   2886      StoreSIMDAbove14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
   2887                       std::max(skip, i) - i,
   2888                       bits32 + 2 * i / SIMDVec16::kLanes);
   2889    }
   2890    StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output);
   2891  }
   2892 #endif
   2893  size_t NumSymbols(bool) const { return 19; }
   2894 };
   2895 constexpr uint8_t MoreThan14Bits::kMinRawLength[];
   2896 constexpr uint8_t MoreThan14Bits::kMaxRawLength[];
   2897 
   2898 void PrepareDCGlobalCommon(bool is_single_group, size_t width, size_t height,
   2899                           const PrefixCode code[4], BitWriter* output) {
   2900  output->Allocate(100000 + (is_single_group ? width * height * 16 : 0));
   2901  // No patches, spline or noise.
   2902  output->Write(1, 1);  // default DC dequantization factors (?)
   2903  output->Write(1, 1);  // use global tree / histograms
   2904  output->Write(1, 0);  // no lz77 for the tree
   2905 
   2906  output->Write(1, 1);         // simple code for the tree's context map
   2907  output->Write(2, 0);         // all contexts clustered together
   2908  output->Write(1, 1);         // use prefix code for tree
   2909  output->Write(4, 0);         // 000 hybrid uint
   2910  output->Write(6, 0b100011);  // Alphabet size is 4 (var16)
   2911  output->Write(2, 1);         // simple prefix code
   2912  output->Write(2, 3);         // with 4 symbols
   2913  output->Write(2, 0);
   2914  output->Write(2, 1);
   2915  output->Write(2, 2);
   2916  output->Write(2, 3);
   2917  output->Write(1, 0);  // First tree encoding option
   2918 
   2919  // Huffman table + extra bits for the tree.
   2920  uint8_t symbol_bits[6] = {0b00, 0b10, 0b001, 0b101, 0b0011, 0b0111};
   2921  uint8_t symbol_nbits[6] = {2, 2, 3, 3, 4, 4};
   2922  // Write a tree with a leaf per channel, and gradient predictor for every
   2923  // leaf.
   2924  for (auto v : {1, 2, 1, 4, 1, 0, 0, 5, 0, 0, 0, 0, 5,
   2925                 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0}) {
   2926    output->Write(symbol_nbits[v], symbol_bits[v]);
   2927  }
   2928 
   2929  output->Write(1, 1);     // Enable lz77 for the main bitstream
   2930  output->Write(2, 0b00);  // lz77 offset 224
   2931  static_assert(kLZ77Offset == 224, "");
   2932  output->Write(4, 0b1010);  // lz77 min length 7
   2933  // 400 hybrid uint config for lz77
   2934  output->Write(4, 4);
   2935  output->Write(3, 0);
   2936  output->Write(3, 0);
   2937 
   2938  output->Write(1, 1);  // simple code for the context map
   2939  output->Write(2, 3);  // 3 bits per entry
   2940  output->Write(3, 4);  // channel 3
   2941  output->Write(3, 3);  // channel 2
   2942  output->Write(3, 2);  // channel 1
   2943  output->Write(3, 1);  // channel 0
   2944  output->Write(3, 0);  // distance histogram first
   2945 
   2946  output->Write(1, 1);  // use prefix codes
   2947  output->Write(4, 0);  // 000 hybrid uint config for distances (only need 0)
   2948  for (size_t i = 0; i < 4; i++) {
   2949    output->Write(4, 0);  // 000 hybrid uint config for symbols (only <= 10)
   2950  }
   2951 
   2952  // Distance alphabet size:
   2953  output->Write(5, 0b00001);  // 2: just need 1 for RLE (i.e. distance 1)
   2954  // Symbol + LZ77 alphabet size:
   2955  for (size_t i = 0; i < 4; i++) {
   2956    output->Write(1, 1);    // > 1
   2957    output->Write(4, 8);    // <= 512
   2958    output->Write(8, 256);  // == 512
   2959  }
   2960 
   2961  // Distance histogram:
   2962  output->Write(2, 1);  // simple prefix code
   2963  output->Write(2, 0);  // with one symbol
   2964  output->Write(1, 1);  // 1
   2965 
   2966  // Symbol + lz77 histogram:
   2967  for (size_t i = 0; i < 4; i++) {
   2968    code[i].WriteTo(output);
   2969  }
   2970 
   2971  // Group header for global modular image.
   2972  output->Write(1, 1);  // Global tree
   2973  output->Write(1, 1);  // All default wp
   2974 }
   2975 
   2976 void PrepareDCGlobal(bool is_single_group, size_t width, size_t height,
   2977                     size_t nb_chans, const PrefixCode code[4],
   2978                     BitWriter* output) {
   2979  PrepareDCGlobalCommon(is_single_group, width, height, code, output);
   2980  if (nb_chans > 2) {
   2981    output->Write(2, 0b01);     // 1 transform
   2982    output->Write(2, 0b00);     // RCT
   2983    output->Write(5, 0b00000);  // Starting from ch 0
   2984    output->Write(2, 0b00);     // YCoCg
   2985  } else {
   2986    output->Write(2, 0b00);  // no transforms
   2987  }
   2988  if (!is_single_group) {
   2989    output->ZeroPadToByte();
   2990  }
   2991 }
   2992 
   2993 template <typename BitDepth>
   2994 struct ChunkEncoder {
   2995  void PrepareForSimd() {
   2996    BitDepth::PrepareForSimd(code->raw_nbits, code->raw_bits, code->numraw,
   2997                             raw_nbits_simd, raw_bits_simd);
   2998  }
   2999  FJXL_INLINE static void EncodeRle(size_t count, const PrefixCode& code,
   3000                                    BitWriter& output) {
   3001    if (count == 0) return;
   3002    count -= kLZ77MinLength + 1;
   3003    if (count < kLZ77CacheSize) {
   3004      output.Write(code.lz77_cache_nbits[count], code.lz77_cache_bits[count]);
   3005    } else {
   3006      unsigned token, nbits, bits;
   3007      EncodeHybridUintLZ77(count, &token, &nbits, &bits);
   3008      uint64_t wbits = bits;
   3009      wbits = (wbits << code.lz77_nbits[token]) | code.lz77_bits[token];
   3010      wbits = (wbits << code.raw_nbits[0]) | code.raw_bits[0];
   3011      output.Write(code.lz77_nbits[token] + nbits + code.raw_nbits[0], wbits);
   3012    }
   3013  }
   3014 
   3015  FJXL_INLINE void Chunk(size_t run, typename BitDepth::upixel_t* residuals,
   3016                         size_t skip, size_t n) {
   3017    EncodeRle(run, *code, *output);
   3018 #ifdef FJXL_GENERIC_SIMD
   3019    BitDepth::EncodeChunkSimd(residuals, n, skip, raw_nbits_simd, raw_bits_simd,
   3020                              *output);
   3021 #else
   3022    GenericEncodeChunk(residuals, n, skip, *code, *output);
   3023 #endif
   3024  }
   3025 
   3026  inline void Finalize(size_t run) { EncodeRle(run, *code, *output); }
   3027 
   3028  const PrefixCode* code;
   3029  BitWriter* output;
   3030  alignas(64) uint8_t raw_nbits_simd[16] = {};
   3031  alignas(64) uint8_t raw_bits_simd[16] = {};
   3032 };
   3033 
   3034 template <typename BitDepth>
   3035 struct ChunkSampleCollector {
   3036  FJXL_INLINE void Rle(size_t count, uint64_t* lz77_counts) {
   3037    if (count == 0) return;
   3038    raw_counts[0] += 1;
   3039    count -= kLZ77MinLength + 1;
   3040    unsigned token, nbits, bits;
   3041    EncodeHybridUintLZ77(count, &token, &nbits, &bits);
   3042    lz77_counts[token]++;
   3043  }
   3044 
   3045  FJXL_INLINE void Chunk(size_t run, typename BitDepth::upixel_t* residuals,
   3046                         size_t skip, size_t n) {
   3047    // Run is broken. Encode the run and encode the individual vector.
   3048    Rle(run, lz77_counts);
   3049    for (size_t ix = skip; ix < n; ix++) {
   3050      unsigned token, nbits, bits;
   3051      EncodeHybridUint000(residuals[ix], &token, &nbits, &bits);
   3052      raw_counts[token]++;
   3053    }
   3054  }
   3055 
   3056  // don't count final run since we don't know how long it really is
   3057  void Finalize(size_t run) {}
   3058 
   3059  uint64_t* raw_counts;
   3060  uint64_t* lz77_counts;
   3061 };
   3062 
   3063 constexpr uint32_t PackSigned(int32_t value) {
   3064  return (static_cast<uint32_t>(value) << 1) ^
   3065         ((static_cast<uint32_t>(~value) >> 31) - 1);
   3066 }
   3067 
   3068 template <typename T, typename BitDepth>
   3069 struct ChannelRowProcessor {
   3070  using upixel_t = typename BitDepth::upixel_t;
   3071  using pixel_t = typename BitDepth::pixel_t;
   3072  T* t;
   3073  void ProcessChunk(const pixel_t* row, const pixel_t* row_left,
   3074                    const pixel_t* row_top, const pixel_t* row_topleft,
   3075                    size_t n) {
   3076    alignas(64) upixel_t residuals[kChunkSize] = {};
   3077    size_t prefix_size = 0;
   3078    size_t required_prefix_size = 0;
   3079 #ifdef FJXL_GENERIC_SIMD
   3080    constexpr size_t kNum =
   3081        sizeof(pixel_t) == 2 ? SIMDVec16::kLanes : SIMDVec32::kLanes;
   3082    for (size_t ix = 0; ix < kChunkSize; ix += kNum) {
   3083      size_t c =
   3084          PredictPixels<simd_t<pixel_t>>(row + ix, row_left + ix, row_top + ix,
   3085                                         row_topleft + ix, residuals + ix);
   3086      prefix_size =
   3087          prefix_size == required_prefix_size ? prefix_size + c : prefix_size;
   3088      required_prefix_size += kNum;
   3089    }
   3090 #else
   3091    for (size_t ix = 0; ix < kChunkSize; ix++) {
   3092      pixel_t px = row[ix];
   3093      pixel_t left = row_left[ix];
   3094      pixel_t top = row_top[ix];
   3095      pixel_t topleft = row_topleft[ix];
   3096      pixel_t ac = left - topleft;
   3097      pixel_t ab = left - top;
   3098      pixel_t bc = top - topleft;
   3099      pixel_t grad = static_cast<pixel_t>(static_cast<upixel_t>(ac) +
   3100                                          static_cast<upixel_t>(top));
   3101      pixel_t d = ab ^ bc;
   3102      pixel_t clamp = d < 0 ? top : left;
   3103      pixel_t s = ac ^ bc;
   3104      pixel_t pred = s < 0 ? grad : clamp;
   3105      residuals[ix] = PackSigned(px - pred);
   3106      prefix_size = prefix_size == required_prefix_size
   3107                        ? prefix_size + (residuals[ix] == 0)
   3108                        : prefix_size;
   3109      required_prefix_size += 1;
   3110    }
   3111 #endif
   3112    prefix_size = std::min(n, prefix_size);
   3113    if (prefix_size == n && (run > 0 || prefix_size > kLZ77MinLength)) {
   3114      // Run continues, nothing to do.
   3115      run += prefix_size;
   3116    } else if (prefix_size + run > kLZ77MinLength) {
   3117      // Run is broken. Encode the run and encode the individual vector.
   3118      t->Chunk(run + prefix_size, residuals, prefix_size, n);
   3119      run = 0;
   3120    } else {
   3121      // There was no run to begin with.
   3122      t->Chunk(0, residuals, 0, n);
   3123    }
   3124  }
   3125 
   3126  void ProcessRow(const pixel_t* row, const pixel_t* row_left,
   3127                  const pixel_t* row_top, const pixel_t* row_topleft,
   3128                  size_t xs) {
   3129    for (size_t x = 0; x < xs; x += kChunkSize) {
   3130      ProcessChunk(row + x, row_left + x, row_top + x, row_topleft + x,
   3131                   std::min(kChunkSize, xs - x));
   3132    }
   3133  }
   3134 
   3135  void Finalize() { t->Finalize(run); }
   3136  // Invariant: run == 0 or run > kLZ77MinLength.
   3137  size_t run = 0;
   3138 };
   3139 
   3140 uint16_t LoadLE16(const unsigned char* ptr) {
   3141  return uint16_t{ptr[0]} | (uint16_t{ptr[1]} << 8);
   3142 }
   3143 
   3144 uint16_t SwapEndian(uint16_t in) { return (in >> 8) | (in << 8); }
   3145 
   3146 #ifdef FJXL_GENERIC_SIMD
   3147 void StorePixels(SIMDVec16 p, int16_t* dest) { p.Store((uint16_t*)dest); }
   3148 
   3149 void StorePixels(SIMDVec16 p, int32_t* dest) {
   3150  VecPair<SIMDVec32> p_up = p.Upcast();
   3151  p_up.low.Store((uint32_t*)dest);
   3152  p_up.hi.Store((uint32_t*)dest + SIMDVec32::kLanes);
   3153 }
   3154 #endif
   3155 
   3156 template <typename pixel_t>
   3157 void FillRowG8(const unsigned char* rgba, size_t oxs, pixel_t* luma) {
   3158  size_t x = 0;
   3159 #ifdef FJXL_GENERIC_SIMD
   3160  for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3161    auto rgb = SIMDVec16::LoadG8(rgba + x);
   3162    StorePixels(rgb[0], luma + x);
   3163  }
   3164 #endif
   3165  for (; x < oxs; x++) {
   3166    luma[x] = rgba[x];
   3167  }
   3168 }
   3169 
   3170 template <bool big_endian, typename pixel_t>
   3171 void FillRowG16(const unsigned char* rgba, size_t oxs, pixel_t* luma) {
   3172  size_t x = 0;
   3173 #ifdef FJXL_GENERIC_SIMD
   3174  for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3175    auto rgb = SIMDVec16::LoadG16(rgba + 2 * x);
   3176    if (big_endian) {
   3177      rgb[0].SwapEndian();
   3178    }
   3179    StorePixels(rgb[0], luma + x);
   3180  }
   3181 #endif
   3182  for (; x < oxs; x++) {
   3183    uint16_t val = LoadLE16(rgba + 2 * x);
   3184    if (big_endian) {
   3185      val = SwapEndian(val);
   3186    }
   3187    luma[x] = val;
   3188  }
   3189 }
   3190 
   3191 template <typename pixel_t>
   3192 void FillRowGA8(const unsigned char* rgba, size_t oxs, pixel_t* luma,
   3193                pixel_t* alpha) {
   3194  size_t x = 0;
   3195 #ifdef FJXL_GENERIC_SIMD
   3196  for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3197    auto rgb = SIMDVec16::LoadGA8(rgba + 2 * x);
   3198    StorePixels(rgb[0], luma + x);
   3199    StorePixels(rgb[1], alpha + x);
   3200  }
   3201 #endif
   3202  for (; x < oxs; x++) {
   3203    luma[x] = rgba[2 * x];
   3204    alpha[x] = rgba[2 * x + 1];
   3205  }
   3206 }
   3207 
   3208 template <bool big_endian, typename pixel_t>
   3209 void FillRowGA16(const unsigned char* rgba, size_t oxs, pixel_t* luma,
   3210                 pixel_t* alpha) {
   3211  size_t x = 0;
   3212 #ifdef FJXL_GENERIC_SIMD
   3213  for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3214    auto rgb = SIMDVec16::LoadGA16(rgba + 4 * x);
   3215    if (big_endian) {
   3216      rgb[0].SwapEndian();
   3217      rgb[1].SwapEndian();
   3218    }
   3219    StorePixels(rgb[0], luma + x);
   3220    StorePixels(rgb[1], alpha + x);
   3221  }
   3222 #endif
   3223  for (; x < oxs; x++) {
   3224    uint16_t l = LoadLE16(rgba + 4 * x);
   3225    uint16_t a = LoadLE16(rgba + 4 * x + 2);
   3226    if (big_endian) {
   3227      l = SwapEndian(l);
   3228      a = SwapEndian(a);
   3229    }
   3230    luma[x] = l;
   3231    alpha[x] = a;
   3232  }
   3233 }
   3234 
   3235 template <typename pixel_t>
   3236 void StoreYCoCg(pixel_t r, pixel_t g, pixel_t b, pixel_t* y, pixel_t* co,
   3237                pixel_t* cg) {
   3238  *co = r - b;
   3239  pixel_t tmp = b + (*co >> 1);
   3240  *cg = g - tmp;
   3241  *y = tmp + (*cg >> 1);
   3242 }
   3243 
   3244 #ifdef FJXL_GENERIC_SIMD
   3245 void StoreYCoCg(SIMDVec16 r, SIMDVec16 g, SIMDVec16 b, int16_t* y, int16_t* co,
   3246                int16_t* cg) {
   3247  SIMDVec16 co_v = r.Sub(b);
   3248  SIMDVec16 tmp = b.Add(co_v.SignedShiftRight<1>());
   3249  SIMDVec16 cg_v = g.Sub(tmp);
   3250  SIMDVec16 y_v = tmp.Add(cg_v.SignedShiftRight<1>());
   3251  y_v.Store(reinterpret_cast<uint16_t*>(y));
   3252  co_v.Store(reinterpret_cast<uint16_t*>(co));
   3253  cg_v.Store(reinterpret_cast<uint16_t*>(cg));
   3254 }
   3255 
   3256 void StoreYCoCg(SIMDVec16 r, SIMDVec16 g, SIMDVec16 b, int32_t* y, int32_t* co,
   3257                int32_t* cg) {
   3258  VecPair<SIMDVec32> r_up = r.Upcast();
   3259  VecPair<SIMDVec32> g_up = g.Upcast();
   3260  VecPair<SIMDVec32> b_up = b.Upcast();
   3261  SIMDVec32 co_lo_v = r_up.low.Sub(b_up.low);
   3262  SIMDVec32 tmp_lo = b_up.low.Add(co_lo_v.SignedShiftRight<1>());
   3263  SIMDVec32 cg_lo_v = g_up.low.Sub(tmp_lo);
   3264  SIMDVec32 y_lo_v = tmp_lo.Add(cg_lo_v.SignedShiftRight<1>());
   3265  SIMDVec32 co_hi_v = r_up.hi.Sub(b_up.hi);
   3266  SIMDVec32 tmp_hi = b_up.hi.Add(co_hi_v.SignedShiftRight<1>());
   3267  SIMDVec32 cg_hi_v = g_up.hi.Sub(tmp_hi);
   3268  SIMDVec32 y_hi_v = tmp_hi.Add(cg_hi_v.SignedShiftRight<1>());
   3269  y_lo_v.Store(reinterpret_cast<uint32_t*>(y));
   3270  co_lo_v.Store(reinterpret_cast<uint32_t*>(co));
   3271  cg_lo_v.Store(reinterpret_cast<uint32_t*>(cg));
   3272  y_hi_v.Store(reinterpret_cast<uint32_t*>(y) + SIMDVec32::kLanes);
   3273  co_hi_v.Store(reinterpret_cast<uint32_t*>(co) + SIMDVec32::kLanes);
   3274  cg_hi_v.Store(reinterpret_cast<uint32_t*>(cg) + SIMDVec32::kLanes);
   3275 }
   3276 #endif
   3277 
   3278 template <typename pixel_t>
   3279 void FillRowRGB8(const unsigned char* rgba, size_t oxs, pixel_t* y, pixel_t* co,
   3280                 pixel_t* cg) {
   3281  size_t x = 0;
   3282 #ifdef FJXL_GENERIC_SIMD
   3283  for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3284    auto rgb = SIMDVec16::LoadRGB8(rgba + 3 * x);
   3285    StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
   3286  }
   3287 #endif
   3288  for (; x < oxs; x++) {
   3289    uint16_t r = rgba[3 * x];
   3290    uint16_t g = rgba[3 * x + 1];
   3291    uint16_t b = rgba[3 * x + 2];
   3292    StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
   3293  }
   3294 }
   3295 
   3296 template <bool big_endian, typename pixel_t>
   3297 void FillRowRGB16(const unsigned char* rgba, size_t oxs, pixel_t* y,
   3298                  pixel_t* co, pixel_t* cg) {
   3299  size_t x = 0;
   3300 #ifdef FJXL_GENERIC_SIMD
   3301  for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3302    auto rgb = SIMDVec16::LoadRGB16(rgba + 6 * x);
   3303    if (big_endian) {
   3304      rgb[0].SwapEndian();
   3305      rgb[1].SwapEndian();
   3306      rgb[2].SwapEndian();
   3307    }
   3308    StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
   3309  }
   3310 #endif
   3311  for (; x < oxs; x++) {
   3312    uint16_t r = LoadLE16(rgba + 6 * x);
   3313    uint16_t g = LoadLE16(rgba + 6 * x + 2);
   3314    uint16_t b = LoadLE16(rgba + 6 * x + 4);
   3315    if (big_endian) {
   3316      r = SwapEndian(r);
   3317      g = SwapEndian(g);
   3318      b = SwapEndian(b);
   3319    }
   3320    StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
   3321  }
   3322 }
   3323 
   3324 template <typename pixel_t>
   3325 void FillRowRGBA8(const unsigned char* rgba, size_t oxs, pixel_t* y,
   3326                  pixel_t* co, pixel_t* cg, pixel_t* alpha) {
   3327  size_t x = 0;
   3328 #ifdef FJXL_GENERIC_SIMD
   3329  for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3330    auto rgb = SIMDVec16::LoadRGBA8(rgba + 4 * x);
   3331    StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
   3332    StorePixels(rgb[3], alpha + x);
   3333  }
   3334 #endif
   3335  for (; x < oxs; x++) {
   3336    uint16_t r = rgba[4 * x];
   3337    uint16_t g = rgba[4 * x + 1];
   3338    uint16_t b = rgba[4 * x + 2];
   3339    uint16_t a = rgba[4 * x + 3];
   3340    StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
   3341    alpha[x] = a;
   3342  }
   3343 }
   3344 
   3345 template <bool big_endian, typename pixel_t>
   3346 void FillRowRGBA16(const unsigned char* rgba, size_t oxs, pixel_t* y,
   3347                   pixel_t* co, pixel_t* cg, pixel_t* alpha) {
   3348  size_t x = 0;
   3349 #ifdef FJXL_GENERIC_SIMD
   3350  for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
   3351    auto rgb = SIMDVec16::LoadRGBA16(rgba + 8 * x);
   3352    if (big_endian) {
   3353      rgb[0].SwapEndian();
   3354      rgb[1].SwapEndian();
   3355      rgb[2].SwapEndian();
   3356      rgb[3].SwapEndian();
   3357    }
   3358    StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
   3359    StorePixels(rgb[3], alpha + x);
   3360  }
   3361 #endif
   3362  for (; x < oxs; x++) {
   3363    uint16_t r = LoadLE16(rgba + 8 * x);
   3364    uint16_t g = LoadLE16(rgba + 8 * x + 2);
   3365    uint16_t b = LoadLE16(rgba + 8 * x + 4);
   3366    uint16_t a = LoadLE16(rgba + 8 * x + 6);
   3367    if (big_endian) {
   3368      r = SwapEndian(r);
   3369      g = SwapEndian(g);
   3370      b = SwapEndian(b);
   3371      a = SwapEndian(a);
   3372    }
   3373    StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
   3374    alpha[x] = a;
   3375  }
   3376 }
   3377 
   3378 template <typename Processor, typename BitDepth>
   3379 void ProcessImageArea(const unsigned char* rgba, size_t x0, size_t y0,
   3380                      size_t xs, size_t yskip, size_t ys, size_t row_stride,
   3381                      BitDepth bitdepth, size_t nb_chans, bool big_endian,
   3382                      Processor* processors) {
   3383  constexpr size_t kPadding = 32;
   3384 
   3385  using pixel_t = typename BitDepth::pixel_t;
   3386 
   3387  constexpr size_t kAlign = 64;
   3388  constexpr size_t kAlignPixels = kAlign / sizeof(pixel_t);
   3389 
   3390  auto align = [=](pixel_t* ptr) {
   3391    size_t offset = reinterpret_cast<uintptr_t>(ptr) % kAlign;
   3392    if (offset) {
   3393      ptr += offset / sizeof(pixel_t);
   3394    }
   3395    return ptr;
   3396  };
   3397 
   3398  constexpr size_t kNumPx =
   3399      (256 + kPadding * 2 + kAlignPixels + kAlignPixels - 1) / kAlignPixels *
   3400      kAlignPixels;
   3401 
   3402  std::vector<std::array<std::array<pixel_t, kNumPx>, 2>> group_data(nb_chans);
   3403 
   3404  for (size_t y = 0; y < ys; y++) {
   3405    const auto rgba_row =
   3406        rgba + row_stride * (y0 + y) + x0 * nb_chans * BitDepth::kInputBytes;
   3407    pixel_t* crow[4] = {};
   3408    pixel_t* prow[4] = {};
   3409    for (size_t i = 0; i < nb_chans; i++) {
   3410      crow[i] = align(&group_data[i][y & 1][kPadding]);
   3411      prow[i] = align(&group_data[i][(y - 1) & 1][kPadding]);
   3412    }
   3413 
   3414    // Pre-fill rows with YCoCg converted pixels.
   3415    if (nb_chans == 1) {
   3416      if (BitDepth::kInputBytes == 1) {
   3417        FillRowG8(rgba_row, xs, crow[0]);
   3418      } else if (big_endian) {
   3419        FillRowG16</*big_endian=*/true>(rgba_row, xs, crow[0]);
   3420      } else {
   3421        FillRowG16</*big_endian=*/false>(rgba_row, xs, crow[0]);
   3422      }
   3423    } else if (nb_chans == 2) {
   3424      if (BitDepth::kInputBytes == 1) {
   3425        FillRowGA8(rgba_row, xs, crow[0], crow[1]);
   3426      } else if (big_endian) {
   3427        FillRowGA16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1]);
   3428      } else {
   3429        FillRowGA16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1]);
   3430      }
   3431    } else if (nb_chans == 3) {
   3432      if (BitDepth::kInputBytes == 1) {
   3433        FillRowRGB8(rgba_row, xs, crow[0], crow[1], crow[2]);
   3434      } else if (big_endian) {
   3435        FillRowRGB16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1],
   3436                                          crow[2]);
   3437      } else {
   3438        FillRowRGB16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1],
   3439                                           crow[2]);
   3440      }
   3441    } else {
   3442      if (BitDepth::kInputBytes == 1) {
   3443        FillRowRGBA8(rgba_row, xs, crow[0], crow[1], crow[2], crow[3]);
   3444      } else if (big_endian) {
   3445        FillRowRGBA16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1],
   3446                                           crow[2], crow[3]);
   3447      } else {
   3448        FillRowRGBA16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1],
   3449                                            crow[2], crow[3]);
   3450      }
   3451    }
   3452    // Deal with x == 0.
   3453    for (size_t c = 0; c < nb_chans; c++) {
   3454      *(crow[c] - 1) = y > 0 ? *(prow[c]) : 0;
   3455      // Fix topleft.
   3456      *(prow[c] - 1) = y > 0 ? *(prow[c]) : 0;
   3457    }
   3458    if (y < yskip) continue;
   3459    for (size_t c = 0; c < nb_chans; c++) {
   3460      // Get pointers to px/left/top/topleft data to speedup loop.
   3461      const pixel_t* row = crow[c];
   3462      const pixel_t* row_left = crow[c] - 1;
   3463      const pixel_t* row_top = y == 0 ? row_left : prow[c];
   3464      const pixel_t* row_topleft = y == 0 ? row_left : prow[c] - 1;
   3465 
   3466      processors[c].ProcessRow(row, row_left, row_top, row_topleft, xs);
   3467    }
   3468  }
   3469  for (size_t c = 0; c < nb_chans; c++) {
   3470    processors[c].Finalize();
   3471  }
   3472 }
   3473 
   3474 template <typename BitDepth>
   3475 void WriteACSection(const unsigned char* rgba, size_t x0, size_t y0, size_t xs,
   3476                    size_t ys, size_t row_stride, bool is_single_group,
   3477                    BitDepth bitdepth, size_t nb_chans, bool big_endian,
   3478                    const PrefixCode code[4],
   3479                    std::array<BitWriter, 4>& output) {
   3480  for (size_t i = 0; i < nb_chans; i++) {
   3481    if (is_single_group && i == 0) continue;
   3482    output[i].Allocate(xs * ys * bitdepth.MaxEncodedBitsPerSample() + 4);
   3483  }
   3484  if (!is_single_group) {
   3485    // Group header for modular image.
   3486    // When the image is single-group, the global modular image is the one
   3487    // that contains the pixel data, and there is no group header.
   3488    output[0].Write(1, 1);     // Global tree
   3489    output[0].Write(1, 1);     // All default wp
   3490    output[0].Write(2, 0b00);  // 0 transforms
   3491  }
   3492 
   3493  ChunkEncoder<BitDepth> encoders[4];
   3494  ChannelRowProcessor<ChunkEncoder<BitDepth>, BitDepth> row_encoders[4];
   3495  for (size_t c = 0; c < nb_chans; c++) {
   3496    row_encoders[c].t = &encoders[c];
   3497    encoders[c].output = &output[c];
   3498    encoders[c].code = &code[c];
   3499    encoders[c].PrepareForSimd();
   3500  }
   3501  ProcessImageArea<ChannelRowProcessor<ChunkEncoder<BitDepth>, BitDepth>>(
   3502      rgba, x0, y0, xs, 0, ys, row_stride, bitdepth, nb_chans, big_endian,
   3503      row_encoders);
   3504 }
   3505 
   3506 constexpr int kHashExp = 16;
   3507 constexpr uint32_t kHashSize = 1 << kHashExp;
   3508 constexpr uint32_t kHashMultiplier = 2654435761;
   3509 constexpr int kMaxColors = 512;
   3510 
   3511 // can be any function that returns a value in 0 .. kHashSize-1
   3512 // has to map 0 to 0
   3513 inline uint32_t pixel_hash(uint32_t p) {
   3514  return (p * kHashMultiplier) >> (32 - kHashExp);
   3515 }
   3516 
   3517 template <size_t nb_chans>
   3518 void FillRowPalette(const unsigned char* inrow, size_t xs,
   3519                    const int16_t* lookup, int16_t* out) {
   3520  for (size_t x = 0; x < xs; x++) {
   3521    uint32_t p = 0;
   3522    for (size_t i = 0; i < nb_chans; ++i) {
   3523      p |= inrow[x * nb_chans + i] << (8 * i);
   3524    }
   3525    out[x] = lookup[pixel_hash(p)];
   3526  }
   3527 }
   3528 
   3529 template <typename Processor>
   3530 void ProcessImageAreaPalette(const unsigned char* rgba, size_t x0, size_t y0,
   3531                             size_t xs, size_t yskip, size_t ys,
   3532                             size_t row_stride, const int16_t* lookup,
   3533                             size_t nb_chans, Processor* processors) {
   3534  constexpr size_t kPadding = 32;
   3535 
   3536  std::vector<std::array<int16_t, 256 + kPadding * 2>> group_data(2);
   3537  Processor& row_encoder = processors[0];
   3538 
   3539  for (size_t y = 0; y < ys; y++) {
   3540    // Pre-fill rows with palette converted pixels.
   3541    const unsigned char* inrow = rgba + row_stride * (y0 + y) + x0 * nb_chans;
   3542    int16_t* outrow = &group_data[y & 1][kPadding];
   3543    if (nb_chans == 1) {
   3544      FillRowPalette<1>(inrow, xs, lookup, outrow);
   3545    } else if (nb_chans == 2) {
   3546      FillRowPalette<2>(inrow, xs, lookup, outrow);
   3547    } else if (nb_chans == 3) {
   3548      FillRowPalette<3>(inrow, xs, lookup, outrow);
   3549    } else if (nb_chans == 4) {
   3550      FillRowPalette<4>(inrow, xs, lookup, outrow);
   3551    }
   3552    // Deal with x == 0.
   3553    group_data[y & 1][kPadding - 1] =
   3554        y > 0 ? group_data[(y - 1) & 1][kPadding] : 0;
   3555    // Fix topleft.
   3556    group_data[(y - 1) & 1][kPadding - 1] =
   3557        y > 0 ? group_data[(y - 1) & 1][kPadding] : 0;
   3558    // Get pointers to px/left/top/topleft data to speedup loop.
   3559    const int16_t* row = &group_data[y & 1][kPadding];
   3560    const int16_t* row_left = &group_data[y & 1][kPadding - 1];
   3561    const int16_t* row_top =
   3562        y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding];
   3563    const int16_t* row_topleft =
   3564        y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding - 1];
   3565 
   3566    row_encoder.ProcessRow(row, row_left, row_top, row_topleft, xs);
   3567  }
   3568  row_encoder.Finalize();
   3569 }
   3570 
   3571 void WriteACSectionPalette(const unsigned char* rgba, size_t x0, size_t y0,
   3572                           size_t xs, size_t ys, size_t row_stride,
   3573                           bool is_single_group, const PrefixCode code[4],
   3574                           const int16_t* lookup, size_t nb_chans,
   3575                           BitWriter& output) {
   3576  if (!is_single_group) {
   3577    output.Allocate(16 * xs * ys + 4);
   3578    // Group header for modular image.
   3579    // When the image is single-group, the global modular image is the one
   3580    // that contains the pixel data, and there is no group header.
   3581    output.Write(1, 1);     // Global tree
   3582    output.Write(1, 1);     // All default wp
   3583    output.Write(2, 0b00);  // 0 transforms
   3584  }
   3585 
   3586  ChunkEncoder<UpTo8Bits> encoder;
   3587  ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits> row_encoder;
   3588 
   3589  row_encoder.t = &encoder;
   3590  encoder.output = &output;
   3591  encoder.code = &code[is_single_group ? 1 : 0];
   3592  encoder.PrepareForSimd();
   3593  ProcessImageAreaPalette<
   3594      ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits>>(
   3595      rgba, x0, y0, xs, 0, ys, row_stride, lookup, nb_chans, &row_encoder);
   3596 }
   3597 
   3598 template <typename BitDepth>
   3599 void CollectSamples(const unsigned char* rgba, size_t x0, size_t y0, size_t xs,
   3600                    size_t row_stride, size_t row_count,
   3601                    uint64_t raw_counts[4][kNumRawSymbols],
   3602                    uint64_t lz77_counts[4][kNumLZ77], bool is_single_group,
   3603                    bool palette, BitDepth bitdepth, size_t nb_chans,
   3604                    bool big_endian, const int16_t* lookup) {
   3605  if (palette) {
   3606    ChunkSampleCollector<UpTo8Bits> sample_collectors[4];
   3607    ChannelRowProcessor<ChunkSampleCollector<UpTo8Bits>, UpTo8Bits>
   3608        row_sample_collectors[4];
   3609    for (size_t c = 0; c < nb_chans; c++) {
   3610      row_sample_collectors[c].t = &sample_collectors[c];
   3611      sample_collectors[c].raw_counts = raw_counts[is_single_group ? 1 : 0];
   3612      sample_collectors[c].lz77_counts = lz77_counts[is_single_group ? 1 : 0];
   3613    }
   3614    ProcessImageAreaPalette<
   3615        ChannelRowProcessor<ChunkSampleCollector<UpTo8Bits>, UpTo8Bits>>(
   3616        rgba, x0, y0, xs, 1, 1 + row_count, row_stride, lookup, nb_chans,
   3617        row_sample_collectors);
   3618  } else {
   3619    ChunkSampleCollector<BitDepth> sample_collectors[4];
   3620    ChannelRowProcessor<ChunkSampleCollector<BitDepth>, BitDepth>
   3621        row_sample_collectors[4];
   3622    for (size_t c = 0; c < nb_chans; c++) {
   3623      row_sample_collectors[c].t = &sample_collectors[c];
   3624      sample_collectors[c].raw_counts = raw_counts[c];
   3625      sample_collectors[c].lz77_counts = lz77_counts[c];
   3626    }
   3627    ProcessImageArea<
   3628        ChannelRowProcessor<ChunkSampleCollector<BitDepth>, BitDepth>>(
   3629        rgba, x0, y0, xs, 1, 1 + row_count, row_stride, bitdepth, nb_chans,
   3630        big_endian, row_sample_collectors);
   3631  }
   3632 }
   3633 
   3634 void PrepareDCGlobalPalette(bool is_single_group, size_t width, size_t height,
   3635                            size_t nb_chans, const PrefixCode code[4],
   3636                            const std::vector<uint32_t>& palette,
   3637                            size_t pcolors, BitWriter* output) {
   3638  PrepareDCGlobalCommon(is_single_group, width, height, code, output);
   3639  output->Write(2, 0b01);     // 1 transform
   3640  output->Write(2, 0b01);     // Palette
   3641  output->Write(5, 0b00000);  // Starting from ch 0
   3642  if (nb_chans == 1) {
   3643    output->Write(2, 0b00);  // 1-channel palette (Gray)
   3644  } else if (nb_chans == 3) {
   3645    output->Write(2, 0b01);  // 3-channel palette (RGB)
   3646  } else if (nb_chans == 4) {
   3647    output->Write(2, 0b10);  // 4-channel palette (RGBA)
   3648  } else {
   3649    output->Write(2, 0b11);
   3650    output->Write(13, nb_chans - 1);
   3651  }
   3652  // pcolors <= kMaxColors + kChunkSize - 1
   3653  static_assert(kMaxColors + kChunkSize < 1281,
   3654                "add code to signal larger palette sizes");
   3655  if (pcolors < 256) {
   3656    output->Write(2, 0b00);
   3657    output->Write(8, pcolors);
   3658  } else {
   3659    output->Write(2, 0b01);
   3660    output->Write(10, pcolors - 256);
   3661  }
   3662 
   3663  output->Write(2, 0b00);  // nb_deltas == 0
   3664  output->Write(4, 0);     // Zero predictor for delta palette
   3665  // Encode palette
   3666  ChunkEncoder<UpTo8Bits> encoder;
   3667  ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits> row_encoder;
   3668  row_encoder.t = &encoder;
   3669  encoder.output = output;
   3670  encoder.code = &code[0];
   3671  encoder.PrepareForSimd();
   3672  int16_t p[4][32 + 1024] = {};
   3673  size_t i = 0;
   3674  size_t have_zero = 1;
   3675  for (; i < pcolors; i++) {
   3676    p[0][16 + i + have_zero] = palette[i] & 0xFF;
   3677    p[1][16 + i + have_zero] = (palette[i] >> 8) & 0xFF;
   3678    p[2][16 + i + have_zero] = (palette[i] >> 16) & 0xFF;
   3679    p[3][16 + i + have_zero] = (palette[i] >> 24) & 0xFF;
   3680  }
   3681  p[0][15] = 0;
   3682  row_encoder.ProcessRow(p[0] + 16, p[0] + 15, p[0] + 15, p[0] + 15, pcolors);
   3683  p[1][15] = p[0][16];
   3684  p[0][15] = p[0][16];
   3685  if (nb_chans > 1) {
   3686    row_encoder.ProcessRow(p[1] + 16, p[1] + 15, p[0] + 16, p[0] + 15, pcolors);
   3687  }
   3688  p[2][15] = p[1][16];
   3689  p[1][15] = p[1][16];
   3690  if (nb_chans > 2) {
   3691    row_encoder.ProcessRow(p[2] + 16, p[2] + 15, p[1] + 16, p[1] + 15, pcolors);
   3692  }
   3693  p[3][15] = p[2][16];
   3694  p[2][15] = p[2][16];
   3695  if (nb_chans > 3) {
   3696    row_encoder.ProcessRow(p[3] + 16, p[3] + 15, p[2] + 16, p[2] + 15, pcolors);
   3697  }
   3698  row_encoder.Finalize();
   3699 
   3700  if (!is_single_group) {
   3701    output->ZeroPadToByte();
   3702  }
   3703 }
   3704 
   3705 template <size_t nb_chans>
   3706 bool detect_palette(const unsigned char* r, size_t width,
   3707                    std::vector<uint32_t>& palette) {
   3708  size_t x = 0;
   3709  bool collided = false;
   3710  // this is just an unrolling of the next loop
   3711  size_t look_ahead = 7 + ((nb_chans == 1) ? 3 : ((nb_chans < 4) ? 1 : 0));
   3712  for (; x + look_ahead < width; x += 8) {
   3713    uint32_t p[8] = {}, index[8];
   3714    for (int i = 0; i < 8; i++) {
   3715      for (int j = 0; j < 4; ++j) {
   3716        p[i] |= r[(x + i) * nb_chans + j] << (8 * j);
   3717      }
   3718    }
   3719    for (int i = 0; i < 8; i++) p[i] &= ((1llu << (8 * nb_chans)) - 1);
   3720    for (int i = 0; i < 8; i++) index[i] = pixel_hash(p[i]);
   3721    for (int i = 0; i < 8; i++) {
   3722      collided |= (palette[index[i]] != 0 && p[i] != palette[index[i]]);
   3723    }
   3724    for (int i = 0; i < 8; i++) palette[index[i]] = p[i];
   3725  }
   3726  for (; x < width; x++) {
   3727    uint32_t p = 0;
   3728    for (size_t i = 0; i < nb_chans; ++i) {
   3729      p |= r[x * nb_chans + i] << (8 * i);
   3730    }
   3731    uint32_t index = pixel_hash(p);
   3732    collided |= (palette[index] != 0 && p != palette[index]);
   3733    palette[index] = p;
   3734  }
   3735  return collided;
   3736 }
   3737 
   3738 template <typename BitDepth>
   3739 JxlFastLosslessFrameState* LLPrepare(JxlChunkedFrameInputSource input,
   3740                                     size_t width, size_t height,
   3741                                     BitDepth bitdepth, size_t nb_chans,
   3742                                     bool big_endian, int effort, int oneshot) {
   3743  assert(width != 0);
   3744  assert(height != 0);
   3745 
   3746  // Count colors to try palette
   3747  std::vector<uint32_t> palette(kHashSize);
   3748  std::vector<int16_t> lookup(kHashSize);
   3749  lookup[0] = 0;
   3750  int pcolors = 0;
   3751  bool collided = effort < 2 || bitdepth.bitdepth != 8 || !oneshot;
   3752  for (size_t y0 = 0; y0 < height && !collided; y0 += 256) {
   3753    size_t ys = std::min<size_t>(height - y0, 256);
   3754    for (size_t x0 = 0; x0 < width && !collided; x0 += 256) {
   3755      size_t xs = std::min<size_t>(width - x0, 256);
   3756      size_t stride;
   3757      // TODO(szabadka): Add RAII wrapper around this.
   3758      const void* buffer = input.get_color_channel_data_at(input.opaque, x0, y0,
   3759                                                           xs, ys, &stride);
   3760      auto rgba = reinterpret_cast<const unsigned char*>(buffer);
   3761      for (size_t y = 0; y < ys && !collided; y++) {
   3762        const unsigned char* r = rgba + stride * y;
   3763        if (nb_chans == 1) collided = detect_palette<1>(r, xs, palette);
   3764        if (nb_chans == 2) collided = detect_palette<2>(r, xs, palette);
   3765        if (nb_chans == 3) collided = detect_palette<3>(r, xs, palette);
   3766        if (nb_chans == 4) collided = detect_palette<4>(r, xs, palette);
   3767      }
   3768      input.release_buffer(input.opaque, buffer);
   3769    }
   3770  }
   3771  int nb_entries = 0;
   3772  if (!collided) {
   3773    pcolors = 1;  // always have all-zero as a palette color
   3774    bool have_color = false;
   3775    uint8_t minG = 255, maxG = 0;
   3776    for (uint32_t k = 0; k < kHashSize; k++) {
   3777      if (palette[k] == 0) continue;
   3778      uint8_t p[4];
   3779      for (int i = 0; i < 4; ++i) {
   3780        p[i] = (palette[k] >> (8 * i)) & 0xFF;
   3781      }
   3782      // move entries to front so sort has less work
   3783      palette[nb_entries] = palette[k];
   3784      if (p[0] != p[1] || p[0] != p[2]) have_color = true;
   3785      if (p[1] < minG) minG = p[1];
   3786      if (p[1] > maxG) maxG = p[1];
   3787      nb_entries++;
   3788      // don't do palette if too many colors are needed
   3789      if (nb_entries + pcolors > kMaxColors) {
   3790        collided = true;
   3791        break;
   3792      }
   3793    }
   3794    if (!have_color) {
   3795      // don't do palette if it's just grayscale without many holes
   3796      if (maxG - minG < nb_entries * 1.4f) collided = true;
   3797    }
   3798  }
   3799  if (!collided) {
   3800    std::sort(
   3801        palette.begin(), palette.begin() + nb_entries,
   3802        [&nb_chans](uint32_t ap, uint32_t bp) {
   3803          if (ap == 0) return false;
   3804          if (bp == 0) return true;
   3805          uint8_t a[4], b[4];
   3806          for (int i = 0; i < 4; ++i) {
   3807            a[i] = (ap >> (8 * i)) & 0xFF;
   3808            b[i] = (bp >> (8 * i)) & 0xFF;
   3809          }
   3810          float ay, by;
   3811          if (nb_chans == 4) {
   3812            ay = (0.299f * a[0] + 0.587f * a[1] + 0.114f * a[2] + 0.01f) * a[3];
   3813            by = (0.299f * b[0] + 0.587f * b[1] + 0.114f * b[2] + 0.01f) * b[3];
   3814          } else {
   3815            ay = (0.299f * a[0] + 0.587f * a[1] + 0.114f * a[2] + 0.01f);
   3816            by = (0.299f * b[0] + 0.587f * b[1] + 0.114f * b[2] + 0.01f);
   3817          }
   3818          return ay < by;  // sort on alpha*luma
   3819        });
   3820    for (int k = 0; k < nb_entries; k++) {
   3821      if (palette[k] == 0) break;
   3822      lookup[pixel_hash(palette[k])] = pcolors++;
   3823    }
   3824  }
   3825 
   3826  size_t num_groups_x = (width + 255) / 256;
   3827  size_t num_groups_y = (height + 255) / 256;
   3828  size_t num_dc_groups_x = (width + 2047) / 2048;
   3829  size_t num_dc_groups_y = (height + 2047) / 2048;
   3830 
   3831  uint64_t raw_counts[4][kNumRawSymbols] = {};
   3832  uint64_t lz77_counts[4][kNumLZ77] = {};
   3833 
   3834  bool onegroup = num_groups_x == 1 && num_groups_y == 1;
   3835 
   3836  auto sample_rows = [&](size_t xg, size_t yg, size_t num_rows) {
   3837    size_t y0 = yg * 256;
   3838    size_t x0 = xg * 256;
   3839    size_t ys = std::min<size_t>(height - y0, 256);
   3840    size_t xs = std::min<size_t>(width - x0, 256);
   3841    size_t stride;
   3842    const void* buffer =
   3843        input.get_color_channel_data_at(input.opaque, x0, y0, xs, ys, &stride);
   3844    auto rgba = reinterpret_cast<const unsigned char*>(buffer);
   3845    int y_begin_group =
   3846        std::max<ssize_t>(
   3847            0, static_cast<ssize_t>(ys) - static_cast<ssize_t>(num_rows)) /
   3848        2;
   3849    int y_count = std::min<int>(num_rows, ys - y_begin_group);
   3850    int x_max = xs / kChunkSize * kChunkSize;
   3851    CollectSamples(rgba, 0, y_begin_group, x_max, stride, y_count, raw_counts,
   3852                   lz77_counts, onegroup, !collided, bitdepth, nb_chans,
   3853                   big_endian, lookup.data());
   3854    input.release_buffer(input.opaque, buffer);
   3855  };
   3856 
   3857  // TODO(veluca): that `64` is an arbitrary constant, meant to correspond to
   3858  // the point where the number of processed rows is large enough that loading
   3859  // the entire image is cost-effective.
   3860  if (oneshot || effort >= 64) {
   3861    for (size_t g = 0; g < num_groups_y * num_groups_x; g++) {
   3862      size_t xg = g % num_groups_x;
   3863      size_t yg = g / num_groups_x;
   3864      size_t y0 = yg * 256;
   3865      size_t ys = std::min<size_t>(height - y0, 256);
   3866      size_t num_rows = 2 * effort * ys / 256;
   3867      sample_rows(xg, yg, num_rows);
   3868    }
   3869  } else {
   3870    // sample the middle (effort * 2 * num_groups) rows of the center group
   3871    // (possibly all of them).
   3872    sample_rows((num_groups_x - 1) / 2, (num_groups_y - 1) / 2,
   3873                2 * effort * num_groups_x * num_groups_y);
   3874  }
   3875 
   3876  // TODO(veluca): can probably improve this and make it bitdepth-dependent.
   3877  uint64_t base_raw_counts[kNumRawSymbols] = {
   3878      3843, 852, 1270, 1214, 1014, 727, 481, 300, 159, 51,
   3879      5,    1,   1,    1,    1,    1,   1,   1,   1};
   3880 
   3881  bool doing_ycocg = nb_chans > 2 && collided;
   3882  bool large_palette = !collided || pcolors >= 256;
   3883  for (size_t i = bitdepth.NumSymbols(doing_ycocg || large_palette);
   3884       i < kNumRawSymbols; i++) {
   3885    base_raw_counts[i] = 0;
   3886  }
   3887 
   3888  for (size_t c = 0; c < 4; c++) {
   3889    for (size_t i = 0; i < kNumRawSymbols; i++) {
   3890      raw_counts[c][i] = (raw_counts[c][i] << 8) + base_raw_counts[i];
   3891    }
   3892  }
   3893 
   3894  if (!collided) {
   3895    unsigned token, nbits, bits;
   3896    EncodeHybridUint000(PackSigned(pcolors - 1), &token, &nbits, &bits);
   3897    // ensure all palette indices can actually be encoded
   3898    for (size_t i = 0; i < token + 1; i++)
   3899      raw_counts[0][i] = std::max<uint64_t>(raw_counts[0][i], 1);
   3900    // these tokens are only used for the palette itself so they can get a bad
   3901    // code
   3902    for (size_t i = token + 1; i < 10; i++) raw_counts[0][i] = 1;
   3903  }
   3904 
   3905  uint64_t base_lz77_counts[kNumLZ77] = {
   3906      29, 27, 25,  23, 21, 21, 19, 18, 21, 17, 16, 15, 15, 14,
   3907      13, 13, 137, 98, 61, 34, 1,  1,  1,  1,  1,  1,  1,  1,
   3908  };
   3909 
   3910  for (size_t c = 0; c < 4; c++) {
   3911    for (size_t i = 0; i < kNumLZ77; i++) {
   3912      lz77_counts[c][i] = (lz77_counts[c][i] << 8) + base_lz77_counts[i];
   3913    }
   3914  }
   3915 
   3916  JxlFastLosslessFrameState* frame_state = new JxlFastLosslessFrameState();
   3917  for (size_t i = 0; i < 4; i++) {
   3918    frame_state->hcode[i] = PrefixCode(bitdepth, raw_counts[i], lz77_counts[i]);
   3919  }
   3920 
   3921  size_t num_dc_groups = num_dc_groups_x * num_dc_groups_y;
   3922  size_t num_ac_groups = num_groups_x * num_groups_y;
   3923  size_t num_groups = onegroup ? 1 : (2 + num_dc_groups + num_ac_groups);
   3924  frame_state->input = input;
   3925  frame_state->width = width;
   3926  frame_state->height = height;
   3927  frame_state->num_groups_x = num_groups_x;
   3928  frame_state->num_groups_y = num_groups_y;
   3929  frame_state->num_dc_groups_x = num_dc_groups_x;
   3930  frame_state->num_dc_groups_y = num_dc_groups_y;
   3931  frame_state->nb_chans = nb_chans;
   3932  frame_state->bitdepth = bitdepth.bitdepth;
   3933  frame_state->big_endian = big_endian;
   3934  frame_state->effort = effort;
   3935  frame_state->collided = collided;
   3936  frame_state->lookup = lookup;
   3937 
   3938  frame_state->group_data = std::vector<std::array<BitWriter, 4>>(num_groups);
   3939  frame_state->group_sizes.resize(num_groups);
   3940  if (collided) {
   3941    PrepareDCGlobal(onegroup, width, height, nb_chans, frame_state->hcode,
   3942                    &frame_state->group_data[0][0]);
   3943  } else {
   3944    PrepareDCGlobalPalette(onegroup, width, height, nb_chans,
   3945                           frame_state->hcode, palette, pcolors,
   3946                           &frame_state->group_data[0][0]);
   3947  }
   3948  frame_state->group_sizes[0] = SectionSize(frame_state->group_data[0]);
   3949  if (!onegroup) {
   3950    ComputeAcGroupDataOffset(frame_state->group_sizes[0], num_dc_groups,
   3951                             num_ac_groups, frame_state->min_dc_global_size,
   3952                             frame_state->ac_group_data_offset);
   3953  }
   3954 
   3955  return frame_state;
   3956 }
   3957 
   3958 template <typename BitDepth>
   3959 jxl::Status LLProcess(JxlFastLosslessFrameState* frame_state, bool is_last,
   3960                      BitDepth bitdepth, void* runner_opaque,
   3961                      FJxlParallelRunner runner,
   3962                      JxlEncoderOutputProcessorWrapper* output_processor) {
   3963 #if !FJXL_STANDALONE
   3964  if (frame_state->process_done) {
   3965    JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/0, is_last);
   3966    if (output_processor) {
   3967      JXL_RETURN_IF_ERROR(
   3968          JxlFastLosslessOutputFrame(frame_state, output_processor));
   3969    }
   3970    return true;
   3971  }
   3972 #endif
   3973  // The maximum number of groups that we process concurrently here.
   3974  // TODO(szabadka) Use the number of threads or some outside parameter for the
   3975  // maximum memory usage instead.
   3976  constexpr size_t kMaxLocalGroups = 16;
   3977  bool onegroup = frame_state->group_sizes.size() == 1;
   3978  bool streaming = !onegroup && output_processor;
   3979  size_t total_groups = frame_state->num_groups_x * frame_state->num_groups_y;
   3980  size_t max_groups = streaming ? kMaxLocalGroups : total_groups;
   3981 #if !FJXL_STANDALONE
   3982  size_t start_pos = 0;
   3983  if (streaming) {
   3984    start_pos = output_processor->CurrentPosition();
   3985    JXL_RETURN_IF_ERROR(
   3986        output_processor->Seek(start_pos + frame_state->ac_group_data_offset));
   3987  }
   3988 #endif
   3989  for (size_t offset = 0; offset < total_groups; offset += max_groups) {
   3990    size_t num_groups = std::min(max_groups, total_groups - offset);
   3991    JxlFastLosslessFrameState local_frame_state;
   3992    if (streaming) {
   3993      local_frame_state.group_data =
   3994          std::vector<std::array<BitWriter, 4>>(num_groups);
   3995    }
   3996    auto run_one = [&](size_t i) {
   3997      size_t g = offset + i;
   3998      size_t xg = g % frame_state->num_groups_x;
   3999      size_t yg = g / frame_state->num_groups_x;
   4000      size_t num_dc_groups =
   4001          frame_state->num_dc_groups_x * frame_state->num_dc_groups_y;
   4002      size_t group_id = onegroup ? 0 : (2 + num_dc_groups + g);
   4003      size_t xs = std::min<size_t>(frame_state->width - xg * 256, 256);
   4004      size_t ys = std::min<size_t>(frame_state->height - yg * 256, 256);
   4005      size_t x0 = xg * 256;
   4006      size_t y0 = yg * 256;
   4007      size_t stride;
   4008      JxlChunkedFrameInputSource input = frame_state->input;
   4009      const void* buffer = input.get_color_channel_data_at(input.opaque, x0, y0,
   4010                                                           xs, ys, &stride);
   4011      const unsigned char* rgba =
   4012          reinterpret_cast<const unsigned char*>(buffer);
   4013 
   4014      auto& gd = streaming ? local_frame_state.group_data[i]
   4015                           : frame_state->group_data[group_id];
   4016      if (frame_state->collided) {
   4017        WriteACSection(rgba, 0, 0, xs, ys, stride, onegroup, bitdepth,
   4018                       frame_state->nb_chans, frame_state->big_endian,
   4019                       frame_state->hcode, gd);
   4020      } else {
   4021        WriteACSectionPalette(rgba, 0, 0, xs, ys, stride, onegroup,
   4022                              frame_state->hcode, frame_state->lookup.data(),
   4023                              frame_state->nb_chans, gd[0]);
   4024      }
   4025      frame_state->group_sizes[group_id] = SectionSize(gd);
   4026      input.release_buffer(input.opaque, buffer);
   4027    };
   4028    runner(
   4029        runner_opaque, &run_one,
   4030        +[](void* r, size_t i) {
   4031          (*reinterpret_cast<decltype(&run_one)>(r))(i);
   4032        },
   4033        num_groups);
   4034 #if !FJXL_STANDALONE
   4035    if (streaming) {
   4036      local_frame_state.nb_chans = frame_state->nb_chans;
   4037      local_frame_state.current_bit_writer = 1;
   4038      JXL_RETURN_IF_ERROR(
   4039          JxlFastLosslessOutputFrame(&local_frame_state, output_processor));
   4040    }
   4041 #endif
   4042  }
   4043 #if !FJXL_STANDALONE
   4044  if (streaming) {
   4045    size_t end_pos = output_processor->CurrentPosition();
   4046    JXL_RETURN_IF_ERROR(output_processor->Seek(start_pos));
   4047    frame_state->group_data.resize(1);
   4048    bool have_alpha = frame_state->nb_chans == 2 || frame_state->nb_chans == 4;
   4049    size_t padding = ComputeDcGlobalPadding(
   4050        frame_state->group_sizes, frame_state->ac_group_data_offset,
   4051        frame_state->min_dc_global_size, have_alpha, is_last);
   4052 
   4053    for (size_t i = 0; i < padding; ++i) {
   4054      frame_state->group_data[0][0].Write(8, 0);
   4055    }
   4056    frame_state->group_sizes[0] += padding;
   4057    JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/0, is_last);
   4058    assert(frame_state->ac_group_data_offset ==
   4059           JxlFastLosslessOutputSize(frame_state));
   4060    JXL_RETURN_IF_ERROR(
   4061        JxlFastLosslessOutputHeaders(frame_state, output_processor));
   4062    JXL_RETURN_IF_ERROR(output_processor->Seek(end_pos));
   4063  } else if (output_processor) {
   4064    assert(onegroup);
   4065    JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/0, is_last);
   4066    if (output_processor) {
   4067      JXL_RETURN_IF_ERROR(
   4068          JxlFastLosslessOutputFrame(frame_state, output_processor));
   4069    }
   4070  }
   4071  frame_state->process_done = true;
   4072 #endif
   4073  return true;
   4074 }
   4075 
   4076 JxlFastLosslessFrameState* JxlFastLosslessPrepareImpl(
   4077    JxlChunkedFrameInputSource input, size_t width, size_t height,
   4078    size_t nb_chans, size_t bitdepth, bool big_endian, int effort,
   4079    int oneshot) {
   4080  assert(bitdepth > 0);
   4081  assert(nb_chans <= 4);
   4082  assert(nb_chans != 0);
   4083  if (bitdepth <= 8) {
   4084    return LLPrepare(input, width, height, UpTo8Bits(bitdepth), nb_chans,
   4085                     big_endian, effort, oneshot);
   4086  }
   4087  if (bitdepth <= 13) {
   4088    return LLPrepare(input, width, height, From9To13Bits(bitdepth), nb_chans,
   4089                     big_endian, effort, oneshot);
   4090  }
   4091  if (bitdepth == 14) {
   4092    return LLPrepare(input, width, height, Exactly14Bits(bitdepth), nb_chans,
   4093                     big_endian, effort, oneshot);
   4094  }
   4095  return LLPrepare(input, width, height, MoreThan14Bits(bitdepth), nb_chans,
   4096                   big_endian, effort, oneshot);
   4097 }
   4098 
   4099 jxl::Status JxlFastLosslessProcessFrameImpl(
   4100    JxlFastLosslessFrameState* frame_state, bool is_last, void* runner_opaque,
   4101    FJxlParallelRunner runner,
   4102    JxlEncoderOutputProcessorWrapper* output_processor) {
   4103  const size_t bitdepth = frame_state->bitdepth;
   4104  if (bitdepth <= 8) {
   4105    JXL_RETURN_IF_ERROR(LLProcess(frame_state, is_last, UpTo8Bits(bitdepth),
   4106                                  runner_opaque, runner, output_processor));
   4107  } else if (bitdepth <= 13) {
   4108    JXL_RETURN_IF_ERROR(LLProcess(frame_state, is_last, From9To13Bits(bitdepth),
   4109                                  runner_opaque, runner, output_processor));
   4110  } else if (bitdepth == 14) {
   4111    JXL_RETURN_IF_ERROR(LLProcess(frame_state, is_last, Exactly14Bits(bitdepth),
   4112                                  runner_opaque, runner, output_processor));
   4113  } else {
   4114    JXL_RETURN_IF_ERROR(LLProcess(frame_state, is_last,
   4115                                  MoreThan14Bits(bitdepth), runner_opaque,
   4116                                  runner, output_processor));
   4117  }
   4118  return true;
   4119 }
   4120 
   4121 }  // namespace
   4122 
   4123 #endif  // FJXL_SELF_INCLUDE
   4124 
   4125 #ifndef FJXL_SELF_INCLUDE
   4126 
   4127 #define FJXL_SELF_INCLUDE
   4128 
   4129 // If we have NEON enabled, it is the default target.
   4130 #if FJXL_ENABLE_NEON
   4131 
   4132 namespace default_implementation {
   4133 #define FJXL_NEON
   4134 #include "lib/jxl/enc_fast_lossless.cc"
   4135 #undef FJXL_NEON
   4136 }  // namespace default_implementation
   4137 
   4138 #else                                    // FJXL_ENABLE_NEON
   4139 
   4140 namespace default_implementation {
   4141 #include "lib/jxl/enc_fast_lossless.cc"  // NOLINT
   4142 }
   4143 
   4144 #if FJXL_ENABLE_AVX2
   4145 #ifdef __clang__
   4146 #pragma clang attribute push(__attribute__((target("avx,avx2"))), \
   4147                             apply_to = function)
   4148 // Causes spurious warnings on clang5.
   4149 #pragma clang diagnostic push
   4150 #pragma clang diagnostic ignored "-Wmissing-braces"
   4151 #elif defined(__GNUC__)
   4152 #pragma GCC push_options
   4153 // Seems to cause spurious errors on GCC8.
   4154 #pragma GCC diagnostic ignored "-Wpsabi"
   4155 #pragma GCC target "avx,avx2"
   4156 #endif
   4157 
   4158 namespace AVX2 {
   4159 #define FJXL_AVX2
   4160 #include "lib/jxl/enc_fast_lossless.cc"  // NOLINT
   4161 #undef FJXL_AVX2
   4162 }  // namespace AVX2
   4163 
   4164 #ifdef __clang__
   4165 #pragma clang attribute pop
   4166 #pragma clang diagnostic pop
   4167 #elif defined(__GNUC__)
   4168 #pragma GCC pop_options
   4169 #endif
   4170 #endif  // FJXL_ENABLE_AVX2
   4171 
   4172 #if FJXL_ENABLE_AVX512
   4173 #ifdef __clang__
   4174 #pragma clang attribute push(                                                 \
   4175    __attribute__((target("avx512cd,avx512bw,avx512vl,avx512f,avx512vbmi"))), \
   4176    apply_to = function)
   4177 #elif defined(__GNUC__)
   4178 #pragma GCC push_options
   4179 #pragma GCC target "avx512cd,avx512bw,avx512vl,avx512f,avx512vbmi"
   4180 #endif
   4181 
   4182 namespace AVX512 {
   4183 #define FJXL_AVX512
   4184 #include "lib/jxl/enc_fast_lossless.cc"
   4185 #undef FJXL_AVX512
   4186 }  // namespace AVX512
   4187 
   4188 #ifdef __clang__
   4189 #pragma clang attribute pop
   4190 #elif defined(__GNUC__)
   4191 #pragma GCC pop_options
   4192 #endif
   4193 #endif  // FJXL_ENABLE_AVX512
   4194 
   4195 #endif
   4196 
   4197 extern "C" {
   4198 
   4199 #if FJXL_STANDALONE
   4200 class FJxlFrameInput {
   4201 public:
   4202  FJxlFrameInput(const unsigned char* rgba, size_t row_stride, size_t nb_chans,
   4203                 size_t bitdepth)
   4204      : rgba_(rgba),
   4205        row_stride_(row_stride),
   4206        bytes_per_pixel_(bitdepth <= 8 ? nb_chans : 2 * nb_chans) {}
   4207 
   4208  JxlChunkedFrameInputSource GetInputSource() {
   4209    return JxlChunkedFrameInputSource{this, GetDataAt,
   4210                                      [](void*, const void*) {}};
   4211  }
   4212 
   4213 private:
   4214  static const void* GetDataAt(void* opaque, size_t xpos, size_t ypos,
   4215                               size_t xsize, size_t ysize, size_t* row_offset) {
   4216    FJxlFrameInput* self = static_cast<FJxlFrameInput*>(opaque);
   4217    *row_offset = self->row_stride_;
   4218    return self->rgba_ + ypos * (*row_offset) + xpos * self->bytes_per_pixel_;
   4219  }
   4220 
   4221  const uint8_t* rgba_;
   4222  size_t row_stride_;
   4223  size_t bytes_per_pixel_;
   4224 };
   4225 
   4226 size_t JxlFastLosslessEncode(const unsigned char* rgba, size_t width,
   4227                             size_t row_stride, size_t height, size_t nb_chans,
   4228                             size_t bitdepth, bool big_endian, int effort,
   4229                             unsigned char** output, void* runner_opaque,
   4230                             FJxlParallelRunner runner) {
   4231  FJxlFrameInput input(rgba, row_stride, nb_chans, bitdepth);
   4232  auto frame_state = JxlFastLosslessPrepareFrame(
   4233      input.GetInputSource(), width, height, nb_chans, bitdepth, big_endian,
   4234      effort, /*oneshot=*/true);
   4235  if (!JxlFastLosslessProcessFrame(frame_state, /*is_last=*/true, runner_opaque,
   4236                                   runner, nullptr)) {
   4237    return 0;
   4238  }
   4239  JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/1,
   4240                               /*is_last=*/1);
   4241  size_t output_size = JxlFastLosslessMaxRequiredOutput(frame_state);
   4242  *output = (unsigned char*)malloc(output_size);
   4243  size_t written = 0;
   4244  size_t total = 0;
   4245  while ((written = JxlFastLosslessWriteOutput(frame_state, *output + total,
   4246                                               output_size - total)) != 0) {
   4247    total += written;
   4248  }
   4249  JxlFastLosslessFreeFrameState(frame_state);
   4250  return total;
   4251 }
   4252 #endif
   4253 
   4254 JxlFastLosslessFrameState* JxlFastLosslessPrepareFrame(
   4255    JxlChunkedFrameInputSource input, size_t width, size_t height,
   4256    size_t nb_chans, size_t bitdepth, bool big_endian, int effort,
   4257    int oneshot) {
   4258 #if FJXL_ENABLE_AVX512
   4259  if (HasCpuFeature(CpuFeature::kAVX512CD) &&
   4260      HasCpuFeature(CpuFeature::kVBMI) &&
   4261      HasCpuFeature(CpuFeature::kAVX512BW) &&
   4262      HasCpuFeature(CpuFeature::kAVX512F) &&
   4263      HasCpuFeature(CpuFeature::kAVX512VL)) {
   4264    return AVX512::JxlFastLosslessPrepareImpl(
   4265        input, width, height, nb_chans, bitdepth, big_endian, effort, oneshot);
   4266  }
   4267 #endif
   4268 #if FJXL_ENABLE_AVX2
   4269  if (HasCpuFeature(CpuFeature::kAVX2)) {
   4270    return AVX2::JxlFastLosslessPrepareImpl(
   4271        input, width, height, nb_chans, bitdepth, big_endian, effort, oneshot);
   4272  }
   4273 #endif
   4274 
   4275  return default_implementation::JxlFastLosslessPrepareImpl(
   4276      input, width, height, nb_chans, bitdepth, big_endian, effort, oneshot);
   4277 }
   4278 
   4279 bool JxlFastLosslessProcessFrame(
   4280    JxlFastLosslessFrameState* frame_state, bool is_last, void* runner_opaque,
   4281    FJxlParallelRunner runner,
   4282    JxlEncoderOutputProcessorWrapper* output_processor) {
   4283  auto trivial_runner =
   4284      +[](void*, void* opaque, void fun(void*, size_t), size_t count) {
   4285        for (size_t i = 0; i < count; i++) {
   4286          fun(opaque, i);
   4287        }
   4288      };
   4289 
   4290  if (runner == nullptr) {
   4291    runner = trivial_runner;
   4292  }
   4293 
   4294 #if FJXL_ENABLE_AVX512
   4295  if (HasCpuFeature(CpuFeature::kAVX512CD) &&
   4296      HasCpuFeature(CpuFeature::kVBMI) &&
   4297      HasCpuFeature(CpuFeature::kAVX512BW) &&
   4298      HasCpuFeature(CpuFeature::kAVX512F) &&
   4299      HasCpuFeature(CpuFeature::kAVX512VL)) {
   4300    JXL_RETURN_IF_ERROR(AVX512::JxlFastLosslessProcessFrameImpl(
   4301        frame_state, is_last, runner_opaque, runner, output_processor));
   4302    return true;
   4303  }
   4304 #endif
   4305 #if FJXL_ENABLE_AVX2
   4306  if (HasCpuFeature(CpuFeature::kAVX2)) {
   4307    JXL_RETURN_IF_ERROR(AVX2::JxlFastLosslessProcessFrameImpl(
   4308        frame_state, is_last, runner_opaque, runner, output_processor));
   4309    return true;
   4310  }
   4311 #endif
   4312 
   4313  JXL_RETURN_IF_ERROR(default_implementation::JxlFastLosslessProcessFrameImpl(
   4314      frame_state, is_last, runner_opaque, runner, output_processor));
   4315  return true;
   4316 }
   4317 
   4318 }  // extern "C"
   4319 
   4320 #if !FJXL_STANDALONE
   4321 bool JxlFastLosslessOutputFrame(
   4322    JxlFastLosslessFrameState* frame_state,
   4323    JxlEncoderOutputProcessorWrapper* output_processor) {
   4324  size_t fl_size = JxlFastLosslessOutputSize(frame_state);
   4325  size_t written = 0;
   4326  while (written < fl_size) {
   4327    JXL_ASSIGN_OR_RETURN(auto buffer,
   4328                         output_processor->GetBuffer(32, fl_size - written));
   4329    size_t n =
   4330        JxlFastLosslessWriteOutput(frame_state, buffer.data(), buffer.size());
   4331    if (n == 0) break;
   4332    JXL_RETURN_IF_ERROR(buffer.advance(n));
   4333    written += n;
   4334  };
   4335  return true;
   4336 }
   4337 #endif
   4338 
   4339 #endif  // FJXL_SELF_INCLUDE