enc_icc_codec.cc (17955B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_icc_codec.h" 7 8 #include <jxl/memory_manager.h> 9 10 #include <cstdint> 11 #include <limits> 12 #include <map> 13 #include <vector> 14 15 #include "lib/jxl/base/status.h" 16 #include "lib/jxl/enc_ans.h" 17 #include "lib/jxl/enc_aux_out.h" 18 #include "lib/jxl/fields.h" 19 #include "lib/jxl/icc_codec_common.h" 20 #include "lib/jxl/padded_bytes.h" 21 22 namespace jxl { 23 namespace { 24 25 // Unshuffles or de-interleaves bytes, for example with width 2, turns 26 // "AaBbCcDc" into "ABCDabcd", this for example de-interleaves UTF-16 bytes into 27 // first all the high order bytes, then all the low order bytes. 28 // Transposes a matrix of width columns and ceil(size / width) rows. There are 29 // size elements, size may be < width * height, if so the 30 // last elements of the bottom row are missing, the missing spots are 31 // transposed along with the filled spots, and the result has the missing 32 // elements at the bottom of the rightmost column. The input is the input matrix 33 // in scanline order, the output is the result matrix in scanline order, with 34 // missing elements skipped over (this may occur at multiple positions). 35 Status Unshuffle(JxlMemoryManager* memory_manager, uint8_t* data, size_t size, 36 size_t width) { 37 size_t height = (size + width - 1) / width; // amount of rows of input 38 PaddedBytes result(memory_manager); 39 JXL_ASSIGN_OR_RETURN(result, 40 PaddedBytes::WithInitialSpace(memory_manager, size)); 41 42 // i = input index, j output index 43 size_t s = 0; 44 size_t j = 0; 45 for (size_t i = 0; i < size; i++) { 46 result[j] = data[i]; 47 j += height; 48 if (j >= size) j = ++s; 49 } 50 51 for (size_t i = 0; i < size; i++) { 52 data[i] = result[i]; 53 } 54 return true; 55 } 56 57 // This is performed by the encoder, the encoder must be able to encode any 58 // random byte stream (not just byte streams that are a valid ICC profile), so 59 // an error returned by this function is an implementation error. 60 Status PredictAndShuffle(size_t stride, size_t width, int order, size_t num, 61 const uint8_t* data, size_t size, size_t* pos, 62 PaddedBytes* result) { 63 JXL_RETURN_IF_ERROR(CheckOutOfBounds(*pos, num, size)); 64 JxlMemoryManager* memory_manager = result->memory_manager(); 65 // Required by the specification, see decoder. stride * 4 must be < *pos. 66 if (!*pos || ((*pos - 1u) >> 2u) < stride) { 67 return JXL_FAILURE("Invalid stride"); 68 } 69 if (*pos < stride * 4) return JXL_FAILURE("Too large stride"); 70 size_t start = result->size(); 71 for (size_t i = 0; i < num; i++) { 72 uint8_t predicted = 73 LinearPredictICCValue(data, *pos, i, stride, width, order); 74 JXL_RETURN_IF_ERROR(result->push_back(data[*pos + i] - predicted)); 75 } 76 *pos += num; 77 if (width > 1) { 78 JXL_RETURN_IF_ERROR( 79 Unshuffle(memory_manager, result->data() + start, num, width)); 80 } 81 return true; 82 } 83 84 inline Status EncodeVarInt(uint64_t value, PaddedBytes* data) { 85 size_t pos = data->size(); 86 JXL_RETURN_IF_ERROR(data->resize(data->size() + 9)); 87 size_t output_size = data->size(); 88 uint8_t* output = data->data(); 89 90 // While more than 7 bits of data are left, 91 // store 7 bits and set the next byte flag 92 while (value > 127) { 93 // TODO(eustas): should it be `<` ? 94 JXL_ENSURE(pos <= output_size); 95 // |128: Set the next byte flag 96 output[pos++] = (static_cast<uint8_t>(value & 127)) | 128; 97 // Remove the seven bits we just wrote 98 value >>= 7; 99 } 100 // TODO(eustas): should it be `<` ? 101 JXL_ENSURE(pos <= output_size); 102 output[pos++] = static_cast<uint8_t>(value & 127); 103 104 return data->resize(pos); 105 } 106 107 constexpr size_t kSizeLimit = std::numeric_limits<uint32_t>::max() >> 2; 108 109 } // namespace 110 111 // Outputs a transformed form of the given icc profile. The result itself is 112 // not particularly smaller than the input data in bytes, but it will be in a 113 // form that is easier to compress (more zeroes, ...) and will compress better 114 // with brotli. 115 Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) { 116 JxlMemoryManager* memory_manager = result->memory_manager(); 117 PaddedBytes commands{memory_manager}; 118 PaddedBytes data{memory_manager}; 119 120 static_assert(sizeof(size_t) >= 4, "size_t is too short"); 121 // Fuzzer expects that PredictICC can accept any input, 122 // but 1GB should be enough for any purpose. 123 if (size > kSizeLimit) { 124 return JXL_FAILURE("ICC profile is too large"); 125 } 126 127 JXL_RETURN_IF_ERROR(EncodeVarInt(size, result)); 128 129 // Header 130 PaddedBytes header{memory_manager}; 131 JXL_RETURN_IF_ERROR(header.append(ICCInitialHeaderPrediction(size))); 132 for (size_t i = 0; i < kICCHeaderSize && i < size; i++) { 133 ICCPredictHeader(icc, size, header.data(), i); 134 JXL_RETURN_IF_ERROR(data.push_back(icc[i] - header[i])); 135 } 136 if (size <= kICCHeaderSize) { 137 JXL_RETURN_IF_ERROR(EncodeVarInt(0, result)); // 0 commands 138 for (uint8_t b : data) { 139 JXL_RETURN_IF_ERROR(result->push_back(b)); 140 } 141 return true; 142 } 143 144 std::vector<Tag> tags; 145 std::vector<size_t> tagstarts; 146 std::vector<size_t> tagsizes; 147 std::map<size_t, size_t> tagmap; 148 149 // Tag list 150 size_t pos = kICCHeaderSize; 151 if (pos + 4 <= size) { 152 uint64_t numtags = DecodeUint32(icc, size, pos); 153 pos += 4; 154 JXL_RETURN_IF_ERROR(EncodeVarInt(numtags + 1, &commands)); 155 uint64_t prevtagstart = kICCHeaderSize + numtags * 12; 156 uint32_t prevtagsize = 0; 157 for (size_t i = 0; i < numtags; i++) { 158 if (pos + 12 > size) break; 159 160 Tag tag = DecodeKeyword(icc, size, pos + 0); 161 uint32_t tagstart = DecodeUint32(icc, size, pos + 4); 162 uint32_t tagsize = DecodeUint32(icc, size, pos + 8); 163 pos += 12; 164 165 tags.push_back(tag); 166 tagstarts.push_back(tagstart); 167 tagsizes.push_back(tagsize); 168 tagmap[tagstart] = tags.size() - 1; 169 170 uint8_t tagcode = kCommandTagUnknown; 171 for (size_t j = 0; j < kNumTagStrings; j++) { 172 if (tag == *kTagStrings[j]) { 173 tagcode = j + kCommandTagStringFirst; 174 break; 175 } 176 } 177 178 if (tag == kRtrcTag && pos + 24 < size) { 179 bool ok = true; 180 ok &= DecodeKeyword(icc, size, pos + 0) == kGtrcTag; 181 ok &= DecodeKeyword(icc, size, pos + 12) == kBtrcTag; 182 if (ok) { 183 for (size_t kk = 0; kk < 8; kk++) { 184 if (icc[pos - 8 + kk] != icc[pos + 4 + kk]) ok = false; 185 if (icc[pos - 8 + kk] != icc[pos + 16 + kk]) ok = false; 186 } 187 } 188 if (ok) { 189 tagcode = kCommandTagTRC; 190 pos += 24; 191 i += 2; 192 } 193 } 194 195 if (tag == kRxyzTag && pos + 24 < size) { 196 bool ok = true; 197 ok &= DecodeKeyword(icc, size, pos + 0) == kGxyzTag; 198 ok &= DecodeKeyword(icc, size, pos + 12) == kBxyzTag; 199 uint32_t offsetr = tagstart; 200 uint32_t offsetg = DecodeUint32(icc, size, pos + 4); 201 uint32_t offsetb = DecodeUint32(icc, size, pos + 16); 202 uint32_t sizer = tagsize; 203 uint32_t sizeg = DecodeUint32(icc, size, pos + 8); 204 uint32_t sizeb = DecodeUint32(icc, size, pos + 20); 205 ok &= sizer == 20; 206 ok &= sizeg == 20; 207 ok &= sizeb == 20; 208 ok &= (offsetg == offsetr + 20); 209 ok &= (offsetb == offsetr + 40); 210 if (ok) { 211 tagcode = kCommandTagXYZ; 212 pos += 24; 213 i += 2; 214 } 215 } 216 217 uint8_t command = tagcode; 218 uint64_t predicted_tagstart = prevtagstart + prevtagsize; 219 if (predicted_tagstart != tagstart) command |= kFlagBitOffset; 220 size_t predicted_tagsize = prevtagsize; 221 if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag || 222 tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag || 223 tag == kLumiTag) { 224 predicted_tagsize = 20; 225 } 226 if (predicted_tagsize != tagsize) command |= kFlagBitSize; 227 JXL_RETURN_IF_ERROR(commands.push_back(command)); 228 if (tagcode == 1) { 229 JXL_RETURN_IF_ERROR(AppendKeyword(tag, &data)); 230 } 231 if (command & kFlagBitOffset) 232 JXL_RETURN_IF_ERROR(EncodeVarInt(tagstart, &commands)); 233 if (command & kFlagBitSize) 234 JXL_RETURN_IF_ERROR(EncodeVarInt(tagsize, &commands)); 235 236 prevtagstart = tagstart; 237 prevtagsize = tagsize; 238 } 239 } 240 // Indicate end of tag list or varint indicating there's none 241 JXL_RETURN_IF_ERROR(commands.push_back(0)); 242 243 // Main content 244 // The main content in a valid ICC profile contains tagged elements, with the 245 // tag types (4 letter names) given by the tag list above, and the tag list 246 // pointing to the start and indicating the size of each tagged element. It is 247 // allowed for tagged elements to overlap, e.g. the curve for R, G and B could 248 // all point to the same one. 249 Tag tag; 250 size_t tagstart = 0; 251 size_t tagsize = 0; 252 size_t clutstart = 0; 253 254 // Should always check tag_sane before doing math with tagsize. 255 const auto tag_sane = [&tagsize]() { 256 return (tagsize > 8) && (tagsize < kSizeLimit); 257 }; 258 259 size_t last0 = pos; 260 // This loop appends commands to the output, processing some sub-section of a 261 // current tagged element each time. We need to keep track of the tagtype of 262 // the current element, and update it when we encounter the boundary of a 263 // next one. 264 // It is not required that the input data is a valid ICC profile, if the 265 // encoder does not recognize the data it will still be able to output bytes 266 // but will not predict as well. 267 while (pos <= size) { 268 size_t last1 = pos; 269 PaddedBytes commands_add{memory_manager}; 270 PaddedBytes data_add{memory_manager}; 271 272 // This means the loop brought the position beyond the tag end. 273 // If tagsize is nonsensical, any pos looks "ok-ish". 274 if ((pos > tagstart + tagsize) && (tagsize < kSizeLimit)) { 275 tag = {{0, 0, 0, 0}}; // nonsensical value 276 } 277 278 if (commands_add.empty() && data_add.empty() && tagmap.count(pos) && 279 pos + 4 <= size) { 280 size_t index = tagmap[pos]; 281 tag = DecodeKeyword(icc, size, pos); 282 tagstart = tagstarts[index]; 283 tagsize = tagsizes[index]; 284 285 if (tag == kMlucTag && tag_sane() && pos + tagsize <= size && 286 icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 && 287 icc[pos + 7] == 0) { 288 size_t num = tagsize - 8; 289 JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandTypeStartFirst + 3)); 290 pos += 8; 291 JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandShuffle2)); 292 JXL_RETURN_IF_ERROR(EncodeVarInt(num, &commands_add)); 293 size_t start = data_add.size(); 294 for (size_t i = 0; i < num; i++) { 295 JXL_RETURN_IF_ERROR(data_add.push_back(icc[pos])); 296 pos++; 297 } 298 JXL_RETURN_IF_ERROR( 299 Unshuffle(memory_manager, data_add.data() + start, num, 2)); 300 } 301 302 if (tag == kCurvTag && tag_sane() && pos + tagsize <= size && 303 icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 && 304 icc[pos + 7] == 0) { 305 size_t num = tagsize - 8; 306 if (num > 16 && num < (1 << 28) && pos + num <= size && pos > 0) { 307 JXL_RETURN_IF_ERROR( 308 commands_add.push_back(kCommandTypeStartFirst + 5)); 309 pos += 8; 310 JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandPredict)); 311 int order = 1; 312 int width = 2; 313 int stride = width; 314 JXL_RETURN_IF_ERROR( 315 commands_add.push_back((order << 2) | (width - 1))); 316 JXL_RETURN_IF_ERROR(EncodeVarInt(num, &commands_add)); 317 JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, 318 size, &pos, &data_add)); 319 } 320 } 321 } 322 323 if (tag == kMab_Tag || tag == kMba_Tag) { 324 Tag subTag = DecodeKeyword(icc, size, pos); 325 if (pos + 12 < size && (subTag == kCurvTag || subTag == kVcgtTag) && 326 DecodeUint32(icc, size, pos + 4) == 0) { 327 uint32_t num = DecodeUint32(icc, size, pos + 8) * 2; 328 if (num > 16 && num < (1 << 28) && pos + 12 + num <= size) { 329 pos += 12; 330 last1 = pos; 331 JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandPredict)); 332 int order = 1; 333 int width = 2; 334 int stride = width; 335 JXL_RETURN_IF_ERROR( 336 commands_add.push_back((order << 2) | (width - 1))); 337 JXL_RETURN_IF_ERROR(EncodeVarInt(num, &commands_add)); 338 JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, 339 size, &pos, &data_add)); 340 } 341 } 342 343 if (pos == tagstart + 24 && pos + 4 < size) { 344 // Note that this value can be remembered for next iterations of the 345 // loop, so the "pos == clutstart" if below can trigger during a later 346 // iteration. 347 clutstart = tagstart + DecodeUint32(icc, size, pos); 348 } 349 350 if (pos == clutstart && clutstart + 16 < size) { 351 size_t numi = icc[tagstart + 8]; 352 size_t numo = icc[tagstart + 9]; 353 size_t width = icc[clutstart + 16]; 354 size_t stride = width * numo; 355 size_t num = width * numo; 356 for (size_t i = 0; i < numi && clutstart + i < size; i++) { 357 num *= icc[clutstart + i]; 358 } 359 if ((width == 1 || width == 2) && num > 64 && num < (1 << 28) && 360 pos + num <= size && pos > stride * 4) { 361 JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandPredict)); 362 int order = 1; 363 uint8_t flags = 364 (order << 2) | (width - 1) | (stride == width ? 0 : 16); 365 JXL_RETURN_IF_ERROR(commands_add.push_back(flags)); 366 if (flags & 16) { 367 JXL_RETURN_IF_ERROR(EncodeVarInt(stride, &commands_add)); 368 } 369 JXL_RETURN_IF_ERROR(EncodeVarInt(num, &commands_add)); 370 JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, 371 size, &pos, &data_add)); 372 } 373 } 374 } 375 376 if (commands_add.empty() && data_add.empty() && tag == kGbd_Tag && 377 tag_sane() && pos == tagstart + 8 && pos + tagsize - 8 <= size && 378 pos > 16) { 379 size_t width = 4; 380 size_t order = 0; 381 size_t stride = width; 382 size_t num = tagsize - 8; 383 uint8_t flags = (order << 2) | (width - 1) | (stride == width ? 0 : 16); 384 JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandPredict)); 385 JXL_RETURN_IF_ERROR(commands_add.push_back(flags)); 386 if (flags & 16) { 387 JXL_RETURN_IF_ERROR(EncodeVarInt(stride, &commands_add)); 388 } 389 JXL_RETURN_IF_ERROR(EncodeVarInt(num, &commands_add)); 390 JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc, 391 size, &pos, &data_add)); 392 } 393 394 if (commands_add.empty() && data_add.empty() && pos + 20 <= size) { 395 Tag subTag = DecodeKeyword(icc, size, pos); 396 if (subTag == kXyz_Tag && DecodeUint32(icc, size, pos + 4) == 0) { 397 JXL_RETURN_IF_ERROR(commands_add.push_back(kCommandXYZ)); 398 pos += 8; 399 for (size_t j = 0; j < 12; j++) { 400 JXL_RETURN_IF_ERROR(data_add.push_back(icc[pos++])); 401 } 402 } 403 } 404 405 if (commands_add.empty() && data_add.empty() && pos + 8 <= size) { 406 if (DecodeUint32(icc, size, pos + 4) == 0) { 407 Tag subTag = DecodeKeyword(icc, size, pos); 408 for (size_t i = 0; i < kNumTypeStrings; i++) { 409 if (subTag == *kTypeStrings[i]) { 410 JXL_RETURN_IF_ERROR( 411 commands_add.push_back(kCommandTypeStartFirst + i)); 412 pos += 8; 413 break; 414 } 415 } 416 } 417 } 418 419 if (!(commands_add.empty() && data_add.empty()) || pos == size) { 420 if (last0 < last1) { 421 JXL_RETURN_IF_ERROR(commands.push_back(kCommandInsert)); 422 JXL_RETURN_IF_ERROR(EncodeVarInt(last1 - last0, &commands)); 423 while (last0 < last1) { 424 JXL_RETURN_IF_ERROR(data.push_back(icc[last0++])); 425 } 426 } 427 for (uint8_t b : commands_add) { 428 JXL_RETURN_IF_ERROR(commands.push_back(b)); 429 } 430 for (uint8_t b : data_add) { 431 JXL_RETURN_IF_ERROR(data.push_back(b)); 432 } 433 last0 = pos; 434 } 435 if (commands_add.empty() && data_add.empty()) { 436 pos++; 437 } 438 } 439 440 JXL_RETURN_IF_ERROR(EncodeVarInt(commands.size(), result)); 441 for (uint8_t b : commands) { 442 JXL_RETURN_IF_ERROR(result->push_back(b)); 443 } 444 for (uint8_t b : data) { 445 JXL_RETURN_IF_ERROR(result->push_back(b)); 446 } 447 448 return true; 449 } 450 451 Status WriteICC(const Span<const uint8_t> icc, BitWriter* JXL_RESTRICT writer, 452 LayerType layer, AuxOut* JXL_RESTRICT aux_out) { 453 if (icc.empty()) return JXL_FAILURE("ICC must be non-empty"); 454 JxlMemoryManager* memory_manager = writer->memory_manager(); 455 PaddedBytes enc{memory_manager}; 456 JXL_RETURN_IF_ERROR(PredictICC(icc.data(), icc.size(), &enc)); 457 std::vector<std::vector<Token>> tokens(1); 458 JXL_RETURN_IF_ERROR(writer->WithMaxBits(128, layer, aux_out, [&] { 459 return U64Coder::Write(enc.size(), writer); 460 })); 461 462 for (size_t i = 0; i < enc.size(); i++) { 463 tokens[0].emplace_back( 464 ICCANSContext(i, i > 0 ? enc[i - 1] : 0, i > 1 ? enc[i - 2] : 0), 465 enc[i]); 466 } 467 HistogramParams params; 468 params.lz77_method = enc.size() < 4096 ? HistogramParams::LZ77Method::kOptimal 469 : HistogramParams::LZ77Method::kLZ77; 470 EntropyEncodingData code; 471 std::vector<uint8_t> context_map; 472 params.force_huffman = true; 473 JXL_ASSIGN_OR_RETURN( 474 size_t cost, 475 BuildAndEncodeHistograms(memory_manager, params, kNumICCContexts, tokens, 476 &code, &context_map, writer, layer, aux_out)); 477 (void)cost; 478 JXL_RETURN_IF_ERROR( 479 WriteTokens(tokens[0], code, context_map, 0, writer, layer, aux_out)); 480 return true; 481 } 482 483 } // namespace jxl