enc_heuristics.cc (48256B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_heuristics.h" 7 8 #include <jxl/cms_interface.h> 9 #include <jxl/memory_manager.h> 10 11 #include <algorithm> 12 #include <cstddef> 13 #include <cstdint> 14 #include <cstdlib> 15 #include <limits> 16 #include <memory> 17 #include <numeric> 18 #include <string> 19 #include <utility> 20 #include <vector> 21 22 #include "lib/jxl/ac_context.h" 23 #include "lib/jxl/ac_strategy.h" 24 #include "lib/jxl/base/common.h" 25 #include "lib/jxl/base/compiler_specific.h" 26 #include "lib/jxl/base/data_parallel.h" 27 #include "lib/jxl/base/override.h" 28 #include "lib/jxl/base/rect.h" 29 #include "lib/jxl/base/status.h" 30 #include "lib/jxl/butteraugli/butteraugli.h" 31 #include "lib/jxl/chroma_from_luma.h" 32 #include "lib/jxl/coeff_order.h" 33 #include "lib/jxl/coeff_order_fwd.h" 34 #include "lib/jxl/common.h" 35 #include "lib/jxl/dec_cache.h" 36 #include "lib/jxl/dec_group.h" 37 #include "lib/jxl/dec_noise.h" 38 #include "lib/jxl/dec_xyb.h" 39 #include "lib/jxl/enc_ac_strategy.h" 40 #include "lib/jxl/enc_adaptive_quantization.h" 41 #include "lib/jxl/enc_cache.h" 42 #include "lib/jxl/enc_chroma_from_luma.h" 43 #include "lib/jxl/enc_gaborish.h" 44 #include "lib/jxl/enc_modular.h" 45 #include "lib/jxl/enc_noise.h" 46 #include "lib/jxl/enc_params.h" 47 #include "lib/jxl/enc_patch_dictionary.h" 48 #include "lib/jxl/enc_quant_weights.h" 49 #include "lib/jxl/enc_splines.h" 50 #include "lib/jxl/epf.h" 51 #include "lib/jxl/frame_dimensions.h" 52 #include "lib/jxl/frame_header.h" 53 #include "lib/jxl/image.h" 54 #include "lib/jxl/image_metadata.h" 55 #include "lib/jxl/image_ops.h" 56 #include "lib/jxl/memory_manager_internal.h" 57 #include "lib/jxl/passes_state.h" 58 #include "lib/jxl/quant_weights.h" 59 60 namespace jxl { 61 62 struct AuxOut; 63 64 void FindBestBlockEntropyModel(const CompressParams& cparams, const ImageI& rqf, 65 const AcStrategyImage& ac_strategy, 66 BlockCtxMap* block_ctx_map) { 67 if (cparams.decoding_speed_tier >= 1) { 68 static constexpr uint8_t kSimpleCtxMap[] = { 69 // Cluster all blocks together 70 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 71 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 72 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 73 }; 74 static_assert( 75 3 * kNumOrders == sizeof(kSimpleCtxMap) / sizeof *kSimpleCtxMap, 76 "Update simple context map"); 77 78 auto bcm = *block_ctx_map; 79 bcm.ctx_map.assign(std::begin(kSimpleCtxMap), std::end(kSimpleCtxMap)); 80 bcm.num_ctxs = 2; 81 bcm.num_dc_ctxs = 1; 82 return; 83 } 84 if (cparams.speed_tier >= SpeedTier::kFalcon) { 85 return; 86 } 87 // No need to change context modeling for small images. 88 size_t tot = rqf.xsize() * rqf.ysize(); 89 size_t size_for_ctx_model = (1 << 10) * cparams.butteraugli_distance; 90 if (tot < size_for_ctx_model) return; 91 92 struct OccCounters { 93 // count the occurrences of each qf value and each strategy type. 94 OccCounters(const ImageI& rqf, const AcStrategyImage& ac_strategy) { 95 for (size_t y = 0; y < rqf.ysize(); y++) { 96 const int32_t* qf_row = rqf.Row(y); 97 AcStrategyRow acs_row = ac_strategy.ConstRow(y); 98 for (size_t x = 0; x < rqf.xsize(); x++) { 99 int ord = kStrategyOrder[acs_row[x].RawStrategy()]; 100 int qf = qf_row[x] - 1; 101 qf_counts[qf]++; 102 qf_ord_counts[ord][qf]++; 103 ord_counts[ord]++; 104 } 105 } 106 } 107 108 size_t qf_counts[256] = {}; 109 size_t qf_ord_counts[kNumOrders][256] = {}; 110 size_t ord_counts[kNumOrders] = {}; 111 }; 112 // The OccCounters struct is too big to allocate on the stack. 113 std::unique_ptr<OccCounters> counters(new OccCounters(rqf, ac_strategy)); 114 115 // Splitting the context model according to the quantization field seems to 116 // mostly benefit only large images. 117 size_t size_for_qf_split = (1 << 13) * cparams.butteraugli_distance; 118 size_t num_qf_segments = tot < size_for_qf_split ? 1 : 2; 119 std::vector<uint32_t>& qft = block_ctx_map->qf_thresholds; 120 qft.clear(); 121 // Divide the quant field in up to num_qf_segments segments. 122 size_t cumsum = 0; 123 size_t next = 1; 124 size_t last_cut = 256; 125 size_t cut = tot * next / num_qf_segments; 126 for (uint32_t j = 0; j < 256; j++) { 127 cumsum += counters->qf_counts[j]; 128 if (cumsum > cut) { 129 if (j != 0) { 130 qft.push_back(j); 131 } 132 last_cut = j; 133 while (cumsum > cut) { 134 next++; 135 cut = tot * next / num_qf_segments; 136 } 137 } else if (next > qft.size() + 1) { 138 if (j - 1 == last_cut && j != 0) { 139 qft.push_back(j); 140 } 141 } 142 } 143 144 // Count the occurrences of each segment. 145 std::vector<size_t> counts(kNumOrders * (qft.size() + 1)); 146 size_t qft_pos = 0; 147 for (size_t j = 0; j < 256; j++) { 148 if (qft_pos < qft.size() && j == qft[qft_pos]) { 149 qft_pos++; 150 } 151 for (size_t i = 0; i < kNumOrders; i++) { 152 counts[qft_pos + i * (qft.size() + 1)] += counters->qf_ord_counts[i][j]; 153 } 154 } 155 156 // Repeatedly merge the lowest-count pair. 157 std::vector<uint8_t> remap((qft.size() + 1) * kNumOrders); 158 std::iota(remap.begin(), remap.end(), 0); 159 std::vector<uint8_t> clusters(remap); 160 size_t nb_clusters = 161 Clamp1(static_cast<int>(tot / size_for_ctx_model / 2), 2, 9); 162 size_t nb_clusters_chroma = 163 Clamp1(static_cast<int>(tot / size_for_ctx_model / 3), 1, 5); 164 // This is O(n^2 log n), but n is small. 165 while (clusters.size() > nb_clusters) { 166 std::sort(clusters.begin(), clusters.end(), 167 [&](int a, int b) { return counts[a] > counts[b]; }); 168 counts[clusters[clusters.size() - 2]] += counts[clusters.back()]; 169 counts[clusters.back()] = 0; 170 remap[clusters.back()] = clusters[clusters.size() - 2]; 171 clusters.pop_back(); 172 } 173 for (size_t i = 0; i < remap.size(); i++) { 174 while (remap[remap[i]] != remap[i]) { 175 remap[i] = remap[remap[i]]; 176 } 177 } 178 // Relabel starting from 0. 179 std::vector<uint8_t> remap_remap(remap.size(), remap.size()); 180 size_t num = 0; 181 for (size_t i = 0; i < remap.size(); i++) { 182 if (remap_remap[remap[i]] == remap.size()) { 183 remap_remap[remap[i]] = num++; 184 } 185 remap[i] = remap_remap[remap[i]]; 186 } 187 // Write the block context map. 188 auto& ctx_map = block_ctx_map->ctx_map; 189 ctx_map = remap; 190 ctx_map.resize(remap.size() * 3); 191 // for chroma, only use up to nb_clusters_chroma separate block contexts 192 // (those for the biggest clusters) 193 for (size_t i = remap.size(); i < remap.size() * 3; i++) { 194 ctx_map[i] = num + Clamp1(static_cast<int>(remap[i % remap.size()]), 0, 195 static_cast<int>(nb_clusters_chroma) - 1); 196 } 197 block_ctx_map->num_ctxs = 198 *std::max_element(ctx_map.begin(), ctx_map.end()) + 1; 199 } 200 201 namespace { 202 203 Status FindBestDequantMatrices(JxlMemoryManager* memory_manager, 204 const CompressParams& cparams, 205 ModularFrameEncoder* modular_frame_encoder, 206 DequantMatrices* dequant_matrices) { 207 // TODO(veluca): quant matrices for no-gaborish. 208 // TODO(veluca): heuristics for in-bitstream quant tables. 209 *dequant_matrices = DequantMatrices(); 210 if (cparams.max_error_mode || cparams.disable_perceptual_optimizations) { 211 constexpr float kMSEWeights[3] = {0.001, 0.001, 0.001}; 212 const float* wp = cparams.disable_perceptual_optimizations 213 ? kMSEWeights 214 : cparams.max_error; 215 // Set numerators of all quantization matrices to constant values. 216 float weights[3][1] = {{1.0f / wp[0]}, {1.0f / wp[1]}, {1.0f / wp[2]}}; 217 DctQuantWeightParams dct_params(weights); 218 std::vector<QuantEncoding> encodings(kNumQuantTables, 219 QuantEncoding::DCT(dct_params)); 220 JXL_RETURN_IF_ERROR(DequantMatricesSetCustom(dequant_matrices, encodings, 221 modular_frame_encoder)); 222 float dc_weights[3] = {1.0f / wp[0], 1.0f / wp[1], 1.0f / wp[2]}; 223 JXL_RETURN_IF_ERROR(DequantMatricesSetCustomDC( 224 memory_manager, dequant_matrices, dc_weights)); 225 } 226 return true; 227 } 228 229 void StoreMin2(const float v, float& min1, float& min2) { 230 if (v < min2) { 231 if (v < min1) { 232 min2 = min1; 233 min1 = v; 234 } else { 235 min2 = v; 236 } 237 } 238 } 239 240 void CreateMask(const ImageF& image, ImageF& mask) { 241 for (size_t y = 0; y < image.ysize(); y++) { 242 const auto* row_n = y > 0 ? image.Row(y - 1) : image.Row(y); 243 const auto* row_in = image.Row(y); 244 const auto* row_s = y + 1 < image.ysize() ? image.Row(y + 1) : image.Row(y); 245 auto* row_out = mask.Row(y); 246 for (size_t x = 0; x < image.xsize(); x++) { 247 // Center, west, east, north, south values and their absolute difference 248 float c = row_in[x]; 249 float w = x > 0 ? row_in[x - 1] : row_in[x]; 250 float e = x + 1 < image.xsize() ? row_in[x + 1] : row_in[x]; 251 float n = row_n[x]; 252 float s = row_s[x]; 253 float dw = std::abs(c - w); 254 float de = std::abs(c - e); 255 float dn = std::abs(c - n); 256 float ds = std::abs(c - s); 257 float min = std::numeric_limits<float>::max(); 258 float min2 = std::numeric_limits<float>::max(); 259 StoreMin2(dw, min, min2); 260 StoreMin2(de, min, min2); 261 StoreMin2(dn, min, min2); 262 StoreMin2(ds, min, min2); 263 row_out[x] = min2; 264 } 265 } 266 } 267 268 // Downsamples the image by a factor of 2 with a kernel that's sharper than 269 // the standard 2x2 box kernel used by DownsampleImage. 270 // The kernel is optimized against the result of the 2x2 upsampling kernel used 271 // by the decoder. Ringing is slightly reduced by clamping the values of the 272 // resulting pixels within certain bounds of a small region in the original 273 // image. 274 Status DownsampleImage2_Sharper(const ImageF& input, ImageF* output) { 275 const int64_t kernelx = 12; 276 const int64_t kernely = 12; 277 JxlMemoryManager* memory_manager = input.memory_manager(); 278 279 static const float kernel[144] = { 280 -0.000314256996835, -0.000314256996835, -0.000897597057705, 281 -0.000562751488849, -0.000176807273646, 0.001864627368902, 282 0.001864627368902, -0.000176807273646, -0.000562751488849, 283 -0.000897597057705, -0.000314256996835, -0.000314256996835, 284 -0.000314256996835, -0.001527942804748, -0.000121760530512, 285 0.000191123989093, 0.010193185932466, 0.058637519197110, 286 0.058637519197110, 0.010193185932466, 0.000191123989093, 287 -0.000121760530512, -0.001527942804748, -0.000314256996835, 288 -0.000897597057705, -0.000121760530512, 0.000946363683751, 289 0.007113577630288, 0.000437956841058, -0.000372823835211, 290 -0.000372823835211, 0.000437956841058, 0.007113577630288, 291 0.000946363683751, -0.000121760530512, -0.000897597057705, 292 -0.000562751488849, 0.000191123989093, 0.007113577630288, 293 0.044592622228814, 0.000222278879007, -0.162864473015945, 294 -0.162864473015945, 0.000222278879007, 0.044592622228814, 295 0.007113577630288, 0.000191123989093, -0.000562751488849, 296 -0.000176807273646, 0.010193185932466, 0.000437956841058, 297 0.000222278879007, -0.000913092543974, -0.017071696107902, 298 -0.017071696107902, -0.000913092543974, 0.000222278879007, 299 0.000437956841058, 0.010193185932466, -0.000176807273646, 300 0.001864627368902, 0.058637519197110, -0.000372823835211, 301 -0.162864473015945, -0.017071696107902, 0.414660099370354, 302 0.414660099370354, -0.017071696107902, -0.162864473015945, 303 -0.000372823835211, 0.058637519197110, 0.001864627368902, 304 0.001864627368902, 0.058637519197110, -0.000372823835211, 305 -0.162864473015945, -0.017071696107902, 0.414660099370354, 306 0.414660099370354, -0.017071696107902, -0.162864473015945, 307 -0.000372823835211, 0.058637519197110, 0.001864627368902, 308 -0.000176807273646, 0.010193185932466, 0.000437956841058, 309 0.000222278879007, -0.000913092543974, -0.017071696107902, 310 -0.017071696107902, -0.000913092543974, 0.000222278879007, 311 0.000437956841058, 0.010193185932466, -0.000176807273646, 312 -0.000562751488849, 0.000191123989093, 0.007113577630288, 313 0.044592622228814, 0.000222278879007, -0.162864473015945, 314 -0.162864473015945, 0.000222278879007, 0.044592622228814, 315 0.007113577630288, 0.000191123989093, -0.000562751488849, 316 -0.000897597057705, -0.000121760530512, 0.000946363683751, 317 0.007113577630288, 0.000437956841058, -0.000372823835211, 318 -0.000372823835211, 0.000437956841058, 0.007113577630288, 319 0.000946363683751, -0.000121760530512, -0.000897597057705, 320 -0.000314256996835, -0.001527942804748, -0.000121760530512, 321 0.000191123989093, 0.010193185932466, 0.058637519197110, 322 0.058637519197110, 0.010193185932466, 0.000191123989093, 323 -0.000121760530512, -0.001527942804748, -0.000314256996835, 324 -0.000314256996835, -0.000314256996835, -0.000897597057705, 325 -0.000562751488849, -0.000176807273646, 0.001864627368902, 326 0.001864627368902, -0.000176807273646, -0.000562751488849, 327 -0.000897597057705, -0.000314256996835, -0.000314256996835}; 328 329 int64_t xsize = input.xsize(); 330 int64_t ysize = input.ysize(); 331 332 JXL_ASSIGN_OR_RETURN(ImageF box_downsample, 333 ImageF::Create(memory_manager, xsize, ysize)); 334 JXL_RETURN_IF_ERROR(CopyImageTo(input, &box_downsample)); 335 JXL_ASSIGN_OR_RETURN(box_downsample, DownsampleImage(box_downsample, 2)); 336 337 JXL_ASSIGN_OR_RETURN(ImageF mask, 338 ImageF::Create(memory_manager, box_downsample.xsize(), 339 box_downsample.ysize())); 340 CreateMask(box_downsample, mask); 341 342 for (size_t y = 0; y < output->ysize(); y++) { 343 float* row_out = output->Row(y); 344 const float* row_in[kernely]; 345 const float* row_mask = mask.Row(y); 346 // get the rows in the support 347 for (size_t ky = 0; ky < kernely; ky++) { 348 int64_t iy = y * 2 + ky - (kernely - 1) / 2; 349 if (iy < 0) iy = 0; 350 if (iy >= ysize) iy = ysize - 1; 351 row_in[ky] = input.Row(iy); 352 } 353 354 for (size_t x = 0; x < output->xsize(); x++) { 355 // get min and max values of the original image in the support 356 float min = std::numeric_limits<float>::max(); 357 float max = std::numeric_limits<float>::min(); 358 // kernelx - R and kernely - R are the radius of a rectangular region in 359 // which the values of a pixel are bounded to reduce ringing. 360 static constexpr int64_t R = 5; 361 for (int64_t ky = R; ky + R < kernely; ky++) { 362 for (int64_t kx = R; kx + R < kernelx; kx++) { 363 int64_t ix = x * 2 + kx - (kernelx - 1) / 2; 364 if (ix < 0) ix = 0; 365 if (ix >= xsize) ix = xsize - 1; 366 min = std::min<float>(min, row_in[ky][ix]); 367 max = std::max<float>(max, row_in[ky][ix]); 368 } 369 } 370 371 float sum = 0; 372 for (int64_t ky = 0; ky < kernely; ky++) { 373 for (int64_t kx = 0; kx < kernelx; kx++) { 374 int64_t ix = x * 2 + kx - (kernelx - 1) / 2; 375 if (ix < 0) ix = 0; 376 if (ix >= xsize) ix = xsize - 1; 377 sum += row_in[ky][ix] * kernel[ky * kernelx + kx]; 378 } 379 } 380 381 row_out[x] = sum; 382 383 // Clamp the pixel within the value of a small area to prevent ringning. 384 // The mask determines how much to clamp, clamp more to reduce more 385 // ringing in smooth areas, clamp less in noisy areas to get more 386 // sharpness. Higher mask_multiplier gives less clamping, so less 387 // ringing reduction. 388 const constexpr float mask_multiplier = 1; 389 float a = row_mask[x] * mask_multiplier; 390 float clip_min = min - a; 391 float clip_max = max + a; 392 if (row_out[x] < clip_min) { 393 row_out[x] = clip_min; 394 } else if (row_out[x] > clip_max) { 395 row_out[x] = clip_max; 396 } 397 } 398 } 399 return true; 400 } 401 402 } // namespace 403 404 Status DownsampleImage2_Sharper(Image3F* opsin) { 405 // Allocate extra space to avoid a reallocation when padding. 406 JxlMemoryManager* memory_manager = opsin->memory_manager(); 407 JXL_ASSIGN_OR_RETURN( 408 Image3F downsampled, 409 Image3F::Create(memory_manager, DivCeil(opsin->xsize(), 2) + kBlockDim, 410 DivCeil(opsin->ysize(), 2) + kBlockDim)); 411 JXL_RETURN_IF_ERROR(downsampled.ShrinkTo(downsampled.xsize() - kBlockDim, 412 downsampled.ysize() - kBlockDim)); 413 414 for (size_t c = 0; c < 3; c++) { 415 JXL_RETURN_IF_ERROR( 416 DownsampleImage2_Sharper(opsin->Plane(c), &downsampled.Plane(c))); 417 } 418 *opsin = std::move(downsampled); 419 return true; 420 } 421 422 namespace { 423 424 // The default upsampling kernels used by Upsampler in the decoder. 425 const constexpr int64_t kSize = 5; 426 427 const float kernel00[25] = { 428 -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f, 429 -0.03452303f, 0.14111091f, 0.28896755f, 0.00278718f, -0.01610267f, 430 -0.04022174f, 0.28896755f, 0.56661550f, 0.03777607f, -0.01986694f, 431 -0.02921014f, 0.00278718f, 0.03777607f, -0.03144731f, -0.01185068f, 432 -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f, 433 }; 434 const float kernel01[25] = { 435 -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f, 436 -0.02921014f, 0.00278718f, 0.03777607f, -0.03144731f, -0.01185068f, 437 -0.04022174f, 0.28896755f, 0.56661550f, 0.03777607f, -0.01986694f, 438 -0.03452303f, 0.14111091f, 0.28896755f, 0.00278718f, -0.01610267f, 439 -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f, 440 }; 441 const float kernel10[25] = { 442 -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f, 443 -0.01610267f, 0.00278718f, 0.28896755f, 0.14111091f, -0.03452303f, 444 -0.01986694f, 0.03777607f, 0.56661550f, 0.28896755f, -0.04022174f, 445 -0.01185068f, -0.03144731f, 0.03777607f, 0.00278718f, -0.02921014f, 446 -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f, 447 }; 448 const float kernel11[25] = { 449 -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f, 450 -0.01185068f, -0.03144731f, 0.03777607f, 0.00278718f, -0.02921014f, 451 -0.01986694f, 0.03777607f, 0.56661550f, 0.28896755f, -0.04022174f, 452 -0.01610267f, 0.00278718f, 0.28896755f, 0.14111091f, -0.03452303f, 453 -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f, 454 }; 455 456 // Does exactly the same as the Upsampler in dec_upsampler for 2x2 pixels, with 457 // default CustomTransformData. 458 // TODO(lode): use Upsampler instead. However, it requires pre-initialization 459 // and padding on the left side of the image which requires refactoring the 460 // other code using this. 461 void UpsampleImage(const ImageF& input, ImageF* output) { 462 int64_t xsize = input.xsize(); 463 int64_t ysize = input.ysize(); 464 int64_t xsize2 = output->xsize(); 465 int64_t ysize2 = output->ysize(); 466 for (int64_t y = 0; y < ysize2; y++) { 467 for (int64_t x = 0; x < xsize2; x++) { 468 const auto* kernel = kernel00; 469 if ((x & 1) && (y & 1)) { 470 kernel = kernel11; 471 } else if (x & 1) { 472 kernel = kernel10; 473 } else if (y & 1) { 474 kernel = kernel01; 475 } 476 float sum = 0; 477 int64_t x2 = x / 2; 478 int64_t y2 = y / 2; 479 480 // get min and max values of the original image in the support 481 float min = std::numeric_limits<float>::max(); 482 float max = std::numeric_limits<float>::min(); 483 484 for (int64_t ky = 0; ky < kSize; ky++) { 485 for (int64_t kx = 0; kx < kSize; kx++) { 486 int64_t xi = x2 - kSize / 2 + kx; 487 int64_t yi = y2 - kSize / 2 + ky; 488 if (xi < 0) xi = 0; 489 if (xi >= xsize) xi = input.xsize() - 1; 490 if (yi < 0) yi = 0; 491 if (yi >= ysize) yi = input.ysize() - 1; 492 min = std::min<float>(min, input.Row(yi)[xi]); 493 max = std::max<float>(max, input.Row(yi)[xi]); 494 } 495 } 496 497 for (int64_t ky = 0; ky < kSize; ky++) { 498 for (int64_t kx = 0; kx < kSize; kx++) { 499 int64_t xi = x2 - kSize / 2 + kx; 500 int64_t yi = y2 - kSize / 2 + ky; 501 if (xi < 0) xi = 0; 502 if (xi >= xsize) xi = input.xsize() - 1; 503 if (yi < 0) yi = 0; 504 if (yi >= ysize) yi = input.ysize() - 1; 505 sum += input.Row(yi)[xi] * kernel[ky * kSize + kx]; 506 } 507 } 508 output->Row(y)[x] = sum; 509 if (output->Row(y)[x] < min) output->Row(y)[x] = min; 510 if (output->Row(y)[x] > max) output->Row(y)[x] = max; 511 } 512 } 513 } 514 515 // Returns the derivative of Upsampler, with respect to input pixel x2, y2, to 516 // output pixel x, y (ignoring the clamping). 517 float UpsamplerDeriv(int64_t x2, int64_t y2, int64_t x, int64_t y) { 518 const auto* kernel = kernel00; 519 if ((x & 1) && (y & 1)) { 520 kernel = kernel11; 521 } else if (x & 1) { 522 kernel = kernel10; 523 } else if (y & 1) { 524 kernel = kernel01; 525 } 526 527 int64_t ix = x / 2; 528 int64_t iy = y / 2; 529 int64_t kx = x2 - ix + kSize / 2; 530 int64_t ky = y2 - iy + kSize / 2; 531 532 // This should not happen. 533 if (kx < 0 || kx >= kSize || ky < 0 || ky >= kSize) return 0; 534 535 return kernel[ky * kSize + kx]; 536 } 537 538 // Apply the derivative of the Upsampler to the input, reversing the effect of 539 // its coefficients. The output image is 2x2 times smaller than the input. 540 void AntiUpsample(const ImageF& input, ImageF* d) { 541 int64_t xsize = input.xsize(); 542 int64_t ysize = input.ysize(); 543 int64_t xsize2 = d->xsize(); 544 int64_t ysize2 = d->ysize(); 545 int64_t k0 = kSize - 1; 546 int64_t k1 = kSize; 547 for (int64_t y2 = 0; y2 < ysize2; ++y2) { 548 auto* row = d->Row(y2); 549 for (int64_t x2 = 0; x2 < xsize2; ++x2) { 550 int64_t x0 = x2 * 2 - k0; 551 if (x0 < 0) x0 = 0; 552 int64_t x1 = x2 * 2 + k1 + 1; 553 if (x1 > xsize) x1 = xsize; 554 int64_t y0 = y2 * 2 - k0; 555 if (y0 < 0) y0 = 0; 556 int64_t y1 = y2 * 2 + k1 + 1; 557 if (y1 > ysize) y1 = ysize; 558 559 float sum = 0; 560 for (int64_t y = y0; y < y1; ++y) { 561 const auto* row_in = input.Row(y); 562 for (int64_t x = x0; x < x1; ++x) { 563 double deriv = UpsamplerDeriv(x2, y2, x, y); 564 sum += deriv * row_in[x]; 565 } 566 } 567 row[x2] = sum; 568 } 569 } 570 } 571 572 // Element-wise multiplies two images. 573 template <typename T> 574 Status ElwiseMul(const Plane<T>& image1, const Plane<T>& image2, 575 Plane<T>* out) { 576 const size_t xsize = image1.xsize(); 577 const size_t ysize = image1.ysize(); 578 JXL_ENSURE(xsize == image2.xsize()); 579 JXL_ENSURE(ysize == image2.ysize()); 580 JXL_ENSURE(xsize == out->xsize()); 581 JXL_ENSURE(ysize == out->ysize()); 582 for (size_t y = 0; y < ysize; ++y) { 583 const T* const JXL_RESTRICT row1 = image1.Row(y); 584 const T* const JXL_RESTRICT row2 = image2.Row(y); 585 T* const JXL_RESTRICT row_out = out->Row(y); 586 for (size_t x = 0; x < xsize; ++x) { 587 row_out[x] = row1[x] * row2[x]; 588 } 589 } 590 return true; 591 } 592 593 // Element-wise divides two images. 594 template <typename T> 595 Status ElwiseDiv(const Plane<T>& image1, const Plane<T>& image2, 596 Plane<T>* out) { 597 const size_t xsize = image1.xsize(); 598 const size_t ysize = image1.ysize(); 599 JXL_ENSURE(xsize == image2.xsize()); 600 JXL_ENSURE(ysize == image2.ysize()); 601 JXL_ENSURE(xsize == out->xsize()); 602 JXL_ENSURE(ysize == out->ysize()); 603 for (size_t y = 0; y < ysize; ++y) { 604 const T* const JXL_RESTRICT row1 = image1.Row(y); 605 const T* const JXL_RESTRICT row2 = image2.Row(y); 606 T* const JXL_RESTRICT row_out = out->Row(y); 607 for (size_t x = 0; x < xsize; ++x) { 608 row_out[x] = row1[x] / row2[x]; 609 } 610 } 611 return true; 612 } 613 614 void ReduceRinging(const ImageF& initial, const ImageF& mask, ImageF& down) { 615 int64_t xsize2 = down.xsize(); 616 int64_t ysize2 = down.ysize(); 617 618 for (size_t y = 0; y < down.ysize(); y++) { 619 const float* row_mask = mask.Row(y); 620 float* row_out = down.Row(y); 621 for (size_t x = 0; x < down.xsize(); x++) { 622 float v = down.Row(y)[x]; 623 float min = initial.Row(y)[x]; 624 float max = initial.Row(y)[x]; 625 for (int64_t yi = -1; yi < 2; yi++) { 626 for (int64_t xi = -1; xi < 2; xi++) { 627 int64_t x2 = static_cast<int64_t>(x) + xi; 628 int64_t y2 = static_cast<int64_t>(y) + yi; 629 if (x2 < 0 || y2 < 0 || x2 >= xsize2 || y2 >= ysize2) continue; 630 min = std::min<float>(min, initial.Row(y2)[x2]); 631 max = std::max<float>(max, initial.Row(y2)[x2]); 632 } 633 } 634 635 row_out[x] = v; 636 637 // Clamp the pixel within the value of a small area to prevent ringning. 638 // The mask determines how much to clamp, clamp more to reduce more 639 // ringing in smooth areas, clamp less in noisy areas to get more 640 // sharpness. Higher mask_multiplier gives less clamping, so less 641 // ringing reduction. 642 const constexpr float mask_multiplier = 2; 643 float a = row_mask[x] * mask_multiplier; 644 float clip_min = min - a; 645 float clip_max = max + a; 646 if (row_out[x] < clip_min) row_out[x] = clip_min; 647 if (row_out[x] > clip_max) row_out[x] = clip_max; 648 } 649 } 650 } 651 652 // TODO(lode): move this to a separate file enc_downsample.cc 653 Status DownsampleImage2_Iterative(const ImageF& orig, ImageF* output) { 654 int64_t xsize = orig.xsize(); 655 int64_t ysize = orig.ysize(); 656 int64_t xsize2 = DivCeil(orig.xsize(), 2); 657 int64_t ysize2 = DivCeil(orig.ysize(), 2); 658 JxlMemoryManager* memory_manager = orig.memory_manager(); 659 660 JXL_ASSIGN_OR_RETURN(ImageF box_downsample, 661 ImageF::Create(memory_manager, xsize, ysize)); 662 JXL_RETURN_IF_ERROR(CopyImageTo(orig, &box_downsample)); 663 JXL_ASSIGN_OR_RETURN(box_downsample, DownsampleImage(box_downsample, 2)); 664 JXL_ASSIGN_OR_RETURN(ImageF mask, 665 ImageF::Create(memory_manager, box_downsample.xsize(), 666 box_downsample.ysize())); 667 CreateMask(box_downsample, mask); 668 669 JXL_RETURN_IF_ERROR(output->ShrinkTo(xsize2, ysize2)); 670 671 // Initial result image using the sharper downsampling. 672 // Allocate extra space to avoid a reallocation when padding. 673 JXL_ASSIGN_OR_RETURN( 674 ImageF initial, 675 ImageF::Create(memory_manager, DivCeil(orig.xsize(), 2) + kBlockDim, 676 DivCeil(orig.ysize(), 2) + kBlockDim)); 677 JXL_RETURN_IF_ERROR(initial.ShrinkTo(initial.xsize() - kBlockDim, 678 initial.ysize() - kBlockDim)); 679 JXL_RETURN_IF_ERROR(DownsampleImage2_Sharper(orig, &initial)); 680 681 JXL_ASSIGN_OR_RETURN( 682 ImageF down, 683 ImageF::Create(memory_manager, initial.xsize(), initial.ysize())); 684 JXL_RETURN_IF_ERROR(CopyImageTo(initial, &down)); 685 JXL_ASSIGN_OR_RETURN(ImageF up, ImageF::Create(memory_manager, xsize, ysize)); 686 JXL_ASSIGN_OR_RETURN(ImageF corr, 687 ImageF::Create(memory_manager, xsize, ysize)); 688 JXL_ASSIGN_OR_RETURN(ImageF corr2, 689 ImageF::Create(memory_manager, xsize2, ysize2)); 690 691 // In the weights map, relatively higher values will allow less ringing but 692 // also less sharpness. With all constant values, it optimizes equally 693 // everywhere. Even in this case, the weights2 computed from 694 // this is still used and differs at the borders of the image. 695 // TODO(lode): Make use of the weights field for anti-ringing and clamping, 696 // the values are all set to 1 for now, but it is intended to be used for 697 // reducing ringing based on the mask, and taking clamping into account. 698 JXL_ASSIGN_OR_RETURN(ImageF weights, 699 ImageF::Create(memory_manager, xsize, ysize)); 700 for (size_t y = 0; y < weights.ysize(); y++) { 701 auto* row = weights.Row(y); 702 for (size_t x = 0; x < weights.xsize(); x++) { 703 row[x] = 1; 704 } 705 } 706 JXL_ASSIGN_OR_RETURN(ImageF weights2, 707 ImageF::Create(memory_manager, xsize2, ysize2)); 708 AntiUpsample(weights, &weights2); 709 710 const size_t num_it = 3; 711 for (size_t it = 0; it < num_it; ++it) { 712 UpsampleImage(down, &up); 713 JXL_ASSIGN_OR_RETURN(corr, LinComb<float>(1, orig, -1, up)); 714 JXL_RETURN_IF_ERROR(ElwiseMul(corr, weights, &corr)); 715 AntiUpsample(corr, &corr2); 716 JXL_RETURN_IF_ERROR(ElwiseDiv(corr2, weights2, &corr2)); 717 718 JXL_ASSIGN_OR_RETURN(down, LinComb<float>(1, down, 1, corr2)); 719 } 720 721 ReduceRinging(initial, mask, down); 722 723 // can't just use CopyImage, because the output image was prepared with 724 // padding. 725 for (size_t y = 0; y < down.ysize(); y++) { 726 for (size_t x = 0; x < down.xsize(); x++) { 727 float v = down.Row(y)[x]; 728 output->Row(y)[x] = v; 729 } 730 } 731 return true; 732 } 733 734 } // namespace 735 736 Status DownsampleImage2_Iterative(Image3F* opsin) { 737 JxlMemoryManager* memory_manager = opsin->memory_manager(); 738 // Allocate extra space to avoid a reallocation when padding. 739 JXL_ASSIGN_OR_RETURN( 740 Image3F downsampled, 741 Image3F::Create(memory_manager, DivCeil(opsin->xsize(), 2) + kBlockDim, 742 DivCeil(opsin->ysize(), 2) + kBlockDim)); 743 JXL_RETURN_IF_ERROR(downsampled.ShrinkTo(downsampled.xsize() - kBlockDim, 744 downsampled.ysize() - kBlockDim)); 745 746 JXL_ASSIGN_OR_RETURN( 747 Image3F rgb, 748 Image3F::Create(memory_manager, opsin->xsize(), opsin->ysize())); 749 OpsinParams opsin_params; // TODO(user): use the ones that are actually used 750 opsin_params.Init(kDefaultIntensityTarget); 751 JXL_RETURN_IF_ERROR( 752 OpsinToLinear(*opsin, Rect(rgb), nullptr, &rgb, opsin_params)); 753 754 JXL_ASSIGN_OR_RETURN( 755 ImageF mask, 756 ImageF::Create(memory_manager, opsin->xsize(), opsin->ysize())); 757 ButteraugliParams butter_params; 758 JXL_ASSIGN_OR_RETURN(std::unique_ptr<ButteraugliComparator> butter, 759 ButteraugliComparator::Make(rgb, butter_params)); 760 JXL_RETURN_IF_ERROR(butter->Mask(&mask)); 761 JXL_ASSIGN_OR_RETURN( 762 ImageF mask_fuzzy, 763 ImageF::Create(memory_manager, opsin->xsize(), opsin->ysize())); 764 765 for (size_t c = 0; c < 3; c++) { 766 JXL_RETURN_IF_ERROR( 767 DownsampleImage2_Iterative(opsin->Plane(c), &downsampled.Plane(c))); 768 } 769 *opsin = std::move(downsampled); 770 return true; 771 } 772 773 StatusOr<Image3F> ReconstructImage( 774 const FrameHeader& orig_frame_header, const PassesSharedState& shared, 775 const std::vector<std::unique_ptr<ACImage>>& coeffs, ThreadPool* pool) { 776 const FrameDimensions& frame_dim = shared.frame_dim; 777 JxlMemoryManager* memory_manager = shared.memory_manager; 778 779 FrameHeader frame_header = orig_frame_header; 780 frame_header.UpdateFlag(shared.image_features.patches.HasAny(), 781 FrameHeader::kPatches); 782 frame_header.UpdateFlag(shared.image_features.splines.HasAny(), 783 FrameHeader::kSplines); 784 frame_header.color_transform = ColorTransform::kNone; 785 786 CodecMetadata metadata = *frame_header.nonserialized_metadata; 787 metadata.m.extra_channel_info.clear(); 788 metadata.m.num_extra_channels = metadata.m.extra_channel_info.size(); 789 frame_header.nonserialized_metadata = &metadata; 790 frame_header.extra_channel_upsampling.clear(); 791 792 const bool is_gray = shared.metadata->m.color_encoding.IsGray(); 793 PassesDecoderState dec_state(memory_manager); 794 JXL_RETURN_IF_ERROR( 795 dec_state.output_encoding_info.SetFromMetadata(*shared.metadata)); 796 JXL_RETURN_IF_ERROR(dec_state.output_encoding_info.MaybeSetColorEncoding( 797 ColorEncoding::LinearSRGB(is_gray))); 798 dec_state.shared = &shared; 799 JXL_RETURN_IF_ERROR(dec_state.Init(frame_header)); 800 801 ImageBundle decoded(memory_manager, &shared.metadata->m); 802 decoded.origin = frame_header.frame_origin; 803 JXL_ASSIGN_OR_RETURN( 804 Image3F tmp, 805 Image3F::Create(memory_manager, frame_dim.xsize, frame_dim.ysize)); 806 JXL_RETURN_IF_ERROR(decoded.SetFromImage( 807 std::move(tmp), dec_state.output_encoding_info.color_encoding)); 808 809 PassesDecoderState::PipelineOptions options; 810 options.use_slow_render_pipeline = false; 811 options.coalescing = false; 812 options.render_spotcolors = false; 813 options.render_noise = true; 814 815 JXL_RETURN_IF_ERROR(dec_state.PreparePipeline( 816 frame_header, &shared.metadata->m, &decoded, options)); 817 818 AlignedArray<GroupDecCache> group_dec_caches; 819 const auto allocate_storage = [&](const size_t num_threads) -> Status { 820 JXL_RETURN_IF_ERROR( 821 dec_state.render_pipeline->PrepareForThreads(num_threads, 822 /*use_group_ids=*/false)); 823 JXL_ASSIGN_OR_RETURN(group_dec_caches, AlignedArray<GroupDecCache>::Create( 824 memory_manager, num_threads)); 825 return true; 826 }; 827 const auto process_group = [&](const uint32_t group_index, 828 const size_t thread) -> Status { 829 if (frame_header.loop_filter.epf_iters > 0) { 830 JXL_RETURN_IF_ERROR(ComputeSigma(frame_header.loop_filter, 831 frame_dim.BlockGroupRect(group_index), 832 &dec_state)); 833 } 834 RenderPipelineInput input = 835 dec_state.render_pipeline->GetInputBuffers(group_index, thread); 836 JXL_RETURN_IF_ERROR(DecodeGroupForRoundtrip( 837 frame_header, coeffs, group_index, &dec_state, 838 &group_dec_caches[thread], thread, input, nullptr, nullptr)); 839 if ((frame_header.flags & FrameHeader::kNoise) != 0) { 840 PrepareNoiseInput(dec_state, shared.frame_dim, frame_header, group_index, 841 thread); 842 } 843 JXL_RETURN_IF_ERROR(input.Done()); 844 return true; 845 }; 846 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, frame_dim.num_groups, allocate_storage, 847 process_group, "ReconstructImage")); 848 return std::move(*decoded.color()); 849 } 850 851 float ComputeBlockL2Distance(const Image3F& a, const Image3F& b, 852 const ImageF& mask1x1, size_t by, size_t bx) { 853 Rect rect(bx * kBlockDim, by * kBlockDim, kBlockDim, kBlockDim, a.xsize(), 854 a.ysize()); 855 float err2[3] = {0.0f}; 856 for (size_t y = 0; y < rect.ysize(); ++y) { 857 const float* row_a[3] = { 858 rect.ConstPlaneRow(a, 0, y), 859 rect.ConstPlaneRow(a, 1, y), 860 rect.ConstPlaneRow(a, 2, y), 861 }; 862 const float* row_b[3] = { 863 rect.ConstPlaneRow(b, 0, y), 864 rect.ConstPlaneRow(b, 1, y), 865 rect.ConstPlaneRow(b, 2, y), 866 }; 867 const float* row_mask = rect.ConstRow(mask1x1, y); 868 for (size_t x = 0; x < rect.xsize(); ++x) { 869 float mask = row_mask[x]; 870 float mask2 = mask * mask; 871 for (int i = 0; i < 3; ++i) { 872 float diff = row_a[i][x] - row_b[i][x]; 873 err2[i] += mask2 * diff * diff; 874 } 875 } 876 } 877 static const double kW[] = { 878 12.339445295782363, 879 1.0, 880 0.2, 881 }; 882 float retval = kW[0] * err2[0] + kW[1] * err2[1] + kW[2] * err2[2]; 883 return retval; 884 } 885 886 Status ComputeARHeuristics(const FrameHeader& frame_header, 887 PassesEncoderState* enc_state, 888 const Image3F& orig_opsin, const Rect& rect, 889 ThreadPool* pool) { 890 const CompressParams& cparams = enc_state->cparams; 891 PassesSharedState& shared = enc_state->shared; 892 const FrameDimensions& frame_dim = shared.frame_dim; 893 const ImageF& initial_quant_masking1x1 = enc_state->initial_quant_masking1x1; 894 ImageB& epf_sharpness = shared.epf_sharpness; 895 JxlMemoryManager* memory_manager = enc_state->memory_manager(); 896 897 float clamped_butteraugli = std::min(5.0f, cparams.butteraugli_distance); 898 if (cparams.butteraugli_distance < kMinButteraugliForDynamicAR || 899 cparams.speed_tier > SpeedTier::kWombat || 900 frame_header.loop_filter.epf_iters == 0) { 901 FillPlane(static_cast<uint8_t>(4), &epf_sharpness, Rect(epf_sharpness)); 902 return true; 903 } 904 905 std::vector<uint8_t> epf_steps; 906 if (cparams.butteraugli_distance > 4.5f) { 907 epf_steps.push_back(0); 908 epf_steps.push_back(4); 909 } else { 910 epf_steps.push_back(0); 911 epf_steps.push_back(2); 912 epf_steps.push_back(7); 913 } 914 static const int kNumEPFVals = 8; 915 size_t epf_steps_lut[kNumEPFVals] = {0}; 916 { 917 for (size_t i = 0; i < epf_steps.size(); ++i) { 918 epf_steps_lut[epf_steps[i]] = i; 919 } 920 } 921 std::array<ImageF, kNumEPFVals> error_images; 922 for (uint8_t val : epf_steps) { 923 FillPlane(val, &epf_sharpness, Rect(epf_sharpness)); 924 JXL_ASSIGN_OR_RETURN( 925 Image3F decoded, 926 ReconstructImage(frame_header, shared, enc_state->coeffs, pool)); 927 JXL_ASSIGN_OR_RETURN(error_images[val], 928 ImageF::Create(memory_manager, frame_dim.xsize_blocks, 929 frame_dim.ysize_blocks)); 930 for (size_t by = 0; by < frame_dim.ysize_blocks; by++) { 931 float* error_row = error_images[val].Row(by); 932 for (size_t bx = 0; bx < frame_dim.xsize_blocks; bx++) { 933 error_row[bx] = ComputeBlockL2Distance( 934 orig_opsin, decoded, initial_quant_masking1x1, by, bx); 935 } 936 } 937 } 938 std::vector<std::vector<size_t>> histo(9, std::vector<size_t>(kNumEPFVals)); 939 std::vector<size_t> totals(9, 1); 940 const float c5 = 0.007620386618483585f; 941 const float c6 = 0.0083224805679680686f; 942 const float c7 = 0.99663939685686753; 943 for (size_t by = 0; by < frame_dim.ysize_blocks; by++) { 944 uint8_t* JXL_RESTRICT out_row = epf_sharpness.Row(by); 945 uint8_t* JXL_RESTRICT prev_row = epf_sharpness.Row(by > 0 ? by - 1 : 0); 946 for (size_t bx = 0; bx < frame_dim.xsize_blocks; bx++) { 947 uint8_t best_val = 0; 948 float best_error = std::numeric_limits<float>::max(); 949 uint8_t top_val = by > 0 ? prev_row[bx] : 0; 950 uint8_t left_val = bx > 0 ? out_row[bx - 1] : 0; 951 float top_error = error_images[top_val].Row(by)[bx]; 952 float left_error = error_images[left_val].Row(by)[bx]; 953 for (uint8_t val : epf_steps) { 954 float error = error_images[val].Row(by)[bx]; 955 if (val == 0) { 956 error *= c7 - c5 * clamped_butteraugli; 957 } 958 if (error < best_error) { 959 best_val = val; 960 best_error = error; 961 } 962 } 963 if (best_error < 964 (1.0 - c6 * clamped_butteraugli) * std::min(top_error, left_error)) { 965 out_row[bx] = best_val; 966 } else if (top_error < left_error) { 967 out_row[bx] = top_val; 968 } else { 969 out_row[bx] = left_val; 970 } 971 int context = epf_steps_lut[top_val] * 3 + epf_steps_lut[left_val]; 972 ++histo[context][out_row[bx]]; 973 ++totals[context]; 974 } 975 } 976 const float c1 = 0.059588212153340203f; 977 const float c2 = 0.10599497107315753f; 978 const float c3base = 0.97; 979 const float c3 = pow(c3base, clamped_butteraugli); 980 const float c4 = 1.247544678665836f; 981 const float context_weight = c1 + c2 * clamped_butteraugli; 982 for (size_t by = 0; by < frame_dim.ysize_blocks; by++) { 983 uint8_t* JXL_RESTRICT out_row = epf_sharpness.Row(by); 984 uint8_t* JXL_RESTRICT prev_row = epf_sharpness.Row(by > 0 ? by - 1 : 0); 985 for (size_t bx = 0; bx < frame_dim.xsize_blocks; bx++) { 986 uint8_t best_val = 0; 987 float best_error = std::numeric_limits<float>::max(); 988 uint8_t top_val = by > 0 ? prev_row[bx] : 0; 989 uint8_t left_val = bx > 0 ? out_row[bx - 1] : 0; 990 int context = epf_steps_lut[top_val] * 3 + epf_steps_lut[left_val]; 991 const auto& ctx_histo = histo[context]; 992 for (uint8_t val : epf_steps) { 993 float error = error_images[val].Row(by)[bx] / 994 (c4 + std::log1p(ctx_histo[val] * context_weight / 995 totals[context])); 996 if (val == 0) { 997 error *= c3; 998 } 999 if (error < best_error) { 1000 best_val = val; 1001 best_error = error; 1002 } 1003 } 1004 out_row[bx] = best_val; 1005 } 1006 } 1007 1008 return true; 1009 } 1010 1011 Status LossyFrameHeuristics(const FrameHeader& frame_header, 1012 PassesEncoderState* enc_state, 1013 ModularFrameEncoder* modular_frame_encoder, 1014 const Image3F* linear, Image3F* opsin, 1015 const Rect& rect, const JxlCmsInterface& cms, 1016 ThreadPool* pool, AuxOut* aux_out) { 1017 const CompressParams& cparams = enc_state->cparams; 1018 const bool streaming_mode = enc_state->streaming_mode; 1019 const bool initialize_global_state = enc_state->initialize_global_state; 1020 PassesSharedState& shared = enc_state->shared; 1021 const FrameDimensions& frame_dim = shared.frame_dim; 1022 ImageFeatures& image_features = shared.image_features; 1023 DequantMatrices& matrices = shared.matrices; 1024 Quantizer& quantizer = shared.quantizer; 1025 ImageF& initial_quant_masking1x1 = enc_state->initial_quant_masking1x1; 1026 ImageI& raw_quant_field = shared.raw_quant_field; 1027 ColorCorrelationMap& cmap = shared.cmap; 1028 AcStrategyImage& ac_strategy = shared.ac_strategy; 1029 BlockCtxMap& block_ctx_map = shared.block_ctx_map; 1030 JxlMemoryManager* memory_manager = enc_state->memory_manager(); 1031 1032 // Find and subtract splines. 1033 if (cparams.custom_splines.HasAny()) { 1034 image_features.splines = cparams.custom_splines; 1035 } 1036 if (!streaming_mode && cparams.speed_tier <= SpeedTier::kSquirrel) { 1037 if (!cparams.custom_splines.HasAny()) { 1038 image_features.splines = FindSplines(*opsin); 1039 } 1040 JXL_RETURN_IF_ERROR(image_features.splines.InitializeDrawCache( 1041 opsin->xsize(), opsin->ysize(), cmap.base())); 1042 image_features.splines.SubtractFrom(opsin); 1043 } 1044 1045 // Find and subtract patches/dots. 1046 if (!streaming_mode && 1047 ApplyOverride(cparams.patches, 1048 cparams.speed_tier <= SpeedTier::kSquirrel)) { 1049 JXL_RETURN_IF_ERROR( 1050 FindBestPatchDictionary(*opsin, enc_state, cms, pool, aux_out)); 1051 JXL_RETURN_IF_ERROR( 1052 PatchDictionaryEncoder::SubtractFrom(image_features.patches, opsin)); 1053 } 1054 1055 const float quant_dc = InitialQuantDC(cparams.butteraugli_distance); 1056 1057 // TODO(veluca): we can now run all the code from here to FindBestQuantizer 1058 // (excluded) one rect at a time. Do that. 1059 1060 // Dependency graph: 1061 // 1062 // input: either XYB or input image 1063 // 1064 // input image -> XYB [optional] 1065 // XYB -> initial quant field 1066 // XYB -> Gaborished XYB 1067 // Gaborished XYB -> CfL1 1068 // initial quant field, Gaborished XYB, CfL1 -> ACS 1069 // initial quant field, ACS, Gaborished XYB -> EPF control field 1070 // initial quant field -> adjusted initial quant field 1071 // adjusted initial quant field, ACS -> raw quant field 1072 // raw quant field, ACS, Gaborished XYB -> CfL2 1073 // 1074 // output: Gaborished XYB, CfL, ACS, raw quant field, EPF control field. 1075 1076 AcStrategyHeuristics acs_heuristics(memory_manager, cparams); 1077 CfLHeuristics cfl_heuristics(memory_manager); 1078 ImageF initial_quant_field; 1079 ImageF initial_quant_masking; 1080 1081 // Compute an initial estimate of the quantization field. 1082 // Call InitialQuantField only in Hare mode or slower. Otherwise, rely 1083 // on simple heuristics in FindBestAcStrategy, or set a constant for Falcon 1084 // mode. 1085 if (cparams.speed_tier > SpeedTier::kHare || 1086 cparams.disable_perceptual_optimizations) { 1087 JXL_ASSIGN_OR_RETURN(initial_quant_field, 1088 ImageF::Create(memory_manager, frame_dim.xsize_blocks, 1089 frame_dim.ysize_blocks)); 1090 JXL_ASSIGN_OR_RETURN(initial_quant_masking, 1091 ImageF::Create(memory_manager, frame_dim.xsize_blocks, 1092 frame_dim.ysize_blocks)); 1093 float q = 0.79 / cparams.butteraugli_distance; 1094 FillImage(q, &initial_quant_field); 1095 float masking = 1.0f / (q + 0.001f); 1096 FillImage(masking, &initial_quant_masking); 1097 if (cparams.disable_perceptual_optimizations) { 1098 JXL_ASSIGN_OR_RETURN( 1099 initial_quant_masking1x1, 1100 ImageF::Create(memory_manager, frame_dim.xsize, frame_dim.ysize)); 1101 FillImage(masking, &initial_quant_masking1x1); 1102 } 1103 quantizer.ComputeGlobalScaleAndQuant(quant_dc, q, 0); 1104 } else { 1105 // Call this here, as it relies on pre-gaborish values. 1106 float butteraugli_distance_for_iqf = cparams.butteraugli_distance; 1107 if (!frame_header.loop_filter.gab) { 1108 butteraugli_distance_for_iqf *= 0.62f; 1109 } 1110 JXL_ASSIGN_OR_RETURN( 1111 initial_quant_field, 1112 InitialQuantField(butteraugli_distance_for_iqf, *opsin, rect, pool, 1113 1.0f, &initial_quant_masking, 1114 &initial_quant_masking1x1)); 1115 float q = 0.39 / cparams.butteraugli_distance; 1116 quantizer.ComputeGlobalScaleAndQuant(quant_dc, q, 0); 1117 } 1118 1119 // TODO(veluca): do something about animations. 1120 1121 // Apply inverse-gaborish. 1122 if (frame_header.loop_filter.gab) { 1123 // Changing the weight here to 0.99f would help to reduce ringing in 1124 // generation loss. 1125 float weight[3] = { 1126 1.0f, 1127 1.0f, 1128 1.0f, 1129 }; 1130 JXL_RETURN_IF_ERROR(GaborishInverse(opsin, rect, weight, pool)); 1131 } 1132 1133 if (initialize_global_state) { 1134 JXL_RETURN_IF_ERROR(FindBestDequantMatrices( 1135 memory_manager, cparams, modular_frame_encoder, &matrices)); 1136 } 1137 1138 JXL_RETURN_IF_ERROR(cfl_heuristics.Init(rect)); 1139 JXL_RETURN_IF_ERROR(acs_heuristics.Init(*opsin, rect, initial_quant_field, 1140 initial_quant_masking, 1141 initial_quant_masking1x1, &matrices)); 1142 1143 auto process_tile = [&](const uint32_t tid, const size_t thread) -> Status { 1144 size_t n_enc_tiles = DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks); 1145 size_t tx = tid % n_enc_tiles; 1146 size_t ty = tid / n_enc_tiles; 1147 size_t by0 = ty * kEncTileDimInBlocks; 1148 size_t by1 = 1149 std::min((ty + 1) * kEncTileDimInBlocks, frame_dim.ysize_blocks); 1150 size_t bx0 = tx * kEncTileDimInBlocks; 1151 size_t bx1 = 1152 std::min((tx + 1) * kEncTileDimInBlocks, frame_dim.xsize_blocks); 1153 Rect r(bx0, by0, bx1 - bx0, by1 - by0); 1154 1155 // For speeds up to Wombat, we only compute the color correlation map 1156 // once we know the transform type and the quantization map. 1157 if (cparams.speed_tier <= SpeedTier::kSquirrel) { 1158 JXL_RETURN_IF_ERROR(cfl_heuristics.ComputeTile( 1159 r, *opsin, rect, matrices, 1160 /*ac_strategy=*/nullptr, 1161 /*raw_quant_field=*/nullptr, 1162 /*quantizer=*/nullptr, /*fast=*/false, thread, &cmap)); 1163 } 1164 1165 // Choose block sizes. 1166 JXL_RETURN_IF_ERROR( 1167 acs_heuristics.ProcessRect(r, cmap, &ac_strategy, thread)); 1168 1169 // Always set the initial quant field, so we can compute the CfL map with 1170 // more accuracy. The initial quant field might change in slower modes, but 1171 // adjusting the quant field with butteraugli when all the other encoding 1172 // parameters are fixed is likely a more reliable choice anyway. 1173 JXL_RETURN_IF_ERROR(AdjustQuantField( 1174 ac_strategy, r, cparams.butteraugli_distance, &initial_quant_field)); 1175 quantizer.SetQuantFieldRect(initial_quant_field, r, &raw_quant_field); 1176 1177 // Compute a non-default CfL map if we are at Hare speed, or slower. 1178 if (cparams.speed_tier <= SpeedTier::kHare) { 1179 JXL_RETURN_IF_ERROR(cfl_heuristics.ComputeTile( 1180 r, *opsin, rect, matrices, &ac_strategy, &raw_quant_field, &quantizer, 1181 /*fast=*/cparams.speed_tier >= SpeedTier::kWombat, thread, &cmap)); 1182 } 1183 return true; 1184 }; 1185 size_t num_tiles = DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks) * 1186 DivCeil(frame_dim.ysize_blocks, kEncTileDimInBlocks); 1187 const auto prepare = [&](const size_t num_threads) -> Status { 1188 JXL_RETURN_IF_ERROR(acs_heuristics.PrepareForThreads(num_threads)); 1189 JXL_RETURN_IF_ERROR(cfl_heuristics.PrepareForThreads(num_threads)); 1190 return true; 1191 }; 1192 JXL_RETURN_IF_ERROR( 1193 RunOnPool(pool, 0, num_tiles, prepare, process_tile, "Enc Heuristics")); 1194 1195 JXL_RETURN_IF_ERROR(acs_heuristics.Finalize(frame_dim, ac_strategy, aux_out)); 1196 1197 // Refine quantization levels. 1198 if (!streaming_mode && !cparams.disable_perceptual_optimizations) { 1199 ImageB& epf_sharpness = shared.epf_sharpness; 1200 FillPlane(static_cast<uint8_t>(4), &epf_sharpness, Rect(epf_sharpness)); 1201 JXL_RETURN_IF_ERROR(FindBestQuantizer(frame_header, linear, *opsin, 1202 initial_quant_field, enc_state, cms, 1203 pool, aux_out)); 1204 } 1205 1206 // Choose a context model that depends on the amount of quantization for AC. 1207 if (cparams.speed_tier < SpeedTier::kFalcon && initialize_global_state) { 1208 FindBestBlockEntropyModel(cparams, raw_quant_field, ac_strategy, 1209 &block_ctx_map); 1210 } 1211 return true; 1212 } 1213 1214 } // namespace jxl