tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

dec_jpeg_data_writer.cc (33652B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/jpeg/dec_jpeg_data_writer.h"
      7 
      8 #include <algorithm>
      9 #include <cstddef>
     10 #include <cstdint>
     11 #include <cstdlib>
     12 #include <cstring> /* for memset, memcpy */
     13 #include <deque>
     14 #include <utility>
     15 #include <vector>
     16 
     17 #include "lib/jxl/base/bits.h"
     18 #include "lib/jxl/base/byte_order.h"
     19 #include "lib/jxl/base/common.h"
     20 #include "lib/jxl/base/status.h"
     21 #include "lib/jxl/frame_dimensions.h"
     22 #include "lib/jxl/jpeg/dec_jpeg_serialization_state.h"
     23 #include "lib/jxl/jpeg/jpeg_data.h"
     24 
     25 namespace jxl {
     26 namespace jpeg {
     27 
     28 namespace {
     29 
     30 enum struct SerializationStatus {
     31  NEEDS_MORE_INPUT,
     32  NEEDS_MORE_OUTPUT,
     33  ERROR,
     34  DONE
     35 };
     36 
     37 const int kJpegPrecision = 8;
     38 
     39 // JpegBitWriter: buffer size
     40 const size_t kJpegBitWriterChunkSize = 16384;
     41 
     42 // Returns non-zero if and only if x has a zero byte, i.e. one of
     43 // x & 0xff, x & 0xff00, ..., x & 0xff00000000000000 is zero.
     44 JXL_INLINE uint64_t HasZeroByte(uint64_t x) {
     45  return (x - 0x0101010101010101ULL) & ~x & 0x8080808080808080ULL;
     46 }
     47 
     48 void JpegBitWriterInit(JpegBitWriter* bw,
     49                       std::deque<OutputChunk>* output_queue) {
     50  bw->output = output_queue;
     51  bw->chunk = OutputChunk(kJpegBitWriterChunkSize);
     52  bw->pos = 0;
     53  bw->put_buffer = 0;
     54  bw->put_bits = 64;
     55  bw->healthy = true;
     56  bw->data = bw->chunk.buffer->data();
     57 }
     58 
     59 JXL_NOINLINE void SwapBuffer(JpegBitWriter* bw) {
     60  bw->chunk.len = bw->pos;
     61  bw->output->emplace_back(std::move(bw->chunk));
     62  bw->chunk = OutputChunk(kJpegBitWriterChunkSize);
     63  bw->data = bw->chunk.buffer->data();
     64  bw->pos = 0;
     65 }
     66 
     67 JXL_INLINE void Reserve(JpegBitWriter* bw, size_t n_bytes) {
     68  if (JXL_UNLIKELY((bw->pos + n_bytes) > kJpegBitWriterChunkSize)) {
     69    SwapBuffer(bw);
     70  }
     71 }
     72 
     73 /**
     74 * Writes the given byte to the output, writes an extra zero if byte is 0xFF.
     75 *
     76 * This method is "careless" - caller must make sure that there is enough
     77 * space in the output buffer. Emits up to 2 bytes to buffer.
     78 */
     79 JXL_INLINE void EmitByte(JpegBitWriter* bw, int byte) {
     80  bw->data[bw->pos] = byte;
     81  bw->data[bw->pos + 1] = 0;
     82  bw->pos += (byte != 0xFF ? 1 : 2);
     83 }
     84 
     85 JXL_INLINE void DischargeBitBuffer(JpegBitWriter* bw, int nbits,
     86                                   uint64_t bits) {
     87  // At this point we are ready to emit the put_buffer to the output.
     88  // The JPEG format requires that after every 0xff byte in the entropy
     89  // coded section, there is a zero byte, therefore we first check if any of
     90  // the 8 bytes of put_buffer is 0xFF.
     91  bw->put_buffer |= (bits >> -bw->put_bits);
     92  if (JXL_UNLIKELY(HasZeroByte(~bw->put_buffer))) {
     93    // We have a 0xFF byte somewhere, examine each byte and append a zero
     94    // byte if necessary.
     95    EmitByte(bw, (bw->put_buffer >> 56) & 0xFF);
     96    EmitByte(bw, (bw->put_buffer >> 48) & 0xFF);
     97    EmitByte(bw, (bw->put_buffer >> 40) & 0xFF);
     98    EmitByte(bw, (bw->put_buffer >> 32) & 0xFF);
     99    EmitByte(bw, (bw->put_buffer >> 24) & 0xFF);
    100    EmitByte(bw, (bw->put_buffer >> 16) & 0xFF);
    101    EmitByte(bw, (bw->put_buffer >> 8) & 0xFF);
    102    EmitByte(bw, (bw->put_buffer) & 0xFF);
    103  } else {
    104    // We don't have any 0xFF bytes, output all 8 bytes without checking.
    105    StoreBE64(bw->put_buffer, bw->data + bw->pos);
    106    bw->pos += 8;
    107  }
    108 
    109  bw->put_bits += 64;
    110  bw->put_buffer = bits << bw->put_bits;
    111 }
    112 
    113 JXL_INLINE void WriteBits(JpegBitWriter* bw, int nbits, uint64_t bits) {
    114  JXL_DASSERT(nbits > 0);
    115  bw->put_bits -= nbits;
    116  if (JXL_UNLIKELY(bw->put_bits < 0)) {
    117    if (JXL_UNLIKELY(nbits > 64)) {
    118      bw->put_bits += nbits;
    119      bw->healthy = false;
    120    } else {
    121      DischargeBitBuffer(bw, nbits, bits);
    122    }
    123  } else {
    124    bw->put_buffer |= (bits << bw->put_bits);
    125  }
    126 }
    127 
    128 void EmitMarker(JpegBitWriter* bw, int marker) {
    129  Reserve(bw, 2);
    130  JXL_DASSERT(marker != 0xFF);
    131  bw->data[bw->pos++] = 0xFF;
    132  bw->data[bw->pos++] = marker;
    133 }
    134 
    135 bool JumpToByteBoundary(JpegBitWriter* bw, const uint8_t** pad_bits,
    136                        const uint8_t* pad_bits_end) {
    137  size_t n_bits = bw->put_bits & 7u;
    138  uint8_t dangling_bits = 0;
    139  uint8_t pad_pattern;
    140  if (*pad_bits == nullptr) {
    141    pad_pattern = (1u << n_bits) - 1;
    142  } else {
    143    pad_pattern = 0;
    144    const uint8_t* src = *pad_bits;
    145    // TODO(eustas): bitwise reading looks insanely ineffective!
    146    while (n_bits--) {
    147      pad_pattern <<= 1;
    148      if (src >= pad_bits_end) return false;
    149      uint8_t bit = *src;
    150      src++;
    151      dangling_bits |= bit;
    152      pad_pattern |= bit;
    153    }
    154    *pad_bits = src;
    155  }
    156  if ((dangling_bits & ~1) != 0) return false;
    157 
    158  Reserve(bw, 16);
    159 
    160  while (bw->put_bits <= 56) {
    161    int c = (bw->put_buffer >> 56) & 0xFF;
    162    EmitByte(bw, c);
    163    bw->put_buffer <<= 8;
    164    bw->put_bits += 8;
    165  }
    166  if (bw->put_bits < 64) {
    167    int pad_mask = 0xFFu >> (64 - bw->put_bits);
    168    int c = ((bw->put_buffer >> 56) & ~pad_mask) | pad_pattern;
    169    EmitByte(bw, c);
    170  }
    171  bw->put_buffer = 0;
    172  bw->put_bits = 64;
    173 
    174  return true;
    175 }
    176 
    177 void JpegBitWriterFinish(JpegBitWriter* bw) {
    178  if (bw->pos == 0) return;
    179  bw->chunk.len = bw->pos;
    180  bw->output->emplace_back(std::move(bw->chunk));
    181  bw->chunk = OutputChunk(nullptr, 0);
    182  bw->data = nullptr;
    183  bw->pos = 0;
    184 }
    185 
    186 void DCTCodingStateInit(DCTCodingState* s) {
    187  s->eob_run_ = 0;
    188  s->cur_ac_huff_ = nullptr;
    189  s->refinement_bits_.clear();
    190  s->refinement_bits_.reserve(64);
    191 }
    192 
    193 JXL_INLINE void WriteSymbol(int symbol, HuffmanCodeTable* table,
    194                            JpegBitWriter* bw) {
    195  WriteBits(bw, table->depth[symbol], table->code[symbol]);
    196 }
    197 
    198 JXL_INLINE void WriteSymbolBits(int symbol, HuffmanCodeTable* table,
    199                                JpegBitWriter* bw, int nbits, uint64_t bits) {
    200  WriteBits(bw, nbits + table->depth[symbol],
    201            bits | (table->code[symbol] << nbits));
    202 }
    203 
    204 // Emit all buffered data to the bit stream using the given Huffman code and
    205 // bit writer.
    206 JXL_INLINE void Flush(DCTCodingState* s, JpegBitWriter* bw) {
    207  if (s->eob_run_ > 0) {
    208    Reserve(bw, 16);
    209    int nbits = FloorLog2Nonzero<uint32_t>(s->eob_run_);
    210    int symbol = nbits << 4u;
    211    WriteSymbol(symbol, s->cur_ac_huff_, bw);
    212    if (nbits > 0) {
    213      WriteBits(bw, nbits, s->eob_run_ & ((1 << nbits) - 1));
    214    }
    215    s->eob_run_ = 0;
    216  }
    217  const size_t kStride = 124;  // (515 - 16) / 2 / 2
    218  size_t num_words = s->refinement_bits_count_ >> 4;
    219  size_t i = 0;
    220  while (i < num_words) {
    221    size_t limit = std::min(i + kStride, num_words);
    222    Reserve(bw, 512);
    223    for (; i < limit; ++i) {
    224      WriteBits(bw, 16, s->refinement_bits_[i]);
    225    }
    226  }
    227  Reserve(bw, 16);
    228  size_t tail = s->refinement_bits_count_ & 0xF;
    229  if (tail) {
    230    WriteBits(bw, tail, s->refinement_bits_.back());
    231  }
    232  s->refinement_bits_.clear();
    233  s->refinement_bits_count_ = 0;
    234 }
    235 
    236 // Buffer some more data at the end-of-band (the last non-zero or newly
    237 // non-zero coefficient within the [Ss, Se] spectral band).
    238 JXL_INLINE void BufferEndOfBand(DCTCodingState* s, HuffmanCodeTable* ac_huff,
    239                                const int* new_bits_array,
    240                                size_t new_bits_count, JpegBitWriter* bw) {
    241  if (s->eob_run_ == 0) {
    242    s->cur_ac_huff_ = ac_huff;
    243  }
    244  ++s->eob_run_;
    245  if (new_bits_count) {
    246    uint64_t new_bits = 0;
    247    for (size_t i = 0; i < new_bits_count; ++i) {
    248      new_bits = (new_bits << 1) | new_bits_array[i];
    249    }
    250    size_t tail = s->refinement_bits_count_ & 0xF;
    251    if (tail) {  // First stuff the tail item
    252      size_t stuff_bits_count = std::min(16 - tail, new_bits_count);
    253      uint16_t stuff_bits = new_bits >> (new_bits_count - stuff_bits_count);
    254      stuff_bits &= ((1u << stuff_bits_count) - 1);
    255      s->refinement_bits_.back() =
    256          (s->refinement_bits_.back() << stuff_bits_count) | stuff_bits;
    257      new_bits_count -= stuff_bits_count;
    258      s->refinement_bits_count_ += stuff_bits_count;
    259    }
    260    while (new_bits_count >= 16) {
    261      s->refinement_bits_.push_back(new_bits >> (new_bits_count - 16));
    262      new_bits_count -= 16;
    263      s->refinement_bits_count_ += 16;
    264    }
    265    if (new_bits_count) {
    266      s->refinement_bits_.push_back(new_bits & ((1u << new_bits_count) - 1));
    267      s->refinement_bits_count_ += new_bits_count;
    268    }
    269  }
    270 
    271  if (s->eob_run_ == 0x7FFF) {
    272    Flush(s, bw);
    273  }
    274 }
    275 
    276 bool BuildHuffmanCodeTable(const JPEGHuffmanCode& huff,
    277                           HuffmanCodeTable* table) {
    278  int huff_code[kJpegHuffmanAlphabetSize];
    279  // +1 for a sentinel element.
    280  uint32_t huff_size[kJpegHuffmanAlphabetSize + 1];
    281  int p = 0;
    282  for (size_t l = 1; l <= kJpegHuffmanMaxBitLength; ++l) {
    283    int i = huff.counts[l];
    284    if (p + i > kJpegHuffmanAlphabetSize + 1) {
    285      return false;
    286    }
    287    while (i--) huff_size[p++] = l;
    288  }
    289 
    290  if (p == 0) {
    291    return true;
    292  }
    293 
    294  // Reuse sentinel element.
    295  int last_p = p - 1;
    296  huff_size[last_p] = 0;
    297 
    298  int code = 0;
    299  uint32_t si = huff_size[0];
    300  p = 0;
    301  while (huff_size[p]) {
    302    while ((huff_size[p]) == si) {
    303      huff_code[p++] = code;
    304      code++;
    305    }
    306    code <<= 1;
    307    si++;
    308  }
    309  for (p = 0; p < last_p; p++) {
    310    int i = huff.values[p];
    311    table->depth[i] = huff_size[p];
    312    table->code[i] = huff_code[p];
    313  }
    314  return true;
    315 }
    316 
    317 bool EncodeSOI(SerializationState* state) {
    318  state->output_queue.push_back(OutputChunk({0xFF, 0xD8}));
    319  return true;
    320 }
    321 
    322 bool EncodeEOI(const JPEGData& jpg, SerializationState* state) {
    323  state->output_queue.push_back(OutputChunk({0xFF, 0xD9}));
    324  state->output_queue.emplace_back(jpg.tail_data);
    325  return true;
    326 }
    327 
    328 bool EncodeSOF(const JPEGData& jpg, uint8_t marker, SerializationState* state) {
    329  if (marker <= 0xC2) state->is_progressive = (marker == 0xC2);
    330 
    331  const size_t n_comps = jpg.components.size();
    332  const size_t marker_len = 8 + 3 * n_comps;
    333  state->output_queue.emplace_back(marker_len + 2);
    334  uint8_t* data = state->output_queue.back().buffer->data();
    335  size_t pos = 0;
    336  data[pos++] = 0xFF;
    337  data[pos++] = marker;
    338  data[pos++] = marker_len >> 8u;
    339  data[pos++] = marker_len & 0xFFu;
    340  data[pos++] = kJpegPrecision;
    341  data[pos++] = jpg.height >> 8u;
    342  data[pos++] = jpg.height & 0xFFu;
    343  data[pos++] = jpg.width >> 8u;
    344  data[pos++] = jpg.width & 0xFFu;
    345  data[pos++] = n_comps;
    346  for (size_t i = 0; i < n_comps; ++i) {
    347    data[pos++] = jpg.components[i].id;
    348    data[pos++] = ((jpg.components[i].h_samp_factor << 4u) |
    349                   (jpg.components[i].v_samp_factor));
    350    const size_t quant_idx = jpg.components[i].quant_idx;
    351    if (quant_idx >= jpg.quant.size()) return false;
    352    data[pos++] = jpg.quant[quant_idx].index;
    353  }
    354  return true;
    355 }
    356 
    357 bool EncodeSOS(const JPEGData& jpg, const JPEGScanInfo& scan_info,
    358               SerializationState* state) {
    359  const size_t n_scans = scan_info.num_components;
    360  const size_t marker_len = 6 + 2 * n_scans;
    361  state->output_queue.emplace_back(marker_len + 2);
    362  uint8_t* data = state->output_queue.back().buffer->data();
    363  size_t pos = 0;
    364  data[pos++] = 0xFF;
    365  data[pos++] = 0xDA;
    366  data[pos++] = marker_len >> 8u;
    367  data[pos++] = marker_len & 0xFFu;
    368  data[pos++] = n_scans;
    369  for (size_t i = 0; i < n_scans; ++i) {
    370    const JPEGComponentScanInfo& si = scan_info.components[i];
    371    if (si.comp_idx >= jpg.components.size()) return false;
    372    data[pos++] = jpg.components[si.comp_idx].id;
    373    data[pos++] = (si.dc_tbl_idx << 4u) + si.ac_tbl_idx;
    374  }
    375  data[pos++] = scan_info.Ss;
    376  data[pos++] = scan_info.Se;
    377  data[pos++] = ((scan_info.Ah << 4u) | (scan_info.Al));
    378  return true;
    379 }
    380 
    381 bool EncodeDHT(const JPEGData& jpg, SerializationState* state) {
    382  const std::vector<JPEGHuffmanCode>& huffman_code = jpg.huffman_code;
    383 
    384  size_t marker_len = 2;
    385  for (size_t i = state->dht_index; i < huffman_code.size(); ++i) {
    386    const JPEGHuffmanCode& huff = huffman_code[i];
    387    marker_len += kJpegHuffmanMaxBitLength;
    388    for (uint32_t count : huff.counts) {
    389      marker_len += count;
    390    }
    391    if (huff.is_last) break;
    392  }
    393  state->output_queue.emplace_back(marker_len + 2);
    394  uint8_t* data = state->output_queue.back().buffer->data();
    395  size_t pos = 0;
    396  data[pos++] = 0xFF;
    397  data[pos++] = 0xC4;
    398  data[pos++] = marker_len >> 8u;
    399  data[pos++] = marker_len & 0xFFu;
    400  while (true) {
    401    const size_t huffman_code_index = state->dht_index++;
    402    if (huffman_code_index >= huffman_code.size()) {
    403      return false;
    404    }
    405    const JPEGHuffmanCode& huff = huffman_code[huffman_code_index];
    406    size_t index = huff.slot_id;
    407    HuffmanCodeTable* huff_table;
    408    if (index & 0x10) {
    409      index -= 0x10;
    410      huff_table = &state->ac_huff_table[index];
    411    } else {
    412      huff_table = &state->dc_huff_table[index];
    413    }
    414    // TODO(eustas): cache
    415    huff_table->InitDepths(127);
    416    if (!BuildHuffmanCodeTable(huff, huff_table)) {
    417      return false;
    418    }
    419    huff_table->initialized = true;
    420    size_t total_count = 0;
    421    size_t max_length = 0;
    422    for (size_t i = 0; i < huff.counts.size(); ++i) {
    423      if (huff.counts[i] != 0) {
    424        max_length = i;
    425      }
    426      total_count += huff.counts[i];
    427    }
    428    --total_count;
    429    data[pos++] = huff.slot_id;
    430    for (size_t i = 1; i <= kJpegHuffmanMaxBitLength; ++i) {
    431      data[pos++] = (i == max_length ? huff.counts[i] - 1 : huff.counts[i]);
    432    }
    433    for (size_t i = 0; i < total_count; ++i) {
    434      data[pos++] = huff.values[i];
    435    }
    436    if (huff.is_last) break;
    437  }
    438  return true;
    439 }
    440 
    441 bool EncodeDQT(const JPEGData& jpg, SerializationState* state) {
    442  int marker_len = 2;
    443  for (size_t i = state->dqt_index; i < jpg.quant.size(); ++i) {
    444    const JPEGQuantTable& table = jpg.quant[i];
    445    marker_len += 1 + (table.precision ? 2 : 1) * kDCTBlockSize;
    446    if (table.is_last) break;
    447  }
    448  state->output_queue.emplace_back(marker_len + 2);
    449  uint8_t* data = state->output_queue.back().buffer->data();
    450  size_t pos = 0;
    451  data[pos++] = 0xFF;
    452  data[pos++] = 0xDB;
    453  data[pos++] = marker_len >> 8u;
    454  data[pos++] = marker_len & 0xFFu;
    455  while (true) {
    456    const size_t idx = state->dqt_index++;
    457    if (idx >= jpg.quant.size()) {
    458      return false;  // corrupt input
    459    }
    460    const JPEGQuantTable& table = jpg.quant[idx];
    461    data[pos++] = (table.precision << 4u) + table.index;
    462    for (size_t i = 0; i < kDCTBlockSize; ++i) {
    463      int val_idx = kJPEGNaturalOrder[i];
    464      int val = table.values[val_idx];
    465      if (table.precision) {
    466        data[pos++] = val >> 8u;
    467      }
    468      data[pos++] = val & 0xFFu;
    469    }
    470    if (table.is_last) break;
    471  }
    472  return true;
    473 }
    474 
    475 bool EncodeDRI(const JPEGData& jpg, SerializationState* state) {
    476  state->seen_dri_marker = true;
    477  OutputChunk dri_marker = {0xFF,
    478                            0xDD,
    479                            0,
    480                            4,
    481                            static_cast<uint8_t>(jpg.restart_interval >> 8),
    482                            static_cast<uint8_t>(jpg.restart_interval & 0xFF)};
    483  state->output_queue.push_back(std::move(dri_marker));
    484  return true;
    485 }
    486 
    487 bool EncodeRestart(uint8_t marker, SerializationState* state) {
    488  state->output_queue.push_back(OutputChunk({0xFF, marker}));
    489  return true;
    490 }
    491 
    492 bool EncodeAPP(const JPEGData& jpg, uint8_t marker, SerializationState* state) {
    493  // TODO(eustas): check that marker corresponds to payload?
    494  (void)marker;
    495 
    496  size_t app_index = state->app_index++;
    497  if (app_index >= jpg.app_data.size()) return false;
    498  state->output_queue.push_back(OutputChunk({0xFF}));
    499  state->output_queue.emplace_back(jpg.app_data[app_index]);
    500  return true;
    501 }
    502 
    503 bool EncodeCOM(const JPEGData& jpg, SerializationState* state) {
    504  size_t com_index = state->com_index++;
    505  if (com_index >= jpg.com_data.size()) return false;
    506  state->output_queue.push_back(OutputChunk({0xFF}));
    507  state->output_queue.emplace_back(jpg.com_data[com_index]);
    508  return true;
    509 }
    510 
    511 bool EncodeInterMarkerData(const JPEGData& jpg, SerializationState* state) {
    512  size_t index = state->data_index++;
    513  if (index >= jpg.inter_marker_data.size()) return false;
    514  state->output_queue.emplace_back(jpg.inter_marker_data[index]);
    515  return true;
    516 }
    517 
    518 bool EncodeDCTBlockSequential(const coeff_t* coeffs, HuffmanCodeTable* dc_huff,
    519                              HuffmanCodeTable* ac_huff, int num_zero_runs,
    520                              coeff_t* last_dc_coeff, JpegBitWriter* bw) {
    521  coeff_t temp2;
    522  coeff_t temp;
    523  coeff_t litmus = 0;
    524  temp2 = coeffs[0];
    525  temp = temp2 - *last_dc_coeff;
    526  *last_dc_coeff = temp2;
    527  temp2 = temp >> (8 * sizeof(coeff_t) - 1);
    528  temp += temp2;
    529  temp2 ^= temp;
    530 
    531  int dc_nbits = (temp2 == 0) ? 0 : (FloorLog2Nonzero<uint32_t>(temp2) + 1);
    532  WriteSymbol(dc_nbits, dc_huff, bw);
    533 #if JXL_FALSE
    534  // If the input is corrupt, this could be triggered. Checking is
    535  // costly though, so it makes more sense to avoid this branch.
    536  // (producing a corrupt JPEG when the input is corrupt, instead
    537  // of catching it and returning error)
    538  if (dc_nbits >= 12) return false;
    539 #endif
    540  if (dc_nbits) {
    541    WriteBits(bw, dc_nbits, temp & ((1u << dc_nbits) - 1));
    542  }
    543  int16_t r = 0;
    544 
    545  for (size_t i = 1; i < 64; i++) {
    546    temp = coeffs[kJPEGNaturalOrder[i]];
    547    if (temp == 0) {
    548      r++;
    549    } else {
    550      temp2 = temp >> (8 * sizeof(coeff_t) - 1);
    551      temp += temp2;
    552      temp2 ^= temp;
    553      if (JXL_UNLIKELY(r > 15)) {
    554        WriteSymbol(0xf0, ac_huff, bw);
    555        r -= 16;
    556        if (r > 15) {
    557          WriteSymbol(0xf0, ac_huff, bw);
    558          r -= 16;
    559        }
    560        if (r > 15) {
    561          WriteSymbol(0xf0, ac_huff, bw);
    562          r -= 16;
    563        }
    564      }
    565      litmus |= temp2;
    566      int ac_nbits =
    567          FloorLog2Nonzero<uint32_t>(static_cast<uint16_t>(temp2)) + 1;
    568      int symbol = (r << 4u) + ac_nbits;
    569      WriteSymbolBits(symbol, ac_huff, bw, ac_nbits,
    570                      temp & ((1 << ac_nbits) - 1));
    571      r = 0;
    572    }
    573  }
    574 
    575  for (int i = 0; i < num_zero_runs; ++i) {
    576    WriteSymbol(0xf0, ac_huff, bw);
    577    r -= 16;
    578  }
    579  if (r > 0) {
    580    WriteSymbol(0, ac_huff, bw);
    581  }
    582  return (litmus >= 0);
    583 }
    584 
    585 bool EncodeDCTBlockProgressive(const coeff_t* coeffs, HuffmanCodeTable* dc_huff,
    586                               HuffmanCodeTable* ac_huff, int Ss, int Se,
    587                               int Al, int num_zero_runs,
    588                               DCTCodingState* coding_state,
    589                               coeff_t* last_dc_coeff, JpegBitWriter* bw) {
    590  bool eob_run_allowed = Ss > 0;
    591  coeff_t temp2;
    592  coeff_t temp;
    593  if (Ss == 0) {
    594    temp2 = coeffs[0] >> Al;
    595    temp = temp2 - *last_dc_coeff;
    596    *last_dc_coeff = temp2;
    597    temp2 = temp;
    598    if (temp < 0) {
    599      temp = -temp;
    600      if (temp < 0) return false;
    601      temp2--;
    602    }
    603    int nbits = (temp == 0) ? 0 : (FloorLog2Nonzero<uint32_t>(temp) + 1);
    604    WriteSymbol(nbits, dc_huff, bw);
    605    if (nbits) {
    606      WriteBits(bw, nbits, temp2 & ((1 << nbits) - 1));
    607    }
    608    ++Ss;
    609  }
    610  if (Ss > Se) {
    611    return true;
    612  }
    613  int r = 0;
    614  for (int k = Ss; k <= Se; ++k) {
    615    temp = coeffs[kJPEGNaturalOrder[k]];
    616    if (temp == 0) {
    617      r++;
    618      continue;
    619    }
    620    if (temp < 0) {
    621      temp = -temp;
    622      if (temp < 0) return false;
    623      temp >>= Al;
    624      temp2 = ~temp;
    625    } else {
    626      temp >>= Al;
    627      temp2 = temp;
    628    }
    629    if (temp == 0) {
    630      r++;
    631      continue;
    632    }
    633    Flush(coding_state, bw);
    634    while (r > 15) {
    635      WriteSymbol(0xf0, ac_huff, bw);
    636      r -= 16;
    637    }
    638    int nbits = FloorLog2Nonzero<uint32_t>(temp) + 1;
    639    int symbol = (r << 4u) + nbits;
    640    WriteSymbol(symbol, ac_huff, bw);
    641    WriteBits(bw, nbits, temp2 & ((1 << nbits) - 1));
    642    r = 0;
    643  }
    644  if (num_zero_runs > 0) {
    645    Flush(coding_state, bw);
    646    for (int i = 0; i < num_zero_runs; ++i) {
    647      WriteSymbol(0xf0, ac_huff, bw);
    648      r -= 16;
    649    }
    650  }
    651  if (r > 0) {
    652    BufferEndOfBand(coding_state, ac_huff, nullptr, 0, bw);
    653    if (!eob_run_allowed) {
    654      Flush(coding_state, bw);
    655    }
    656  }
    657  return true;
    658 }
    659 
    660 bool EncodeRefinementBits(const coeff_t* coeffs, HuffmanCodeTable* ac_huff,
    661                          int Ss, int Se, int Al, DCTCodingState* coding_state,
    662                          JpegBitWriter* bw) {
    663  bool eob_run_allowed = Ss > 0;
    664  if (Ss == 0) {
    665    // Emit next bit of DC component.
    666    WriteBits(bw, 1, (coeffs[0] >> Al) & 1);
    667    ++Ss;
    668  }
    669  if (Ss > Se) {
    670    return true;
    671  }
    672  int abs_values[kDCTBlockSize];
    673  int eob = 0;
    674  for (int k = Ss; k <= Se; k++) {
    675    const coeff_t abs_val = std::abs(coeffs[kJPEGNaturalOrder[k]]);
    676    abs_values[k] = abs_val >> Al;
    677    if (abs_values[k] == 1) {
    678      eob = k;
    679    }
    680  }
    681  int r = 0;
    682  int refinement_bits[kDCTBlockSize];
    683  size_t refinement_bits_count = 0;
    684  for (int k = Ss; k <= Se; k++) {
    685    if (abs_values[k] == 0) {
    686      r++;
    687      continue;
    688    }
    689    while (r > 15 && k <= eob) {
    690      Flush(coding_state, bw);
    691      WriteSymbol(0xf0, ac_huff, bw);
    692      r -= 16;
    693      for (size_t i = 0; i < refinement_bits_count; ++i) {
    694        WriteBits(bw, 1, refinement_bits[i]);
    695      }
    696      refinement_bits_count = 0;
    697    }
    698    if (abs_values[k] > 1) {
    699      refinement_bits[refinement_bits_count++] = abs_values[k] & 1u;
    700      continue;
    701    }
    702    Flush(coding_state, bw);
    703    int symbol = (r << 4u) + 1;
    704    int new_non_zero_bit = (coeffs[kJPEGNaturalOrder[k]] < 0) ? 0 : 1;
    705    WriteSymbol(symbol, ac_huff, bw);
    706    WriteBits(bw, 1, new_non_zero_bit);
    707    for (size_t i = 0; i < refinement_bits_count; ++i) {
    708      WriteBits(bw, 1, refinement_bits[i]);
    709    }
    710    refinement_bits_count = 0;
    711    r = 0;
    712  }
    713  if (r > 0 || refinement_bits_count) {
    714    BufferEndOfBand(coding_state, ac_huff, refinement_bits,
    715                    refinement_bits_count, bw);
    716    if (!eob_run_allowed) {
    717      Flush(coding_state, bw);
    718    }
    719  }
    720  return true;
    721 }
    722 
    723 template <int kMode>
    724 SerializationStatus JXL_NOINLINE DoEncodeScan(const JPEGData& jpg,
    725                                              SerializationState* state) {
    726  const JPEGScanInfo& scan_info = jpg.scan_info[state->scan_index];
    727  EncodeScanState& ss = state->scan_state;
    728 
    729  const int restart_interval =
    730      state->seen_dri_marker ? jpg.restart_interval : 0;
    731 
    732  const auto get_next_extra_zero_run_index = [&ss, &scan_info]() -> int {
    733    if (ss.extra_zero_runs_pos < scan_info.extra_zero_runs.size()) {
    734      return scan_info.extra_zero_runs[ss.extra_zero_runs_pos].block_idx;
    735    } else {
    736      return -1;
    737    }
    738  };
    739 
    740  const auto get_next_reset_point = [&ss, &scan_info]() -> int {
    741    if (ss.next_reset_point_pos < scan_info.reset_points.size()) {
    742      return scan_info.reset_points[ss.next_reset_point_pos++];
    743    } else {
    744      return -1;
    745    }
    746  };
    747 
    748  if (ss.stage == EncodeScanState::HEAD) {
    749    if (!EncodeSOS(jpg, scan_info, state)) return SerializationStatus::ERROR;
    750    JpegBitWriterInit(&ss.bw, &state->output_queue);
    751    DCTCodingStateInit(&ss.coding_state);
    752    ss.restarts_to_go = restart_interval;
    753    ss.next_restart_marker = 0;
    754    ss.block_scan_index = 0;
    755    ss.extra_zero_runs_pos = 0;
    756    ss.next_extra_zero_run_index = get_next_extra_zero_run_index();
    757    ss.next_reset_point_pos = 0;
    758    ss.next_reset_point = get_next_reset_point();
    759    ss.mcu_y = 0;
    760    memset(ss.last_dc_coeff, 0, sizeof(ss.last_dc_coeff));
    761    ss.stage = EncodeScanState::BODY;
    762  }
    763  JpegBitWriter* bw = &ss.bw;
    764  DCTCodingState* coding_state = &ss.coding_state;
    765 
    766  if (ss.stage != EncodeScanState::BODY) return SerializationStatus::ERROR;
    767 
    768  // "Non-interleaved" means color data comes in separate scans, in other words
    769  // each scan can contain only one color component.
    770  const bool is_interleaved = (scan_info.num_components > 1);
    771  int MCUs_per_row = 0;
    772  int MCU_rows = 0;
    773  jpg.CalculateMcuSize(scan_info, &MCUs_per_row, &MCU_rows);
    774  const bool is_progressive = state->is_progressive;
    775  const int Al = is_progressive ? scan_info.Al : 0;
    776  const int Ss = is_progressive ? scan_info.Ss : 0;
    777  const int Se = is_progressive ? scan_info.Se : 63;
    778 
    779  // DC-only is defined by [0..0] spectral range.
    780  const bool want_ac = ((Ss != 0) || (Se != 0));
    781  const bool want_dc = (Ss == 0);
    782  // TODO(user): support streaming decoding again.
    783  const bool complete_ac = true;
    784  const bool has_ac = true;
    785  if (want_ac && !has_ac) return SerializationStatus::NEEDS_MORE_INPUT;
    786 
    787  // |has_ac| implies |complete_dc| but not vice versa; for the sake of
    788  // simplicity we pretend they are equal, because they are separated by just a
    789  // few bytes of input.
    790  const bool complete_dc = has_ac;
    791  const bool complete = want_ac ? complete_ac : complete_dc;
    792  // When "incomplete" |ac_dc| tracks information about current ("incomplete")
    793  // band parsing progress.
    794 
    795  // FIXME: Is this always complete?
    796  // const int last_mcu_y =
    797  //     complete ? MCU_rows : parsing_state.internal->ac_dc.next_mcu_y *
    798  //     v_group;
    799  (void)complete;
    800  const int last_mcu_y = complete ? MCU_rows : 0;
    801 
    802  for (; ss.mcu_y < last_mcu_y; ++ss.mcu_y) {
    803    for (int mcu_x = 0; mcu_x < MCUs_per_row; ++mcu_x) {
    804      // Possibly emit a restart marker.
    805      if (restart_interval > 0 && ss.restarts_to_go == 0) {
    806        Flush(coding_state, bw);
    807        if (!JumpToByteBoundary(bw, &state->pad_bits, state->pad_bits_end)) {
    808          return SerializationStatus::ERROR;
    809        }
    810        EmitMarker(bw, 0xD0 + ss.next_restart_marker);
    811        ss.next_restart_marker += 1;
    812        ss.next_restart_marker &= 0x7;
    813        ss.restarts_to_go = restart_interval;
    814        memset(ss.last_dc_coeff, 0, sizeof(ss.last_dc_coeff));
    815      }
    816 
    817      // Encode one MCU
    818      for (size_t i = 0; i < scan_info.num_components; ++i) {
    819        const JPEGComponentScanInfo& si = scan_info.components[i];
    820        const JPEGComponent& c = jpg.components[si.comp_idx];
    821        size_t dc_tbl_idx = si.dc_tbl_idx;
    822        size_t ac_tbl_idx = si.ac_tbl_idx;
    823        HuffmanCodeTable* dc_huff = &state->dc_huff_table[dc_tbl_idx];
    824        HuffmanCodeTable* ac_huff = &state->ac_huff_table[ac_tbl_idx];
    825        if (want_dc && !dc_huff->initialized) {
    826          return SerializationStatus::ERROR;
    827        }
    828        if (want_ac && !ac_huff->initialized) {
    829          return SerializationStatus::ERROR;
    830        }
    831        int n_blocks_y = is_interleaved ? c.v_samp_factor : 1;
    832        int n_blocks_x = is_interleaved ? c.h_samp_factor : 1;
    833        for (int iy = 0; iy < n_blocks_y; ++iy) {
    834          for (int ix = 0; ix < n_blocks_x; ++ix) {
    835            int block_y = ss.mcu_y * n_blocks_y + iy;
    836            int block_x = mcu_x * n_blocks_x + ix;
    837            int block_idx = block_y * c.width_in_blocks + block_x;
    838            if (ss.block_scan_index == ss.next_reset_point) {
    839              Flush(coding_state, bw);
    840              ss.next_reset_point = get_next_reset_point();
    841            }
    842            int num_zero_runs = 0;
    843            if (ss.block_scan_index == ss.next_extra_zero_run_index) {
    844              num_zero_runs = scan_info.extra_zero_runs[ss.extra_zero_runs_pos]
    845                                  .num_extra_zero_runs;
    846              ++ss.extra_zero_runs_pos;
    847              ss.next_extra_zero_run_index = get_next_extra_zero_run_index();
    848            }
    849            const coeff_t* coeffs = &c.coeffs[block_idx << 6];
    850            bool ok;
    851            // compressed size per block cannot be more than 512 bytes
    852            Reserve(bw, 512);
    853            if (kMode == 0) {
    854              ok = EncodeDCTBlockSequential(coeffs, dc_huff, ac_huff,
    855                                            num_zero_runs,
    856                                            ss.last_dc_coeff + si.comp_idx, bw);
    857            } else if (kMode == 1) {
    858              ok = EncodeDCTBlockProgressive(
    859                  coeffs, dc_huff, ac_huff, Ss, Se, Al, num_zero_runs,
    860                  coding_state, ss.last_dc_coeff + si.comp_idx, bw);
    861            } else {
    862              ok = EncodeRefinementBits(coeffs, ac_huff, Ss, Se, Al,
    863                                        coding_state, bw);
    864            }
    865            if (!ok) return SerializationStatus::ERROR;
    866            ++ss.block_scan_index;
    867          }
    868        }
    869      }
    870      --ss.restarts_to_go;
    871    }
    872  }
    873  if (ss.mcu_y < MCU_rows) {
    874    if (!bw->healthy) return SerializationStatus::ERROR;
    875    return SerializationStatus::NEEDS_MORE_INPUT;
    876  }
    877  Flush(coding_state, bw);
    878  if (!JumpToByteBoundary(bw, &state->pad_bits, state->pad_bits_end)) {
    879    return SerializationStatus::ERROR;
    880  }
    881  JpegBitWriterFinish(bw);
    882  ss.stage = EncodeScanState::HEAD;
    883  state->scan_index++;
    884  if (!bw->healthy) return SerializationStatus::ERROR;
    885 
    886  return SerializationStatus::DONE;
    887 }
    888 
    889 SerializationStatus JXL_INLINE EncodeScan(const JPEGData& jpg,
    890                                          SerializationState* state) {
    891  const JPEGScanInfo& scan_info = jpg.scan_info[state->scan_index];
    892  const bool is_progressive = state->is_progressive;
    893  const int Al = is_progressive ? scan_info.Al : 0;
    894  const int Ah = is_progressive ? scan_info.Ah : 0;
    895  const int Ss = is_progressive ? scan_info.Ss : 0;
    896  const int Se = is_progressive ? scan_info.Se : 63;
    897  const bool need_sequential =
    898      !is_progressive || (Ah == 0 && Al == 0 && Ss == 0 && Se == 63);
    899  if (need_sequential) {
    900    return DoEncodeScan<0>(jpg, state);
    901  } else if (Ah == 0) {
    902    return DoEncodeScan<1>(jpg, state);
    903  } else {
    904    return DoEncodeScan<2>(jpg, state);
    905  }
    906 }
    907 
    908 SerializationStatus SerializeSection(uint8_t marker, SerializationState* state,
    909                                     const JPEGData& jpg) {
    910  const auto to_status = [](bool result) {
    911    return result ? SerializationStatus::DONE : SerializationStatus::ERROR;
    912  };
    913  // TODO(eustas): add and use marker enum
    914  switch (marker) {
    915    case 0xC0:
    916    case 0xC1:
    917    case 0xC2:
    918    case 0xC9:
    919    case 0xCA:
    920      return to_status(EncodeSOF(jpg, marker, state));
    921 
    922    case 0xC4:
    923      return to_status(EncodeDHT(jpg, state));
    924 
    925    case 0xD0:
    926    case 0xD1:
    927    case 0xD2:
    928    case 0xD3:
    929    case 0xD4:
    930    case 0xD5:
    931    case 0xD6:
    932    case 0xD7:
    933      return to_status(EncodeRestart(marker, state));
    934 
    935    case 0xD9:
    936      return to_status(EncodeEOI(jpg, state));
    937 
    938    case 0xDA:
    939      return EncodeScan(jpg, state);
    940 
    941    case 0xDB:
    942      return to_status(EncodeDQT(jpg, state));
    943 
    944    case 0xDD:
    945      return to_status(EncodeDRI(jpg, state));
    946 
    947    case 0xE0:
    948    case 0xE1:
    949    case 0xE2:
    950    case 0xE3:
    951    case 0xE4:
    952    case 0xE5:
    953    case 0xE6:
    954    case 0xE7:
    955    case 0xE8:
    956    case 0xE9:
    957    case 0xEA:
    958    case 0xEB:
    959    case 0xEC:
    960    case 0xED:
    961    case 0xEE:
    962    case 0xEF:
    963      return to_status(EncodeAPP(jpg, marker, state));
    964 
    965    case 0xFE:
    966      return to_status(EncodeCOM(jpg, state));
    967 
    968    case 0xFF:
    969      return to_status(EncodeInterMarkerData(jpg, state));
    970 
    971    default:
    972      return SerializationStatus::ERROR;
    973  }
    974 }
    975 
    976 // TODO(veluca): add streaming support again.
    977 Status WriteJpegInternal(const JPEGData& jpg, const JPEGOutput& out,
    978                         SerializationState* ss) {
    979  const auto maybe_push_output = [&]() -> Status {
    980    if (ss->stage != SerializationState::STAGE_ERROR) {
    981      while (!ss->output_queue.empty()) {
    982        auto& chunk = ss->output_queue.front();
    983        size_t num_written = out(chunk.next, chunk.len);
    984        if (num_written == 0 && chunk.len > 0) {
    985          return StatusMessage(Status(StatusCode::kNotEnoughBytes),
    986                               "Failed to write output");
    987        }
    988        chunk.len -= num_written;
    989        if (chunk.len == 0) {
    990          ss->output_queue.pop_front();
    991        }
    992      }
    993    }
    994    return true;
    995  };
    996 
    997  while (true) {
    998    switch (ss->stage) {
    999      case SerializationState::STAGE_INIT: {
   1000        // Valid Brunsli requires, at least, 0xD9 marker.
   1001        // This might happen on corrupted stream, or on unconditioned JPEGData.
   1002        // TODO(eustas): check D9 in the only one and is the last one.
   1003        if (jpg.marker_order.empty()) {
   1004          ss->stage = SerializationState::STAGE_ERROR;
   1005          break;
   1006        }
   1007        ss->dc_huff_table.resize(kMaxHuffmanTables);
   1008        ss->ac_huff_table.resize(kMaxHuffmanTables);
   1009        if (jpg.has_zero_padding_bit) {
   1010          ss->pad_bits = jpg.padding_bits.data();
   1011          ss->pad_bits_end = ss->pad_bits + jpg.padding_bits.size();
   1012        }
   1013 
   1014        EncodeSOI(ss);
   1015        JXL_QUIET_RETURN_IF_ERROR(maybe_push_output());
   1016        ss->stage = SerializationState::STAGE_SERIALIZE_SECTION;
   1017        break;
   1018      }
   1019 
   1020      case SerializationState::STAGE_SERIALIZE_SECTION: {
   1021        if (ss->section_index >= jpg.marker_order.size()) {
   1022          ss->stage = SerializationState::STAGE_DONE;
   1023          break;
   1024        }
   1025        uint8_t marker = jpg.marker_order[ss->section_index];
   1026        SerializationStatus status = SerializeSection(marker, ss, jpg);
   1027        if (status == SerializationStatus::ERROR) {
   1028          JXL_WARNING("Failed to encode marker 0x%.2x", marker);
   1029          ss->stage = SerializationState::STAGE_ERROR;
   1030          break;
   1031        }
   1032        JXL_QUIET_RETURN_IF_ERROR(maybe_push_output());
   1033        if (status == SerializationStatus::NEEDS_MORE_INPUT) {
   1034          return JXL_FAILURE("Incomplete serialization data");
   1035        } else if (status != SerializationStatus::DONE) {
   1036          ss->stage = SerializationState::STAGE_ERROR;
   1037          return JXL_FAILURE("Internal logic error");
   1038          break;
   1039        }
   1040        ++ss->section_index;
   1041        break;
   1042      }
   1043 
   1044      case SerializationState::STAGE_DONE:
   1045        JXL_ENSURE(ss->output_queue.empty());
   1046        if (ss->pad_bits != nullptr && ss->pad_bits != ss->pad_bits_end) {
   1047          return JXL_FAILURE("Invalid number of padding bits.");
   1048        }
   1049        return true;
   1050 
   1051      case SerializationState::STAGE_ERROR:
   1052        return JXL_FAILURE("JPEG serialization error");
   1053    }
   1054  }
   1055 }
   1056 
   1057 }  // namespace
   1058 
   1059 Status WriteJpeg(const JPEGData& jpg, const JPEGOutput& out) {
   1060  auto ss = jxl::make_unique<SerializationState>();
   1061  return WriteJpegInternal(jpg, out, ss.get());
   1062 }
   1063 
   1064 }  // namespace jpeg
   1065 }  // namespace jxl