enc_chroma_from_luma.cc (16260B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_chroma_from_luma.h" 7 8 #include <jxl/memory_manager.h> 9 10 #include <algorithm> 11 #include <cfloat> 12 #include <cmath> 13 #include <cstdlib> 14 #include <hwy/base.h> // HWY_ALIGN_MAX 15 16 #undef HWY_TARGET_INCLUDE 17 #define HWY_TARGET_INCLUDE "lib/jxl/enc_chroma_from_luma.cc" 18 #include <hwy/foreach_target.h> 19 #include <hwy/highway.h> 20 21 #include "lib/jxl/base/common.h" 22 #include "lib/jxl/base/rect.h" 23 #include "lib/jxl/base/status.h" 24 #include "lib/jxl/cms/opsin_params.h" 25 #include "lib/jxl/dec_transforms-inl.h" 26 #include "lib/jxl/enc_aux_out.h" 27 #include "lib/jxl/enc_params.h" 28 #include "lib/jxl/enc_transforms-inl.h" 29 #include "lib/jxl/quantizer.h" 30 #include "lib/jxl/simd_util.h" 31 HWY_BEFORE_NAMESPACE(); 32 namespace jxl { 33 namespace HWY_NAMESPACE { 34 35 // These templates are not found via ADL. 36 using hwy::HWY_NAMESPACE::Abs; 37 using hwy::HWY_NAMESPACE::Ge; 38 using hwy::HWY_NAMESPACE::GetLane; 39 using hwy::HWY_NAMESPACE::IfThenElse; 40 using hwy::HWY_NAMESPACE::Lt; 41 42 static HWY_FULL(float) df; 43 44 struct CFLFunction { 45 static constexpr float kCoeff = 1.f / 3; 46 static constexpr float kThres = 100.0f; 47 static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; 48 CFLFunction(const float* values_m, const float* values_s, size_t num, 49 float base, float distance_mul) 50 : values_m(values_m), 51 values_s(values_s), 52 num(num), 53 base(base), 54 distance_mul(distance_mul) { 55 JXL_DASSERT(num % Lanes(df) == 0); 56 } 57 58 // Returns f'(x), where f is 1/3 * sum ((|color residual| + 1)^2-1) + 59 // distance_mul * x^2 * num. 60 float Compute(float x, float eps, float* fpeps, float* fmeps) const { 61 float first_derivative = 2 * distance_mul * num * x; 62 float first_derivative_peps = 2 * distance_mul * num * (x + eps); 63 float first_derivative_meps = 2 * distance_mul * num * (x - eps); 64 65 const auto inv_color_factor = Set(df, kInvColorFactor); 66 const auto thres = Set(df, kThres); 67 const auto coeffx2 = Set(df, kCoeff * 2.0f); 68 const auto one = Set(df, 1.0f); 69 const auto zero = Set(df, 0.0f); 70 const auto base_v = Set(df, base); 71 const auto x_v = Set(df, x); 72 const auto xpe_v = Set(df, x + eps); 73 const auto xme_v = Set(df, x - eps); 74 auto fd_v = Zero(df); 75 auto fdpe_v = Zero(df); 76 auto fdme_v = Zero(df); 77 78 for (size_t i = 0; i < num; i += Lanes(df)) { 79 // color residual = ax + b 80 const auto a = Mul(inv_color_factor, Load(df, values_m + i)); 81 const auto b = 82 Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i)); 83 const auto v = MulAdd(a, x_v, b); 84 const auto vpe = MulAdd(a, xpe_v, b); 85 const auto vme = MulAdd(a, xme_v, b); 86 const auto av = Abs(v); 87 const auto avpe = Abs(vpe); 88 const auto avme = Abs(vme); 89 const auto acoeffx2 = Mul(coeffx2, a); 90 auto d = Mul(acoeffx2, Add(av, one)); 91 auto dpe = Mul(acoeffx2, Add(avpe, one)); 92 auto dme = Mul(acoeffx2, Add(avme, one)); 93 d = IfThenElse(Lt(v, zero), Sub(zero, d), d); 94 dpe = IfThenElse(Lt(vpe, zero), Sub(zero, dpe), dpe); 95 dme = IfThenElse(Lt(vme, zero), Sub(zero, dme), dme); 96 const auto above = Ge(av, thres); 97 // TODO(eustas): use IfThenElseZero 98 fd_v = Add(fd_v, IfThenElse(above, zero, d)); 99 fdpe_v = Add(fdpe_v, IfThenElse(above, zero, dpe)); 100 fdme_v = Add(fdme_v, IfThenElse(above, zero, dme)); 101 } 102 103 *fpeps = first_derivative_peps + GetLane(SumOfLanes(df, fdpe_v)); 104 *fmeps = first_derivative_meps + GetLane(SumOfLanes(df, fdme_v)); 105 return first_derivative + GetLane(SumOfLanes(df, fd_v)); 106 } 107 108 const float* JXL_RESTRICT values_m; 109 const float* JXL_RESTRICT values_s; 110 size_t num; 111 float base; 112 float distance_mul; 113 }; 114 115 // Chroma-from-luma search, values_m will have luma -- and values_s chroma. 116 int32_t FindBestMultiplier(const float* values_m, const float* values_s, 117 size_t num, float base, float distance_mul, 118 bool fast) { 119 if (num == 0) { 120 return 0; 121 } 122 float x; 123 if (fast) { 124 static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; 125 auto ca = Zero(df); 126 auto cb = Zero(df); 127 const auto inv_color_factor = Set(df, kInvColorFactor); 128 const auto base_v = Set(df, base); 129 for (size_t i = 0; i < num; i += Lanes(df)) { 130 // color residual = ax + b 131 const auto a = Mul(inv_color_factor, Load(df, values_m + i)); 132 const auto b = 133 Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i)); 134 ca = MulAdd(a, a, ca); 135 cb = MulAdd(a, b, cb); 136 } 137 // + distance_mul * x^2 * num 138 x = -GetLane(SumOfLanes(df, cb)) / 139 (GetLane(SumOfLanes(df, ca)) + num * distance_mul * 0.5f); 140 } else { 141 constexpr float eps = 100; 142 constexpr float kClamp = 20.0f; 143 CFLFunction fn(values_m, values_s, num, base, distance_mul); 144 x = 0; 145 // Up to 20 Newton iterations, with approximate derivatives. 146 // Derivatives are approximate due to the high amount of noise in the exact 147 // derivatives. 148 for (size_t i = 0; i < 20; i++) { 149 float dfpeps; 150 float dfmeps; 151 float df = fn.Compute(x, eps, &dfpeps, &dfmeps); 152 float ddf = (dfpeps - dfmeps) / (2 * eps); 153 float kExperimentalInsignificantStabilizer = 0.85; 154 float step = df / (ddf + kExperimentalInsignificantStabilizer); 155 x -= std::min(kClamp, std::max(-kClamp, step)); 156 if (std::abs(step) < 3e-3) break; 157 } 158 } 159 // CFL seems to be tricky for larger transforms for HF components 160 // close to zero. This heuristic brings the solutions closer to zero 161 // and reduces red-green oscillations. A better approach would 162 // look into variance of the multiplier within separate (e.g. 8x8) 163 // areas and only apply this heuristic where there is a high variance. 164 // This would give about 1 % more compression density. 165 float towards_zero = 2.6; 166 if (x >= towards_zero) { 167 x -= towards_zero; 168 } else if (x <= -towards_zero) { 169 x += towards_zero; 170 } else { 171 x = 0; 172 } 173 return std::max(-128.0f, std::min(127.0f, roundf(x))); 174 } 175 176 Status InitDCStorage(JxlMemoryManager* memory_manager, size_t num_blocks, 177 ImageF* dc_values) { 178 // First row: Y channel 179 // Second row: X channel 180 // Third row: Y channel 181 // Fourth row: B channel 182 JXL_ASSIGN_OR_RETURN( 183 *dc_values, 184 ImageF::Create(memory_manager, RoundUpTo(num_blocks, Lanes(df)), 4)); 185 186 JXL_ENSURE(dc_values->xsize() != 0); 187 // Zero-fill the last lanes 188 for (size_t y = 0; y < 4; y++) { 189 for (size_t x = dc_values->xsize() - Lanes(df); x < dc_values->xsize(); 190 x++) { 191 dc_values->Row(y)[x] = 0; 192 } 193 } 194 return true; 195 } 196 197 Status ComputeTile(const Image3F& opsin, const Rect& opsin_rect, 198 const DequantMatrices& dequant, 199 const AcStrategyImage* ac_strategy, 200 const ImageI* raw_quant_field, const Quantizer* quantizer, 201 const Rect& rect, bool fast, bool use_dct8, ImageSB* map_x, 202 ImageSB* map_b, ImageF* dc_values, float* mem) { 203 static_assert(kEncTileDimInBlocks == kColorTileDimInBlocks, 204 "Invalid color tile dim"); 205 size_t xsize_blocks = opsin_rect.xsize() / kBlockDim; 206 constexpr float kDistanceMultiplierAC = 1e-9f; 207 const size_t dct_scratch_size = 208 3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim; 209 210 const size_t y0 = rect.y0(); 211 const size_t x0 = rect.x0(); 212 const size_t x1 = rect.x0() + rect.xsize(); 213 const size_t y1 = rect.y0() + rect.ysize(); 214 215 int ty = y0 / kColorTileDimInBlocks; 216 int tx = x0 / kColorTileDimInBlocks; 217 218 int8_t* JXL_RESTRICT row_out_x = map_x->Row(ty); 219 int8_t* JXL_RESTRICT row_out_b = map_b->Row(ty); 220 221 float* JXL_RESTRICT dc_values_yx = dc_values->Row(0); 222 float* JXL_RESTRICT dc_values_x = dc_values->Row(1); 223 float* JXL_RESTRICT dc_values_yb = dc_values->Row(2); 224 float* JXL_RESTRICT dc_values_b = dc_values->Row(3); 225 226 // All are aligned. 227 float* HWY_RESTRICT block_y = mem; 228 float* HWY_RESTRICT block_x = block_y + AcStrategy::kMaxCoeffArea; 229 float* HWY_RESTRICT block_b = block_x + AcStrategy::kMaxCoeffArea; 230 float* HWY_RESTRICT coeffs_yx = block_b + AcStrategy::kMaxCoeffArea; 231 float* HWY_RESTRICT coeffs_x = coeffs_yx + kColorTileDim * kColorTileDim; 232 float* HWY_RESTRICT coeffs_yb = coeffs_x + kColorTileDim * kColorTileDim; 233 float* HWY_RESTRICT coeffs_b = coeffs_yb + kColorTileDim * kColorTileDim; 234 float* HWY_RESTRICT scratch_space = coeffs_b + kColorTileDim * kColorTileDim; 235 float* scratch_space_end = 236 scratch_space + 2 * AcStrategy::kMaxCoeffArea + dct_scratch_size; 237 JXL_ENSURE(scratch_space_end == block_y + CfLHeuristics::ItemsPerThread()); 238 (void)scratch_space_end; 239 240 // Small (~256 bytes each) 241 HWY_ALIGN_MAX float 242 dc_y[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; 243 HWY_ALIGN_MAX float 244 dc_x[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; 245 HWY_ALIGN_MAX float 246 dc_b[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; 247 size_t num_ac = 0; 248 249 for (size_t y = y0; y < y1; ++y) { 250 const float* JXL_RESTRICT row_y = 251 opsin_rect.ConstPlaneRow(opsin, 1, y * kBlockDim); 252 const float* JXL_RESTRICT row_x = 253 opsin_rect.ConstPlaneRow(opsin, 0, y * kBlockDim); 254 const float* JXL_RESTRICT row_b = 255 opsin_rect.ConstPlaneRow(opsin, 2, y * kBlockDim); 256 size_t stride = opsin.PixelsPerRow(); 257 258 for (size_t x = x0; x < x1; x++) { 259 AcStrategy acs = use_dct8 260 ? AcStrategy::FromRawStrategy(AcStrategyType::DCT) 261 : ac_strategy->ConstRow(y)[x]; 262 if (!acs.IsFirstBlock()) continue; 263 size_t xs = acs.covered_blocks_x(); 264 TransformFromPixels(acs.Strategy(), row_y + x * kBlockDim, stride, 265 block_y, scratch_space); 266 DCFromLowestFrequencies(acs.Strategy(), block_y, dc_y, xs); 267 TransformFromPixels(acs.Strategy(), row_x + x * kBlockDim, stride, 268 block_x, scratch_space); 269 DCFromLowestFrequencies(acs.Strategy(), block_x, dc_x, xs); 270 TransformFromPixels(acs.Strategy(), row_b + x * kBlockDim, stride, 271 block_b, scratch_space); 272 DCFromLowestFrequencies(acs.Strategy(), block_b, dc_b, xs); 273 const float* const JXL_RESTRICT qm_x = 274 dequant.InvMatrix(acs.Strategy(), 0); 275 const float* const JXL_RESTRICT qm_b = 276 dequant.InvMatrix(acs.Strategy(), 2); 277 float q_dc_x = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(0); 278 float q_dc_b = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(2); 279 280 // Copy DCs in dc_values. 281 for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { 282 for (size_t ix = 0; ix < xs; ix++) { 283 dc_values_yx[(iy + y) * xsize_blocks + ix + x] = 284 dc_y[iy * xs + ix] * q_dc_x; 285 dc_values_x[(iy + y) * xsize_blocks + ix + x] = 286 dc_x[iy * xs + ix] * q_dc_x; 287 dc_values_yb[(iy + y) * xsize_blocks + ix + x] = 288 dc_y[iy * xs + ix] * q_dc_b; 289 dc_values_b[(iy + y) * xsize_blocks + ix + x] = 290 dc_b[iy * xs + ix] * q_dc_b; 291 } 292 } 293 294 // Do not use this block for computing AC CfL. 295 if (acs.covered_blocks_x() + x0 > x1 || 296 acs.covered_blocks_y() + y0 > y1) { 297 continue; 298 } 299 300 // Copy AC coefficients in the local block. The order in which 301 // coefficients get stored does not matter. 302 size_t cx = acs.covered_blocks_x(); 303 size_t cy = acs.covered_blocks_y(); 304 CoefficientLayout(&cy, &cx); 305 // Zero out LFs. This introduces terms in the optimization loop that 306 // don't affect the result, as they are all 0, but allow for simpler 307 // SIMDfication. 308 for (size_t iy = 0; iy < cy; iy++) { 309 for (size_t ix = 0; ix < cx; ix++) { 310 block_y[cx * kBlockDim * iy + ix] = 0; 311 block_x[cx * kBlockDim * iy + ix] = 0; 312 block_b[cx * kBlockDim * iy + ix] = 0; 313 } 314 } 315 // Unclear why this is like it is. (This works slightly better 316 // than the previous approach which was also a hack.) 317 const float qq = 318 (raw_quant_field == nullptr) ? 1.0f : raw_quant_field->Row(y)[x]; 319 // Experimentally values 128-130 seem best -- I don't know why we 320 // need this multiplier. 321 const float kStrangeMultiplier = 128; 322 float q = use_dct8 ? 1 : quantizer->Scale() * kStrangeMultiplier * qq; 323 const auto qv = Set(df, q); 324 for (size_t i = 0; i < cx * cy * 64; i += Lanes(df)) { 325 const auto b_y = Load(df, block_y + i); 326 const auto b_x = Load(df, block_x + i); 327 const auto b_b = Load(df, block_b + i); 328 const auto qqm_x = Mul(qv, Load(df, qm_x + i)); 329 const auto qqm_b = Mul(qv, Load(df, qm_b + i)); 330 Store(Mul(b_y, qqm_x), df, coeffs_yx + num_ac); 331 Store(Mul(b_x, qqm_x), df, coeffs_x + num_ac); 332 Store(Mul(b_y, qqm_b), df, coeffs_yb + num_ac); 333 Store(Mul(b_b, qqm_b), df, coeffs_b + num_ac); 334 num_ac += Lanes(df); 335 } 336 } 337 } 338 JXL_ENSURE(num_ac % Lanes(df) == 0); 339 row_out_x[tx] = FindBestMultiplier(coeffs_yx, coeffs_x, num_ac, 0.0f, 340 kDistanceMultiplierAC, fast); 341 row_out_b[tx] = 342 FindBestMultiplier(coeffs_yb, coeffs_b, num_ac, jxl::cms::kYToBRatio, 343 kDistanceMultiplierAC, fast); 344 return true; 345 } 346 347 // NOLINTNEXTLINE(google-readability-namespace-comments) 348 } // namespace HWY_NAMESPACE 349 } // namespace jxl 350 HWY_AFTER_NAMESPACE(); 351 352 #if HWY_ONCE 353 namespace jxl { 354 355 HWY_EXPORT(InitDCStorage); 356 HWY_EXPORT(ComputeTile); 357 358 Status CfLHeuristics::Init(const Rect& rect) { 359 size_t xsize_blocks = rect.xsize() / kBlockDim; 360 size_t ysize_blocks = rect.ysize() / kBlockDim; 361 return HWY_DYNAMIC_DISPATCH(InitDCStorage)( 362 memory_manager, xsize_blocks * ysize_blocks, &dc_values); 363 } 364 365 Status CfLHeuristics::ComputeTile(const Rect& r, const Image3F& opsin, 366 const Rect& opsin_rect, 367 const DequantMatrices& dequant, 368 const AcStrategyImage* ac_strategy, 369 const ImageI* raw_quant_field, 370 const Quantizer* quantizer, bool fast, 371 size_t thread, ColorCorrelationMap* cmap) { 372 bool use_dct8 = ac_strategy == nullptr; 373 return HWY_DYNAMIC_DISPATCH(ComputeTile)( 374 opsin, opsin_rect, dequant, ac_strategy, raw_quant_field, quantizer, r, 375 fast, use_dct8, &cmap->ytox_map, &cmap->ytob_map, &dc_values, 376 mem.address<float>() + thread * ItemsPerThread()); 377 } 378 379 Status ColorCorrelationEncodeDC(const ColorCorrelation& color_correlation, 380 BitWriter* writer, LayerType layer, 381 AuxOut* aux_out) { 382 float color_factor = color_correlation.GetColorFactor(); 383 float base_correlation_x = color_correlation.GetBaseCorrelationX(); 384 float base_correlation_b = color_correlation.GetBaseCorrelationB(); 385 int32_t ytox_dc = color_correlation.GetYToXDC(); 386 int32_t ytob_dc = color_correlation.GetYToBDC(); 387 388 return writer->WithMaxBits( 389 1 + 2 * kBitsPerByte + 12 + 32, layer, aux_out, [&]() -> Status { 390 if (ytox_dc == 0 && ytob_dc == 0 && 391 color_factor == kDefaultColorFactor && base_correlation_x == 0.0f && 392 base_correlation_b == jxl::cms::kYToBRatio) { 393 writer->Write(1, 1); 394 return true; 395 } 396 writer->Write(1, 0); 397 JXL_RETURN_IF_ERROR( 398 U32Coder::Write(kColorFactorDist, color_factor, writer)); 399 JXL_RETURN_IF_ERROR(F16Coder::Write(base_correlation_x, writer)); 400 JXL_RETURN_IF_ERROR(F16Coder::Write(base_correlation_b, writer)); 401 writer->Write(kBitsPerByte, 402 ytox_dc - std::numeric_limits<int8_t>::min()); 403 writer->Write(kBitsPerByte, 404 ytob_dc - std::numeric_limits<int8_t>::min()); 405 return true; 406 }); 407 } 408 409 } // namespace jxl 410 #endif // HWY_ONCE