squeeze.cc (18538B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/modular/transform/squeeze.h" 7 8 #include <jxl/memory_manager.h> 9 10 #include <cstdlib> 11 12 #include "lib/jxl/base/common.h" 13 #include "lib/jxl/base/data_parallel.h" 14 #include "lib/jxl/base/printf_macros.h" 15 #include "lib/jxl/modular/modular_image.h" 16 #include "lib/jxl/modular/transform/transform.h" 17 #undef HWY_TARGET_INCLUDE 18 #define HWY_TARGET_INCLUDE "lib/jxl/modular/transform/squeeze.cc" 19 #include <hwy/foreach_target.h> 20 #include <hwy/highway.h> 21 22 #include "lib/jxl/simd_util-inl.h" 23 24 HWY_BEFORE_NAMESPACE(); 25 namespace jxl { 26 namespace HWY_NAMESPACE { 27 28 // These templates are not found via ADL. 29 using hwy::HWY_NAMESPACE::Abs; 30 using hwy::HWY_NAMESPACE::Add; 31 using hwy::HWY_NAMESPACE::And; 32 using hwy::HWY_NAMESPACE::Gt; 33 using hwy::HWY_NAMESPACE::IfThenElse; 34 using hwy::HWY_NAMESPACE::IfThenZeroElse; 35 using hwy::HWY_NAMESPACE::Lt; 36 using hwy::HWY_NAMESPACE::MulEven; 37 using hwy::HWY_NAMESPACE::Ne; 38 using hwy::HWY_NAMESPACE::Neg; 39 using hwy::HWY_NAMESPACE::OddEven; 40 using hwy::HWY_NAMESPACE::RebindToUnsigned; 41 using hwy::HWY_NAMESPACE::ShiftLeft; 42 using hwy::HWY_NAMESPACE::ShiftRight; 43 using hwy::HWY_NAMESPACE::Sub; 44 using hwy::HWY_NAMESPACE::Xor; 45 46 #if HWY_TARGET != HWY_SCALAR 47 48 JXL_INLINE void FastUnsqueeze(const pixel_type *JXL_RESTRICT p_residual, 49 const pixel_type *JXL_RESTRICT p_avg, 50 const pixel_type *JXL_RESTRICT p_navg, 51 const pixel_type *p_pout, 52 pixel_type *JXL_RESTRICT p_out, 53 pixel_type *p_nout) { 54 const HWY_CAPPED(pixel_type, 8) d; 55 const RebindToUnsigned<decltype(d)> du; 56 const size_t N = Lanes(d); 57 auto onethird = Set(d, 0x55555556); 58 for (size_t x = 0; x < 8; x += N) { 59 auto avg = Load(d, p_avg + x); 60 auto next_avg = Load(d, p_navg + x); 61 auto top = Load(d, p_pout + x); 62 // Equivalent to SmoothTendency(top,avg,next_avg), but without branches 63 // typo:off 64 auto Ba = Sub(top, avg); 65 auto an = Sub(avg, next_avg); 66 auto nonmono = Xor(Ba, an); 67 auto absBa = Abs(Ba); 68 auto absan = Abs(an); 69 auto absBn = Abs(Sub(top, next_avg)); 70 // Compute a3 = absBa / 3 71 auto a3e = BitCast(d, ShiftRight<32>(MulEven(absBa, onethird))); 72 auto a3oi = MulEven(Reverse(d, absBa), onethird); 73 auto a3o = BitCast( 74 d, Reverse(hwy::HWY_NAMESPACE::Repartition<pixel_type_w, decltype(d)>(), 75 a3oi)); 76 auto a3 = OddEven(a3o, a3e); 77 a3 = Add(a3, Add(absBn, Set(d, 2))); 78 auto absdiff = ShiftRight<2>(a3); 79 auto skipdiff = Ne(Ba, Zero(d)); 80 skipdiff = And(skipdiff, Ne(an, Zero(d))); 81 skipdiff = And(skipdiff, Lt(nonmono, Zero(d))); 82 auto absBa2 = Add(ShiftLeft<1>(absBa), And(absdiff, Set(d, 1))); 83 absdiff = IfThenElse(Gt(absdiff, absBa2), 84 Add(ShiftLeft<1>(absBa), Set(d, 1)), absdiff); 85 // typo:on 86 auto absan2 = ShiftLeft<1>(absan); 87 absdiff = IfThenElse(Gt(Add(absdiff, And(absdiff, Set(d, 1))), absan2), 88 absan2, absdiff); 89 auto diff1 = IfThenElse(Lt(top, next_avg), Neg(absdiff), absdiff); 90 auto tendency = IfThenZeroElse(skipdiff, diff1); 91 92 auto diff_minus_tendency = Load(d, p_residual + x); 93 auto diff = Add(diff_minus_tendency, tendency); 94 auto out = 95 Add(avg, ShiftRight<1>( 96 Add(diff, BitCast(d, ShiftRight<31>(BitCast(du, diff)))))); 97 Store(out, d, p_out + x); 98 Store(Sub(out, diff), d, p_nout + x); 99 } 100 } 101 102 #endif 103 104 Status InvHSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) { 105 JXL_ENSURE(c < input.channel.size()); 106 JXL_ENSURE(rc < input.channel.size()); 107 Channel &chin = input.channel[c]; 108 const Channel &chin_residual = input.channel[rc]; 109 // These must be valid since we ran MetaApply already. 110 JXL_ENSURE(chin.w == DivCeil(chin.w + chin_residual.w, 2)); 111 JXL_ENSURE(chin.h == chin_residual.h); 112 JxlMemoryManager *memory_manager = input.memory_manager(); 113 114 if (chin_residual.w == 0) { 115 // Short-circuit: output channel has same dimensions as input. 116 input.channel[c].hshift--; 117 return true; 118 } 119 120 // Note: chin.w >= chin_residual.w and at most 1 different. 121 JXL_ASSIGN_OR_RETURN(Channel chout, 122 Channel::Create(memory_manager, chin.w + chin_residual.w, 123 chin.h, chin.hshift - 1, chin.vshift)); 124 JXL_DEBUG_V(4, 125 "Undoing horizontal squeeze of channel %i using residuals in " 126 "channel %i (going from width %" PRIuS " to %" PRIuS ")", 127 c, rc, chin.w, chout.w); 128 129 if (chin_residual.h == 0) { 130 // Short-circuit: channel with no pixels. 131 input.channel[c] = std::move(chout); 132 return true; 133 } 134 auto unsqueeze_row = [&](size_t y, size_t x0) { 135 const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y); 136 const pixel_type *JXL_RESTRICT p_avg = chin.Row(y); 137 pixel_type *JXL_RESTRICT p_out = chout.Row(y); 138 for (size_t x = x0; x < chin_residual.w; x++) { 139 pixel_type_w diff_minus_tendency = p_residual[x]; 140 pixel_type_w avg = p_avg[x]; 141 pixel_type_w next_avg = (x + 1 < chin.w ? p_avg[x + 1] : avg); 142 pixel_type_w left = (x ? p_out[(x << 1) - 1] : avg); 143 pixel_type_w tendency = SmoothTendency(left, avg, next_avg); 144 pixel_type_w diff = diff_minus_tendency + tendency; 145 pixel_type_w A = avg + (diff / 2); 146 p_out[(x << 1)] = A; 147 pixel_type_w B = A - diff; 148 p_out[(x << 1) + 1] = B; 149 } 150 if (chout.w & 1) p_out[chout.w - 1] = p_avg[chin.w - 1]; 151 }; 152 153 // somewhat complicated trickery just to be able to SIMD this. 154 // Horizontal unsqueeze has horizontal data dependencies, so we do 155 // 8 rows at a time and treat it as a vertical unsqueeze of a 156 // transposed 8x8 block (or 9x8 for one input). 157 static constexpr const size_t kRowsPerThread = 8; 158 const auto unsqueeze_span = [&](const uint32_t task, 159 size_t /* thread */) -> Status { 160 const size_t y0 = task * kRowsPerThread; 161 const size_t rows = std::min(kRowsPerThread, chin.h - y0); 162 size_t x = 0; 163 164 #if HWY_TARGET != HWY_SCALAR 165 intptr_t onerow_in = chin.plane.PixelsPerRow(); 166 intptr_t onerow_inr = chin_residual.plane.PixelsPerRow(); 167 intptr_t onerow_out = chout.plane.PixelsPerRow(); 168 const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y0); 169 const pixel_type *JXL_RESTRICT p_avg = chin.Row(y0); 170 pixel_type *JXL_RESTRICT p_out = chout.Row(y0); 171 HWY_ALIGN pixel_type b_p_avg[9 * kRowsPerThread]; 172 HWY_ALIGN pixel_type b_p_residual[8 * kRowsPerThread]; 173 HWY_ALIGN pixel_type b_p_out_even[8 * kRowsPerThread]; 174 HWY_ALIGN pixel_type b_p_out_odd[8 * kRowsPerThread]; 175 HWY_ALIGN pixel_type b_p_out_evenT[8 * kRowsPerThread]; 176 HWY_ALIGN pixel_type b_p_out_oddT[8 * kRowsPerThread]; 177 const HWY_CAPPED(pixel_type, 8) d; 178 const size_t N = Lanes(d); 179 if (chin_residual.w > 16 && rows == kRowsPerThread) { 180 for (; x < chin_residual.w - 9; x += 8) { 181 Transpose8x8Block(p_residual + x, b_p_residual, onerow_inr); 182 Transpose8x8Block(p_avg + x, b_p_avg, onerow_in); 183 for (size_t y = 0; y < kRowsPerThread; y++) { 184 b_p_avg[8 * 8 + y] = p_avg[x + 8 + onerow_in * y]; 185 } 186 for (size_t i = 0; i < 8; i++) { 187 FastUnsqueeze( 188 b_p_residual + 8 * i, b_p_avg + 8 * i, b_p_avg + 8 * (i + 1), 189 (x + i ? b_p_out_odd + 8 * ((x + i - 1) & 7) : b_p_avg + 8 * i), 190 b_p_out_even + 8 * i, b_p_out_odd + 8 * i); 191 } 192 193 Transpose8x8Block(b_p_out_even, b_p_out_evenT, 8); 194 Transpose8x8Block(b_p_out_odd, b_p_out_oddT, 8); 195 for (size_t y = 0; y < kRowsPerThread; y++) { 196 for (size_t i = 0; i < kRowsPerThread; i += N) { 197 auto even = Load(d, b_p_out_evenT + 8 * y + i); 198 auto odd = Load(d, b_p_out_oddT + 8 * y + i); 199 StoreInterleaved(d, even, odd, 200 p_out + ((x + i) << 1) + onerow_out * y); 201 } 202 } 203 } 204 } 205 #endif 206 for (size_t y = 0; y < rows; y++) { 207 unsqueeze_row(y0 + y, x); 208 } 209 return true; 210 }; 211 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.h, kRowsPerThread), 212 ThreadPool::NoInit, unsqueeze_span, 213 "InvHorizontalSqueeze")); 214 input.channel[c] = std::move(chout); 215 return true; 216 } 217 218 Status InvVSqueeze(Image &input, uint32_t c, uint32_t rc, ThreadPool *pool) { 219 JXL_ENSURE(c < input.channel.size()); 220 JXL_ENSURE(rc < input.channel.size()); 221 const Channel &chin = input.channel[c]; 222 const Channel &chin_residual = input.channel[rc]; 223 // These must be valid since we ran MetaApply already. 224 JXL_ENSURE(chin.h == DivCeil(chin.h + chin_residual.h, 2)); 225 JXL_ENSURE(chin.w == chin_residual.w); 226 JxlMemoryManager *memory_manager = input.memory_manager(); 227 228 if (chin_residual.h == 0) { 229 // Short-circuit: output channel has same dimensions as input. 230 input.channel[c].vshift--; 231 return true; 232 } 233 234 // Note: chin.h >= chin_residual.h and at most 1 different. 235 JXL_ASSIGN_OR_RETURN( 236 Channel chout, 237 Channel::Create(memory_manager, chin.w, chin.h + chin_residual.h, 238 chin.hshift, chin.vshift - 1)); 239 JXL_DEBUG_V( 240 4, 241 "Undoing vertical squeeze of channel %i using residuals in channel " 242 "%i (going from height %" PRIuS " to %" PRIuS ")", 243 c, rc, chin.h, chout.h); 244 245 if (chin_residual.w == 0) { 246 // Short-circuit: channel with no pixels. 247 input.channel[c] = std::move(chout); 248 return true; 249 } 250 251 static constexpr const int kColsPerThread = 64; 252 const auto unsqueeze_slice = [&](const uint32_t task, 253 size_t /* thread */) -> Status { 254 const size_t x0 = task * kColsPerThread; 255 const size_t x1 = 256 std::min(static_cast<size_t>(task + 1) * kColsPerThread, chin.w); 257 const size_t w = x1 - x0; 258 // We only iterate up to std::min(chin_residual.h, chin.h) which is 259 // always chin_residual.h. 260 for (size_t y = 0; y < chin_residual.h; y++) { 261 const pixel_type *JXL_RESTRICT p_residual = chin_residual.Row(y) + x0; 262 const pixel_type *JXL_RESTRICT p_avg = chin.Row(y) + x0; 263 const pixel_type *JXL_RESTRICT p_navg = 264 chin.Row(y + 1 < chin.h ? y + 1 : y) + x0; 265 pixel_type *JXL_RESTRICT p_out = chout.Row(y << 1) + x0; 266 pixel_type *JXL_RESTRICT p_nout = chout.Row((y << 1) + 1) + x0; 267 const pixel_type *p_pout = y > 0 ? chout.Row((y << 1) - 1) + x0 : p_avg; 268 size_t x = 0; 269 #if HWY_TARGET != HWY_SCALAR 270 for (; x + 7 < w; x += 8) { 271 FastUnsqueeze(p_residual + x, p_avg + x, p_navg + x, p_pout + x, 272 p_out + x, p_nout + x); 273 } 274 #endif 275 for (; x < w; x++) { 276 pixel_type_w avg = p_avg[x]; 277 pixel_type_w next_avg = p_navg[x]; 278 pixel_type_w top = p_pout[x]; 279 pixel_type_w tendency = SmoothTendency(top, avg, next_avg); 280 pixel_type_w diff_minus_tendency = p_residual[x]; 281 pixel_type_w diff = diff_minus_tendency + tendency; 282 pixel_type_w out = avg + (diff / 2); 283 p_out[x] = out; 284 // If the chin_residual.h == chin.h, the output has an even number 285 // of rows so the next line is fine. Otherwise, this loop won't 286 // write to the last output row which is handled separately. 287 p_nout[x] = out - diff; 288 } 289 } 290 return true; 291 }; 292 JXL_RETURN_IF_ERROR(RunOnPool(pool, 0, DivCeil(chin.w, kColsPerThread), 293 ThreadPool::NoInit, unsqueeze_slice, 294 "InvVertSqueeze")); 295 296 if (chout.h & 1) { 297 size_t y = chin.h - 1; 298 const pixel_type *p_avg = chin.Row(y); 299 pixel_type *p_out = chout.Row(y << 1); 300 for (size_t x = 0; x < chin.w; x++) { 301 p_out[x] = p_avg[x]; 302 } 303 } 304 input.channel[c] = std::move(chout); 305 return true; 306 } 307 308 Status InvSqueeze(Image &input, const std::vector<SqueezeParams> ¶meters, 309 ThreadPool *pool) { 310 for (int i = parameters.size() - 1; i >= 0; i--) { 311 JXL_RETURN_IF_ERROR( 312 CheckMetaSqueezeParams(parameters[i], input.channel.size())); 313 bool horizontal = parameters[i].horizontal; 314 bool in_place = parameters[i].in_place; 315 uint32_t beginc = parameters[i].begin_c; 316 uint32_t endc = parameters[i].begin_c + parameters[i].num_c - 1; 317 uint32_t offset; 318 if (in_place) { 319 offset = endc + 1; 320 } else { 321 offset = input.channel.size() + beginc - endc - 1; 322 } 323 if (beginc < input.nb_meta_channels) { 324 // This is checked in MetaSqueeze. 325 JXL_ENSURE(input.nb_meta_channels > parameters[i].num_c); 326 input.nb_meta_channels -= parameters[i].num_c; 327 } 328 329 for (uint32_t c = beginc; c <= endc; c++) { 330 uint32_t rc = offset + c - beginc; 331 // MetaApply should imply that `rc` is within range, otherwise there's a 332 // programming bug. 333 JXL_ENSURE(rc < input.channel.size()); 334 if ((input.channel[c].w < input.channel[rc].w) || 335 (input.channel[c].h < input.channel[rc].h)) { 336 return JXL_FAILURE("Corrupted squeeze transform"); 337 } 338 if (horizontal) { 339 JXL_RETURN_IF_ERROR(InvHSqueeze(input, c, rc, pool)); 340 } else { 341 JXL_RETURN_IF_ERROR(InvVSqueeze(input, c, rc, pool)); 342 } 343 } 344 input.channel.erase(input.channel.begin() + offset, 345 input.channel.begin() + offset + (endc - beginc + 1)); 346 } 347 return true; 348 } 349 350 } // namespace HWY_NAMESPACE 351 } // namespace jxl 352 HWY_AFTER_NAMESPACE(); 353 354 #if HWY_ONCE 355 356 namespace jxl { 357 358 HWY_EXPORT(InvSqueeze); 359 Status InvSqueeze(Image &input, const std::vector<SqueezeParams> ¶meters, 360 ThreadPool *pool) { 361 return HWY_DYNAMIC_DISPATCH(InvSqueeze)(input, parameters, pool); 362 } 363 364 void DefaultSqueezeParameters(std::vector<SqueezeParams> *parameters, 365 const Image &image) { 366 int nb_channels = image.channel.size() - image.nb_meta_channels; 367 368 parameters->clear(); 369 size_t w = image.channel[image.nb_meta_channels].w; 370 size_t h = image.channel[image.nb_meta_channels].h; 371 JXL_DEBUG_V( 372 7, "Default squeeze parameters for %" PRIuS "x%" PRIuS " image: ", w, h); 373 374 // do horizontal first on wide images; vertical first on tall images 375 bool wide = (w > h); 376 377 if (nb_channels > 2 && image.channel[image.nb_meta_channels + 1].w == w && 378 image.channel[image.nb_meta_channels + 1].h == h) { 379 // assume channels 1 and 2 are chroma, and can be squeezed first for 4:2:0 380 // previews 381 JXL_DEBUG_V(7, "(4:2:0 chroma), %" PRIuS "x%" PRIuS " image", w, h); 382 SqueezeParams params; 383 // horizontal chroma squeeze 384 params.horizontal = true; 385 params.in_place = false; 386 params.begin_c = image.nb_meta_channels + 1; 387 params.num_c = 2; 388 parameters->push_back(params); 389 params.horizontal = false; 390 // vertical chroma squeeze 391 parameters->push_back(params); 392 } 393 SqueezeParams params; 394 params.begin_c = image.nb_meta_channels; 395 params.num_c = nb_channels; 396 params.in_place = true; 397 398 if (!wide) { 399 if (h > kMaxFirstPreviewSize) { 400 params.horizontal = false; 401 parameters->push_back(params); 402 h = (h + 1) / 2; 403 JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h); 404 } 405 } 406 while (w > kMaxFirstPreviewSize || h > kMaxFirstPreviewSize) { 407 if (w > kMaxFirstPreviewSize) { 408 params.horizontal = true; 409 parameters->push_back(params); 410 w = (w + 1) / 2; 411 JXL_DEBUG_V(7, "Horizontal (%" PRIuS "x%" PRIuS "), ", w, h); 412 } 413 if (h > kMaxFirstPreviewSize) { 414 params.horizontal = false; 415 parameters->push_back(params); 416 h = (h + 1) / 2; 417 JXL_DEBUG_V(7, "Vertical (%" PRIuS "x%" PRIuS "), ", w, h); 418 } 419 } 420 JXL_DEBUG_V(7, "that's it"); 421 } 422 423 Status CheckMetaSqueezeParams(const SqueezeParams ¶meter, 424 int num_channels) { 425 int c1 = parameter.begin_c; 426 int c2 = parameter.begin_c + parameter.num_c - 1; 427 if (c1 < 0 || c1 >= num_channels || c2 < 0 || c2 >= num_channels || c2 < c1) { 428 return JXL_FAILURE("Invalid channel range"); 429 } 430 return true; 431 } 432 433 Status MetaSqueeze(Image &image, std::vector<SqueezeParams> *parameters) { 434 JxlMemoryManager *memory_manager = image.memory_manager(); 435 if (parameters->empty()) { 436 DefaultSqueezeParameters(parameters, image); 437 } 438 439 for (auto ¶meter : *parameters) { 440 JXL_RETURN_IF_ERROR( 441 CheckMetaSqueezeParams(parameter, image.channel.size())); 442 bool horizontal = parameter.horizontal; 443 bool in_place = parameter.in_place; 444 uint32_t beginc = parameter.begin_c; 445 uint32_t endc = parameter.begin_c + parameter.num_c - 1; 446 447 uint32_t offset; 448 if (beginc < image.nb_meta_channels) { 449 if (endc >= image.nb_meta_channels) { 450 return JXL_FAILURE("Invalid squeeze: mix of meta and nonmeta channels"); 451 } 452 if (!in_place) { 453 return JXL_FAILURE( 454 "Invalid squeeze: meta channels require in-place residuals"); 455 } 456 image.nb_meta_channels += parameter.num_c; 457 } 458 if (in_place) { 459 offset = endc + 1; 460 } else { 461 offset = image.channel.size(); 462 } 463 for (uint32_t c = beginc; c <= endc; c++) { 464 if (image.channel[c].hshift > 30 || image.channel[c].vshift > 30) { 465 return JXL_FAILURE("Too many squeezes: shift > 30"); 466 } 467 size_t w = image.channel[c].w; 468 size_t h = image.channel[c].h; 469 if (w == 0 || h == 0) return JXL_FAILURE("Squeezing empty channel"); 470 if (horizontal) { 471 image.channel[c].w = (w + 1) / 2; 472 if (image.channel[c].hshift >= 0) image.channel[c].hshift++; 473 w = w - (w + 1) / 2; 474 } else { 475 image.channel[c].h = (h + 1) / 2; 476 if (image.channel[c].vshift >= 0) image.channel[c].vshift++; 477 h = h - (h + 1) / 2; 478 } 479 JXL_RETURN_IF_ERROR(image.channel[c].shrink()); 480 JXL_ASSIGN_OR_RETURN(Channel placeholder, 481 Channel::Create(memory_manager, w, h)); 482 placeholder.hshift = image.channel[c].hshift; 483 placeholder.vshift = image.channel[c].vshift; 484 485 image.channel.insert(image.channel.begin() + offset + (c - beginc), 486 std::move(placeholder)); 487 JXL_DEBUG_V(8, "MetaSqueeze applied, current image: %s", 488 image.DebugString().c_str()); 489 } 490 } 491 return true; 492 } 493 494 } // namespace jxl 495 496 #endif