compressed_dc.cc (11319B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/compressed_dc.h" 7 8 #include <jxl/memory_manager.h> 9 10 #include <algorithm> 11 #include <cstdint> 12 #include <cstdlib> 13 #include <cstring> 14 #include <vector> 15 16 #undef HWY_TARGET_INCLUDE 17 #define HWY_TARGET_INCLUDE "lib/jxl/compressed_dc.cc" 18 #include <hwy/foreach_target.h> 19 #include <hwy/highway.h> 20 21 #include "lib/jxl/base/compiler_specific.h" 22 #include "lib/jxl/base/data_parallel.h" 23 #include "lib/jxl/base/rect.h" 24 #include "lib/jxl/base/status.h" 25 #include "lib/jxl/image.h" 26 HWY_BEFORE_NAMESPACE(); 27 namespace jxl { 28 namespace HWY_NAMESPACE { 29 30 using D = HWY_FULL(float); 31 using DScalar = HWY_CAPPED(float, 1); 32 33 // These templates are not found via ADL. 34 using hwy::HWY_NAMESPACE::Abs; 35 using hwy::HWY_NAMESPACE::Add; 36 using hwy::HWY_NAMESPACE::Div; 37 using hwy::HWY_NAMESPACE::Max; 38 using hwy::HWY_NAMESPACE::Mul; 39 using hwy::HWY_NAMESPACE::MulAdd; 40 using hwy::HWY_NAMESPACE::Rebind; 41 using hwy::HWY_NAMESPACE::Sub; 42 using hwy::HWY_NAMESPACE::Vec; 43 using hwy::HWY_NAMESPACE::ZeroIfNegative; 44 45 // TODO(veluca): optimize constants. 46 const float w1 = 0.20345139757231578f; 47 const float w2 = 0.0334829185968739f; 48 const float w0 = 1.0f - 4.0f * (w1 + w2); 49 50 template <class V> 51 V MaxWorkaround(V a, V b) { 52 #if (HWY_TARGET == HWY_AVX3) && HWY_COMPILER_CLANG <= 800 53 // Prevents "Do not know how to split the result of this operator" error 54 return IfThenElse(a > b, a, b); 55 #else 56 return Max(a, b); 57 #endif 58 } 59 60 template <typename D> 61 JXL_INLINE void ComputePixelChannel(const D d, const float dc_factor, 62 const float* JXL_RESTRICT row_top, 63 const float* JXL_RESTRICT row, 64 const float* JXL_RESTRICT row_bottom, 65 Vec<D>* JXL_RESTRICT mc, 66 Vec<D>* JXL_RESTRICT sm, 67 Vec<D>* JXL_RESTRICT gap, size_t x) { 68 const auto tl = LoadU(d, row_top + x - 1); 69 const auto tc = Load(d, row_top + x); 70 const auto tr = LoadU(d, row_top + x + 1); 71 72 const auto ml = LoadU(d, row + x - 1); 73 *mc = Load(d, row + x); 74 const auto mr = LoadU(d, row + x + 1); 75 76 const auto bl = LoadU(d, row_bottom + x - 1); 77 const auto bc = Load(d, row_bottom + x); 78 const auto br = LoadU(d, row_bottom + x + 1); 79 80 const auto w_center = Set(d, w0); 81 const auto w_side = Set(d, w1); 82 const auto w_corner = Set(d, w2); 83 84 const auto corner = Add(Add(tl, tr), Add(bl, br)); 85 const auto side = Add(Add(ml, mr), Add(tc, bc)); 86 *sm = MulAdd(corner, w_corner, MulAdd(side, w_side, Mul(*mc, w_center))); 87 88 const auto dc_quant = Set(d, dc_factor); 89 *gap = MaxWorkaround(*gap, Abs(Div(Sub(*mc, *sm), dc_quant))); 90 } 91 92 template <typename D> 93 JXL_INLINE void ComputePixel( 94 const float* JXL_RESTRICT dc_factors, 95 const float* JXL_RESTRICT* JXL_RESTRICT rows_top, 96 const float* JXL_RESTRICT* JXL_RESTRICT rows, 97 const float* JXL_RESTRICT* JXL_RESTRICT rows_bottom, 98 float* JXL_RESTRICT* JXL_RESTRICT out_rows, size_t x) { 99 const D d; 100 auto mc_x = Undefined(d); 101 auto mc_y = Undefined(d); 102 auto mc_b = Undefined(d); 103 auto sm_x = Undefined(d); 104 auto sm_y = Undefined(d); 105 auto sm_b = Undefined(d); 106 auto gap = Set(d, 0.5f); 107 ComputePixelChannel(d, dc_factors[0], rows_top[0], rows[0], rows_bottom[0], 108 &mc_x, &sm_x, &gap, x); 109 ComputePixelChannel(d, dc_factors[1], rows_top[1], rows[1], rows_bottom[1], 110 &mc_y, &sm_y, &gap, x); 111 ComputePixelChannel(d, dc_factors[2], rows_top[2], rows[2], rows_bottom[2], 112 &mc_b, &sm_b, &gap, x); 113 auto factor = MulAdd(Set(d, -4.0f), gap, Set(d, 3.0f)); 114 factor = ZeroIfNegative(factor); 115 116 auto out = MulAdd(Sub(sm_x, mc_x), factor, mc_x); 117 Store(out, d, out_rows[0] + x); 118 out = MulAdd(Sub(sm_y, mc_y), factor, mc_y); 119 Store(out, d, out_rows[1] + x); 120 out = MulAdd(Sub(sm_b, mc_b), factor, mc_b); 121 Store(out, d, out_rows[2] + x); 122 } 123 124 Status AdaptiveDCSmoothing(JxlMemoryManager* memory_manager, 125 const float* dc_factors, Image3F* dc, 126 ThreadPool* pool) { 127 const size_t xsize = dc->xsize(); 128 const size_t ysize = dc->ysize(); 129 if (ysize <= 2 || xsize <= 2) return true; 130 131 // TODO(veluca): use tile-based processing? 132 // TODO(veluca): decide if changes to the y channel should be propagated to 133 // the x and b channels through color correlation. 134 JXL_ENSURE(w1 + w2 < 0.25f); 135 136 JXL_ASSIGN_OR_RETURN(Image3F smoothed, 137 Image3F::Create(memory_manager, xsize, ysize)); 138 // Fill in borders that the loop below will not. First and last are unused. 139 for (size_t c = 0; c < 3; c++) { 140 for (size_t y : {static_cast<size_t>(0), ysize - 1}) { 141 memcpy(smoothed.PlaneRow(c, y), dc->PlaneRow(c, y), 142 xsize * sizeof(float)); 143 } 144 } 145 auto process_row = [&](const uint32_t y, size_t /*thread*/) -> Status { 146 const float* JXL_RESTRICT rows_top[3]{ 147 dc->ConstPlaneRow(0, y - 1), 148 dc->ConstPlaneRow(1, y - 1), 149 dc->ConstPlaneRow(2, y - 1), 150 }; 151 const float* JXL_RESTRICT rows[3] = { 152 dc->ConstPlaneRow(0, y), 153 dc->ConstPlaneRow(1, y), 154 dc->ConstPlaneRow(2, y), 155 }; 156 const float* JXL_RESTRICT rows_bottom[3] = { 157 dc->ConstPlaneRow(0, y + 1), 158 dc->ConstPlaneRow(1, y + 1), 159 dc->ConstPlaneRow(2, y + 1), 160 }; 161 float* JXL_RESTRICT rows_out[3] = { 162 smoothed.PlaneRow(0, y), 163 smoothed.PlaneRow(1, y), 164 smoothed.PlaneRow(2, y), 165 }; 166 for (size_t x : {static_cast<size_t>(0), xsize - 1}) { 167 for (size_t c = 0; c < 3; c++) { 168 rows_out[c][x] = rows[c][x]; 169 } 170 } 171 172 size_t x = 1; 173 // First pixels 174 const size_t N = Lanes(D()); 175 for (; x < std::min(N, xsize - 1); x++) { 176 ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out, 177 x); 178 } 179 // Full vectors. 180 for (; x + N <= xsize - 1; x += N) { 181 ComputePixel<D>(dc_factors, rows_top, rows, rows_bottom, rows_out, x); 182 } 183 // Last pixels. 184 for (; x < xsize - 1; x++) { 185 ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out, 186 x); 187 } 188 return true; 189 }; 190 JXL_RETURN_IF_ERROR(RunOnPool(pool, 1, ysize - 1, ThreadPool::NoInit, 191 process_row, "DCSmoothingRow")); 192 dc->Swap(smoothed); 193 return true; 194 } 195 196 // DC dequantization. 197 void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, 198 const float* dc_factors, float mul, const float* cfl_factors, 199 const YCbCrChromaSubsampling& chroma_subsampling, 200 const BlockCtxMap& bctx) { 201 const HWY_FULL(float) df; 202 const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float 203 if (chroma_subsampling.Is444()) { 204 const auto fac_x = Set(df, dc_factors[0] * mul); 205 const auto fac_y = Set(df, dc_factors[1] * mul); 206 const auto fac_b = Set(df, dc_factors[2] * mul); 207 const auto cfl_fac_x = Set(df, cfl_factors[0]); 208 const auto cfl_fac_b = Set(df, cfl_factors[2]); 209 for (size_t y = 0; y < r.ysize(); y++) { 210 float* dec_row_x = r.PlaneRow(dc, 0, y); 211 float* dec_row_y = r.PlaneRow(dc, 1, y); 212 float* dec_row_b = r.PlaneRow(dc, 2, y); 213 const int32_t* quant_row_x = in.channel[1].plane.Row(y); 214 const int32_t* quant_row_y = in.channel[0].plane.Row(y); 215 const int32_t* quant_row_b = in.channel[2].plane.Row(y); 216 for (size_t x = 0; x < r.xsize(); x += Lanes(di)) { 217 const auto in_q_x = Load(di, quant_row_x + x); 218 const auto in_q_y = Load(di, quant_row_y + x); 219 const auto in_q_b = Load(di, quant_row_b + x); 220 const auto in_x = Mul(ConvertTo(df, in_q_x), fac_x); 221 const auto in_y = Mul(ConvertTo(df, in_q_y), fac_y); 222 const auto in_b = Mul(ConvertTo(df, in_q_b), fac_b); 223 Store(in_y, df, dec_row_y + x); 224 Store(MulAdd(in_y, cfl_fac_x, in_x), df, dec_row_x + x); 225 Store(MulAdd(in_y, cfl_fac_b, in_b), df, dec_row_b + x); 226 } 227 } 228 } else { 229 for (size_t c : {1, 0, 2}) { 230 Rect rect(r.x0() >> chroma_subsampling.HShift(c), 231 r.y0() >> chroma_subsampling.VShift(c), 232 r.xsize() >> chroma_subsampling.HShift(c), 233 r.ysize() >> chroma_subsampling.VShift(c)); 234 const auto fac = Set(df, dc_factors[c] * mul); 235 const Channel& ch = in.channel[c < 2 ? c ^ 1 : c]; 236 for (size_t y = 0; y < rect.ysize(); y++) { 237 const int32_t* quant_row = ch.plane.Row(y); 238 float* row = rect.PlaneRow(dc, c, y); 239 for (size_t x = 0; x < rect.xsize(); x += Lanes(di)) { 240 const auto in_q = Load(di, quant_row + x); 241 const auto in = Mul(ConvertTo(df, in_q), fac); 242 Store(in, df, row + x); 243 } 244 } 245 } 246 } 247 if (bctx.num_dc_ctxs <= 1) { 248 for (size_t y = 0; y < r.ysize(); y++) { 249 uint8_t* qdc_row = r.Row(quant_dc, y); 250 memset(qdc_row, 0, sizeof(*qdc_row) * r.xsize()); 251 } 252 } else { 253 for (size_t y = 0; y < r.ysize(); y++) { 254 uint8_t* qdc_row_val = r.Row(quant_dc, y); 255 const int32_t* quant_row_x = 256 in.channel[1].plane.Row(y >> chroma_subsampling.VShift(0)); 257 const int32_t* quant_row_y = 258 in.channel[0].plane.Row(y >> chroma_subsampling.VShift(1)); 259 const int32_t* quant_row_b = 260 in.channel[2].plane.Row(y >> chroma_subsampling.VShift(2)); 261 for (size_t x = 0; x < r.xsize(); x++) { 262 int bucket_x = 0; 263 int bucket_y = 0; 264 int bucket_b = 0; 265 for (int t : bctx.dc_thresholds[0]) { 266 if (quant_row_x[x >> chroma_subsampling.HShift(0)] > t) bucket_x++; 267 } 268 for (int t : bctx.dc_thresholds[1]) { 269 if (quant_row_y[x >> chroma_subsampling.HShift(1)] > t) bucket_y++; 270 } 271 for (int t : bctx.dc_thresholds[2]) { 272 if (quant_row_b[x >> chroma_subsampling.HShift(2)] > t) bucket_b++; 273 } 274 int bucket = bucket_x; 275 bucket *= bctx.dc_thresholds[2].size() + 1; 276 bucket += bucket_b; 277 bucket *= bctx.dc_thresholds[1].size() + 1; 278 bucket += bucket_y; 279 qdc_row_val[x] = bucket; 280 } 281 } 282 } 283 } 284 285 // NOLINTNEXTLINE(google-readability-namespace-comments) 286 } // namespace HWY_NAMESPACE 287 } // namespace jxl 288 HWY_AFTER_NAMESPACE(); 289 290 #if HWY_ONCE 291 namespace jxl { 292 293 HWY_EXPORT(DequantDC); 294 HWY_EXPORT(AdaptiveDCSmoothing); 295 Status AdaptiveDCSmoothing(JxlMemoryManager* memory_manager, 296 const float* dc_factors, Image3F* dc, 297 ThreadPool* pool) { 298 return HWY_DYNAMIC_DISPATCH(AdaptiveDCSmoothing)(memory_manager, dc_factors, 299 dc, pool); 300 } 301 302 void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in, 303 const float* dc_factors, float mul, const float* cfl_factors, 304 const YCbCrChromaSubsampling& chroma_subsampling, 305 const BlockCtxMap& bctx) { 306 HWY_DYNAMIC_DISPATCH(DequantDC) 307 (r, dc, quant_dc, in, dc_factors, mul, cfl_factors, chroma_subsampling, bctx); 308 } 309 310 } // namespace jxl 311 #endif // HWY_ONCE