enc_noise.cc (13342B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_noise.h" 7 8 #include <algorithm> 9 #include <cstdint> 10 #include <cstdlib> 11 #include <numeric> 12 #include <utility> 13 14 #include "lib/jxl/base/status.h" 15 #include "lib/jxl/enc_aux_out.h" 16 #include "lib/jxl/enc_optimize.h" 17 18 namespace jxl { 19 namespace { 20 21 using OptimizeArray = optimize::Array<double, NoiseParams::kNumNoisePoints>; 22 23 float GetScoreSumsOfAbsoluteDifferences(const Image3F& opsin, const int x, 24 const int y, const int block_size) { 25 const int small_bl_size_x = 3; 26 const int small_bl_size_y = 4; 27 const int kNumSAD = 28 (block_size - small_bl_size_x) * (block_size - small_bl_size_y); 29 // block_size x block_size reference pixels 30 int counter = 0; 31 const int offset = 2; 32 33 std::vector<float> sad(kNumSAD, 0); 34 for (int y_bl = 0; y_bl + small_bl_size_y < block_size; ++y_bl) { 35 for (int x_bl = 0; x_bl + small_bl_size_x < block_size; ++x_bl) { 36 float sad_sum = 0; 37 // size of the center patch, we compare all the patches inside window with 38 // the center one 39 for (int cy = 0; cy < small_bl_size_y; ++cy) { 40 for (int cx = 0; cx < small_bl_size_x; ++cx) { 41 float wnd = 0.5f * (opsin.PlaneRow(1, y + y_bl + cy)[x + x_bl + cx] + 42 opsin.PlaneRow(0, y + y_bl + cy)[x + x_bl + cx]); 43 float center = 44 0.5f * (opsin.PlaneRow(1, y + offset + cy)[x + offset + cx] + 45 opsin.PlaneRow(0, y + offset + cy)[x + offset + cx]); 46 sad_sum += std::abs(center - wnd); 47 } 48 } 49 sad[counter++] = sad_sum; 50 } 51 } 52 const int kSamples = (kNumSAD) / 2; 53 // As with ROAD (rank order absolute distance), we keep the smallest half of 54 // the values in SAD (we use here the more robust patch SAD instead of 55 // absolute single-pixel differences). 56 std::sort(sad.begin(), sad.end()); 57 const float total_sad_sum = 58 std::accumulate(sad.begin(), sad.begin() + kSamples, 0.0f); 59 return total_sad_sum / kSamples; 60 } 61 62 class NoiseHistogram { 63 public: 64 static constexpr int kBins = 256; 65 66 NoiseHistogram() { std::fill(bins, bins + kBins, 0); } 67 68 void Increment(const float x) { bins[Index(x)] += 1; } 69 int Get(const float x) const { return bins[Index(x)]; } 70 int Bin(const size_t bin) const { return bins[bin]; } 71 72 int Mode() const { 73 size_t max_idx = 0; 74 for (size_t i = 0; i < kBins; i++) { 75 if (bins[i] > bins[max_idx]) max_idx = i; 76 } 77 return max_idx; 78 } 79 80 double Quantile(double q01) const { 81 const int64_t total = std::accumulate(bins, bins + kBins, int64_t{1}); 82 const int64_t target = static_cast<int64_t>(q01 * total); 83 // Until sum >= target: 84 int64_t sum = 0; 85 size_t i = 0; 86 for (; i < kBins; ++i) { 87 sum += bins[i]; 88 // Exact match: assume middle of bin i 89 if (sum == target) { 90 return i + 0.5; 91 } 92 if (sum > target) break; 93 } 94 95 // Next non-empty bin (in case histogram is sparsely filled) 96 size_t next = i + 1; 97 while (next < kBins && bins[next] == 0) { 98 ++next; 99 } 100 101 // Linear interpolation according to how far into next we went 102 const double excess = target - sum; 103 const double weight_next = bins[Index(next)] / excess; 104 return ClampX(next * weight_next + i * (1.0 - weight_next)); 105 } 106 107 // Inter-quartile range 108 double IQR() const { return Quantile(0.75) - Quantile(0.25); } 109 110 private: 111 template <typename T> 112 T ClampX(const T x) const { 113 return std::min(std::max(static_cast<T>(0), x), static_cast<T>(kBins - 1)); 114 } 115 size_t Index(const float x) const { return ClampX(static_cast<int>(x)); } 116 117 uint32_t bins[kBins]; 118 }; 119 120 std::vector<float> GetSADScoresForPatches(const Image3F& opsin, 121 const size_t block_s, 122 const size_t num_bin, 123 NoiseHistogram* sad_histogram) { 124 std::vector<float> sad_scores( 125 (opsin.ysize() / block_s) * (opsin.xsize() / block_s), 0.0f); 126 127 int block_index = 0; 128 129 for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) { 130 for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) { 131 float sad_sc = GetScoreSumsOfAbsoluteDifferences(opsin, x, y, block_s); 132 sad_scores[block_index++] = sad_sc; 133 sad_histogram->Increment(sad_sc * num_bin); 134 } 135 } 136 return sad_scores; 137 } 138 139 float GetSADThreshold(const NoiseHistogram& histogram, const int num_bin) { 140 // Here we assume that the most patches with similar SAD value is a "flat" 141 // patches. However, some images might contain regular texture part and 142 // generate second strong peak at the histogram 143 // TODO(user) handle bimodal and heavy-tailed case 144 const int mode = histogram.Mode(); 145 return static_cast<float>(mode) / NoiseHistogram::kBins; 146 } 147 148 // loss = sum asym * (F(x) - nl)^2 + kReg * num_points * sum (w[i] - w[i+1])^2 149 // where asym = 1 if F(x) < nl, kAsym if F(x) > nl. 150 struct LossFunction { 151 explicit LossFunction(std::vector<NoiseLevel> nl0) : nl(std::move(nl0)) {} 152 153 double Compute(const OptimizeArray& w, OptimizeArray* df, 154 bool skip_regularization = false) const { 155 constexpr double kReg = 0.005; 156 constexpr double kAsym = 1.1; 157 double loss_function = 0; 158 for (size_t i = 0; i < w.size(); i++) { 159 (*df)[i] = 0; 160 } 161 for (auto ind : nl) { 162 std::pair<int, float> pos = IndexAndFrac(ind.intensity); 163 JXL_DASSERT(pos.first >= 0 && static_cast<size_t>(pos.first) < 164 NoiseParams::kNumNoisePoints - 1); 165 double low = w[pos.first]; 166 double hi = w[pos.first + 1]; 167 double val = low * (1.0f - pos.second) + hi * pos.second; 168 double dist = val - ind.noise_level; 169 if (dist > 0) { 170 loss_function += kAsym * dist * dist; 171 (*df)[pos.first] -= kAsym * (1.0f - pos.second) * dist; 172 (*df)[pos.first + 1] -= kAsym * pos.second * dist; 173 } else { 174 loss_function += dist * dist; 175 (*df)[pos.first] -= (1.0f - pos.second) * dist; 176 (*df)[pos.first + 1] -= pos.second * dist; 177 } 178 } 179 if (skip_regularization) return loss_function; 180 for (size_t i = 0; i + 1 < w.size(); i++) { 181 double diff = w[i] - w[i + 1]; 182 loss_function += kReg * nl.size() * diff * diff; 183 (*df)[i] -= kReg * diff * nl.size(); 184 (*df)[i + 1] += kReg * diff * nl.size(); 185 } 186 return loss_function; 187 } 188 189 std::vector<NoiseLevel> nl; 190 }; 191 192 void OptimizeNoiseParameters(const std::vector<NoiseLevel>& noise_level, 193 NoiseParams* noise_params) { 194 constexpr double kMaxError = 1e-3; 195 static const double kPrecision = 1e-8; 196 static const int kMaxIter = 40; 197 198 float avg = 0; 199 for (const NoiseLevel& nl : noise_level) { 200 avg += nl.noise_level; 201 } 202 avg /= noise_level.size(); 203 204 LossFunction loss_function(noise_level); 205 OptimizeArray parameter_vector; 206 for (size_t i = 0; i < parameter_vector.size(); i++) { 207 parameter_vector[i] = avg; 208 } 209 210 parameter_vector = optimize::OptimizeWithScaledConjugateGradientMethod( 211 loss_function, parameter_vector, kPrecision, kMaxIter); 212 213 OptimizeArray df = parameter_vector; 214 float loss = loss_function.Compute(parameter_vector, &df, 215 /*skip_regularization=*/true) / 216 noise_level.size(); 217 218 // Approximation went too badly: escape with no noise at all. 219 if (loss > kMaxError) { 220 noise_params->Clear(); 221 return; 222 } 223 224 for (size_t i = 0; i < parameter_vector.size(); i++) { 225 noise_params->lut[i] = std::max(parameter_vector[i], 0.0); 226 } 227 } 228 229 std::vector<NoiseLevel> GetNoiseLevel( 230 const Image3F& opsin, const std::vector<float>& texture_strength, 231 const float threshold, const size_t block_s) { 232 std::vector<NoiseLevel> noise_level_per_intensity; 233 234 const int filt_size = 1; 235 static const float kLaplFilter[filt_size * 2 + 1][filt_size * 2 + 1] = { 236 {-0.25f, -1.0f, -0.25f}, 237 {-1.0f, 5.0f, -1.0f}, 238 {-0.25f, -1.0f, -0.25f}, 239 }; 240 241 // The noise model is built based on channel 0.5 * (X+Y) as we notice that it 242 // is similar to the model 0.5 * (Y-X) 243 size_t patch_index = 0; 244 245 for (size_t y = 0; y + block_s <= opsin.ysize(); y += block_s) { 246 for (size_t x = 0; x + block_s <= opsin.xsize(); x += block_s) { 247 if (texture_strength[patch_index] <= threshold) { 248 // Calculate mean value 249 float mean_int = 0; 250 for (size_t y_bl = 0; y_bl < block_s; ++y_bl) { 251 for (size_t x_bl = 0; x_bl < block_s; ++x_bl) { 252 mean_int += 0.5f * (opsin.PlaneRow(1, y + y_bl)[x + x_bl] + 253 opsin.PlaneRow(0, y + y_bl)[x + x_bl]); 254 } 255 } 256 mean_int /= block_s * block_s; 257 258 // Calculate Noise level 259 float noise_level = 0; 260 size_t count = 0; 261 for (size_t y_bl = 0; y_bl < block_s; ++y_bl) { 262 for (size_t x_bl = 0; x_bl < block_s; ++x_bl) { 263 float filtered_value = 0; 264 for (int y_f = -1 * filt_size; y_f <= filt_size; ++y_f) { 265 if ((static_cast<ssize_t>(y_bl) + y_f) >= 0 && 266 (y_bl + y_f) < block_s) { 267 for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) { 268 if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 && 269 (x_bl + x_f) < block_s) { 270 filtered_value += 271 0.5f * 272 (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl + x_f] + 273 opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl + x_f]) * 274 kLaplFilter[y_f + filt_size][x_f + filt_size]; 275 } else { 276 filtered_value += 277 0.5f * 278 (opsin.PlaneRow(1, y + y_bl + y_f)[x + x_bl - x_f] + 279 opsin.PlaneRow(0, y + y_bl + y_f)[x + x_bl - x_f]) * 280 kLaplFilter[y_f + filt_size][x_f + filt_size]; 281 } 282 } 283 } else { 284 for (int x_f = -1 * filt_size; x_f <= filt_size; ++x_f) { 285 if ((static_cast<ssize_t>(x_bl) + x_f) >= 0 && 286 (x_bl + x_f) < block_s) { 287 filtered_value += 288 0.5f * 289 (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl + x_f] + 290 opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl + x_f]) * 291 kLaplFilter[y_f + filt_size][x_f + filt_size]; 292 } else { 293 filtered_value += 294 0.5f * 295 (opsin.PlaneRow(1, y + y_bl - y_f)[x + x_bl - x_f] + 296 opsin.PlaneRow(0, y + y_bl - y_f)[x + x_bl - x_f]) * 297 kLaplFilter[y_f + filt_size][x_f + filt_size]; 298 } 299 } 300 } 301 } 302 noise_level += std::abs(filtered_value); 303 ++count; 304 } 305 } 306 noise_level /= count; 307 NoiseLevel nl; 308 nl.intensity = mean_int; 309 nl.noise_level = noise_level; 310 noise_level_per_intensity.push_back(nl); 311 } 312 ++patch_index; 313 } 314 } 315 return noise_level_per_intensity; 316 } 317 318 Status EncodeFloatParam(float val, float precision, BitWriter* writer) { 319 JXL_ENSURE(val >= 0); 320 const int absval_quant = static_cast<int>(std::lround(val * precision)); 321 JXL_ENSURE(absval_quant < (1 << 10)); 322 writer->Write(10, absval_quant); 323 return true; 324 } 325 326 } // namespace 327 328 Status GetNoiseParameter(const Image3F& opsin, NoiseParams* noise_params, 329 float quality_coef) { 330 // The size of a patch in decoder might be different from encoder's patch 331 // size. 332 // For encoder: the patch size should be big enough to estimate 333 // noise level, but, at the same time, it should be not too big 334 // to be able to estimate intensity value of the patch 335 const size_t block_s = 8; 336 const size_t kNumBin = 256; 337 NoiseHistogram sad_histogram; 338 std::vector<float> sad_scores = 339 GetSADScoresForPatches(opsin, block_s, kNumBin, &sad_histogram); 340 float sad_threshold = GetSADThreshold(sad_histogram, kNumBin); 341 // If threshold is too large, the image has a strong pattern. This pattern 342 // fools our model and it will add too much noise. Therefore, we do not add 343 // noise for such images 344 if (sad_threshold > 0.15f || sad_threshold <= 0.0f) { 345 noise_params->Clear(); 346 return false; 347 } 348 std::vector<NoiseLevel> nl = 349 GetNoiseLevel(opsin, sad_scores, sad_threshold, block_s); 350 351 OptimizeNoiseParameters(nl, noise_params); 352 for (float& i : noise_params->lut) { 353 i *= quality_coef * 1.4; 354 } 355 return noise_params->HasAny(); 356 } 357 358 Status EncodeNoise(const NoiseParams& noise_params, BitWriter* writer, 359 LayerType layer, AuxOut* aux_out) { 360 JXL_ENSURE(noise_params.HasAny()); 361 362 return writer->WithMaxBits( 363 NoiseParams::kNumNoisePoints * 16, layer, aux_out, [&]() -> Status { 364 for (float i : noise_params.lut) { 365 JXL_RETURN_IF_ERROR(EncodeFloatParam(i, kNoisePrecision, writer)); 366 } 367 return true; 368 }); 369 } 370 371 } // namespace jxl