convolve-inl.h (11471B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #if defined(LIB_JXL_CONVOLVE_INL_H_) == defined(HWY_TARGET_TOGGLE) 7 #ifdef LIB_JXL_CONVOLVE_INL_H_ 8 #undef LIB_JXL_CONVOLVE_INL_H_ 9 #else 10 #define LIB_JXL_CONVOLVE_INL_H_ 11 #endif 12 13 #include <hwy/highway.h> 14 15 #include "lib/jxl/base/data_parallel.h" 16 #include "lib/jxl/base/rect.h" 17 #include "lib/jxl/base/status.h" 18 #include "lib/jxl/image_ops.h" 19 20 HWY_BEFORE_NAMESPACE(); 21 namespace jxl { 22 namespace HWY_NAMESPACE { 23 namespace { 24 25 // These templates are not found via ADL. 26 using hwy::HWY_NAMESPACE::Broadcast; 27 #if HWY_TARGET != HWY_SCALAR 28 using hwy::HWY_NAMESPACE::CombineShiftRightBytes; 29 #endif 30 using hwy::HWY_NAMESPACE::TableLookupLanes; 31 using hwy::HWY_NAMESPACE::Vec; 32 33 // Synthesizes left/right neighbors from a vector of center pixels. 34 class Neighbors { 35 public: 36 using D = HWY_CAPPED(float, 16); 37 using V = Vec<D>; 38 39 // Returns l[i] == c[Mirror(i - 1)]. 40 HWY_INLINE HWY_MAYBE_UNUSED static V FirstL1(const V c) { 41 #if HWY_CAP_GE256 42 const D d; 43 HWY_ALIGN constexpr int32_t lanes[16] = {0, 0, 1, 2, 3, 4, 5, 6, 44 7, 8, 9, 10, 11, 12, 13, 14}; 45 const auto indices = SetTableIndices(d, lanes); 46 // c = PONM'LKJI 47 return TableLookupLanes(c, indices); // ONML'KJII 48 #elif HWY_TARGET == HWY_SCALAR 49 return c; // Same (the first mirrored value is the last valid one) 50 #else // 128 bit 51 // c = LKJI 52 #if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) 53 return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(2, 1, 0, 0))}; // KJII 54 #else 55 const D d; 56 // TODO(deymo): Figure out if this can be optimized using a single vsri 57 // instruction to convert LKJI to KJII. 58 HWY_ALIGN constexpr int lanes[4] = {0, 0, 1, 2}; // KJII 59 const auto indices = SetTableIndices(d, lanes); 60 return TableLookupLanes(c, indices); 61 #endif 62 #endif 63 } 64 65 // Returns l[i] == c[Mirror(i - 2)]. 66 HWY_INLINE HWY_MAYBE_UNUSED static V FirstL2(const V c) { 67 #if HWY_CAP_GE256 68 const D d; 69 HWY_ALIGN constexpr int32_t lanes[16] = {1, 0, 0, 1, 2, 3, 4, 5, 70 6, 7, 8, 9, 10, 11, 12, 13}; 71 const auto indices = SetTableIndices(d, lanes); 72 // c = PONM'LKJI 73 return TableLookupLanes(c, indices); // NMLK'JIIJ 74 #elif HWY_TARGET == HWY_SCALAR 75 const D d; 76 JXL_DEBUG_ABORT("Unsupported"); 77 return Zero(d); 78 #else // 128 bit 79 // c = LKJI 80 #if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) 81 return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(1, 0, 0, 1))}; // JIIJ 82 #else 83 const D d; 84 HWY_ALIGN constexpr int lanes[4] = {1, 0, 0, 1}; // JIIJ 85 const auto indices = SetTableIndices(d, lanes); 86 return TableLookupLanes(c, indices); 87 #endif 88 #endif 89 } 90 91 // Returns l[i] == c[Mirror(i - 3)]. 92 HWY_INLINE HWY_MAYBE_UNUSED static V FirstL3(const V c) { 93 #if HWY_CAP_GE256 94 const D d; 95 HWY_ALIGN constexpr int32_t lanes[16] = {2, 1, 0, 0, 1, 2, 3, 4, 96 5, 6, 7, 8, 9, 10, 11, 12}; 97 const auto indices = SetTableIndices(d, lanes); 98 // c = PONM'LKJI 99 return TableLookupLanes(c, indices); // MLKJ'IIJK 100 #elif HWY_TARGET == HWY_SCALAR 101 const D d; 102 JXL_DEBUG_ABORT("Unsipported"); 103 return Zero(d); 104 #else // 128 bit 105 // c = LKJI 106 #if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) 107 return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(0, 0, 1, 2))}; // IIJK 108 #else 109 const D d; 110 HWY_ALIGN constexpr int lanes[4] = {2, 1, 0, 0}; // IIJK 111 const auto indices = SetTableIndices(d, lanes); 112 return TableLookupLanes(c, indices); 113 #endif 114 #endif 115 } 116 }; 117 118 #if HWY_TARGET != HWY_SCALAR 119 120 // Returns indices for SetTableIndices such that TableLookupLanes on the 121 // rightmost unaligned vector (rightmost sample in its most-significant lane) 122 // returns the mirrored values, with the mirror outside the last valid sample. 123 inline const int32_t* MirrorLanes(const size_t mod) { 124 const HWY_CAPPED(float, 16) d; 125 constexpr size_t kN = MaxLanes(d); 126 // typo:off 127 // For mod = `image width mod 16` 0..15: 128 // last full vec mirrored (mem order) loadedVec mirrorVec idxVec 129 // 0123456789abcdef| fedcba9876543210 fed..210 012..def 012..def 130 // 0123456789abcdef|0 0fedcba98765432 0fe..321 234..f00 123..eff 131 // 0123456789abcdef|01 10fedcba987654 10f..432 456..110 234..ffe 132 // 0123456789abcdef|012 210fedcba9876 210..543 67..2210 34..ffed 133 // 0123456789abcdef|0123 3210fedcba98 321..654 8..33210 4..ffedc 134 // 0123456789abcdef|01234 43210fedcba 135 // 0123456789abcdef|012345 543210fedc 136 // 0123456789abcdef|0123456 6543210fe 137 // 0123456789abcdef|01234567 76543210 138 // 0123456789abcdef|012345678 8765432 139 // 0123456789abcdef|0123456789 987654 140 // 0123456789abcdef|0123456789A A9876 141 // 0123456789abcdef|0123456789AB BA98 142 // 0123456789abcdef|0123456789ABC CBA 143 // 0123456789abcdef|0123456789ABCD DC 144 // 0123456789abcdef|0123456789ABCDE E EDC..10f EED..210 ffe..321 145 // typo:on 146 #if HWY_CAP_GE512 147 HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = { 148 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, // 149 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; 150 #elif HWY_CAP_GE256 151 HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = { 152 1, 2, 3, 4, 5, 6, 7, 7, // 153 6, 5, 4, 3, 2, 1, 0}; 154 #else // 128-bit 155 HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = {1, 2, 3, 3, // 156 2, 1, 0}; 157 #endif 158 return idx_lanes + kN - 1 - mod; 159 } 160 161 #endif // HWY_TARGET != HWY_SCALAR 162 163 // Single entry point for convolution. 164 // "Strategy" (Direct*/Separable*) decides kernel size and how to evaluate it. 165 template <class Strategy> 166 class ConvolveT { 167 static constexpr int64_t kRadius = Strategy::kRadius; 168 using Simd = HWY_CAPPED(float, 16); 169 170 public: 171 static size_t MinWidth() { 172 #if HWY_TARGET == HWY_SCALAR 173 // First/Last use mirrored loads of up to +/- kRadius. 174 return 2 * kRadius; 175 #else 176 return Lanes(Simd()) + kRadius; 177 #endif 178 } 179 180 // "Image" is ImageF or Image3F. 181 template <class Image, class Weights> 182 static void Run(const Image& in, const Rect& rect, const Weights& weights, 183 ThreadPool* pool, Image* out) { 184 JXL_DASSERT(SameSize(rect, *out)); 185 JXL_DASSERT(rect.xsize() >= MinWidth()); 186 187 static_assert(static_cast<int64_t>(kRadius) <= 3, 188 "Must handle [0, kRadius) and >= kRadius"); 189 switch (rect.xsize() % Lanes(Simd())) { 190 case 0: 191 return RunRows<0>(in, rect, weights, pool, out); 192 case 1: 193 return RunRows<1>(in, rect, weights, pool, out); 194 case 2: 195 return RunRows<2>(in, rect, weights, pool, out); 196 default: 197 return RunRows<3>(in, rect, weights, pool, out); 198 } 199 } 200 201 private: 202 template <size_t kSizeModN, class WrapRow, class Weights> 203 static JXL_INLINE void RunRow(const float* JXL_RESTRICT in, 204 const size_t xsize, const int64_t stride, 205 const WrapRow& wrap_row, const Weights& weights, 206 float* JXL_RESTRICT out) { 207 Strategy::template ConvolveRow<kSizeModN>(in, xsize, stride, wrap_row, 208 weights, out); 209 } 210 211 template <size_t kSizeModN, class Weights> 212 static JXL_INLINE void RunBorderRows(const ImageF& in, const Rect& rect, 213 const int64_t ybegin, const int64_t yend, 214 const Weights& weights, ImageF* out) { 215 const int64_t stride = in.PixelsPerRow(); 216 const WrapRowMirror wrap_row(in, rect.ysize()); 217 for (int64_t y = ybegin; y < yend; ++y) { 218 RunRow<kSizeModN>(rect.ConstRow(in, y), rect.xsize(), stride, wrap_row, 219 weights, out->Row(y)); 220 } 221 } 222 223 // Image3F. 224 template <size_t kSizeModN, class Weights> 225 static JXL_INLINE void RunBorderRows(const Image3F& in, const Rect& rect, 226 const int64_t ybegin, const int64_t yend, 227 const Weights& weights, Image3F* out) { 228 const int64_t stride = in.PixelsPerRow(); 229 for (int64_t y = ybegin; y < yend; ++y) { 230 for (size_t c = 0; c < 3; ++c) { 231 const WrapRowMirror wrap_row(in.Plane(c), rect.ysize()); 232 RunRow<kSizeModN>(rect.ConstPlaneRow(in, c, y), rect.xsize(), stride, 233 wrap_row, weights, out->PlaneRow(c, y)); 234 } 235 } 236 } 237 238 template <size_t kSizeModN, class Weights> 239 static JXL_INLINE void RunInteriorRows(const ImageF& in, const Rect& rect, 240 const int64_t ybegin, 241 const int64_t yend, 242 const Weights& weights, 243 ThreadPool* pool, ImageF* out) { 244 const int64_t stride = in.PixelsPerRow(); 245 const auto process_row = [&](const uint32_t y, size_t /*thread*/) HWY_ATTR { 246 RunRow<kSizeModN>(rect.ConstRow(in, y), rect.xsize(), stride, 247 WrapRowUnchanged(), weights, out->Row(y)); 248 return true; 249 }; 250 Status status = RunOnPool(pool, ybegin, yend, ThreadPool::NoInit, 251 process_row, "Convolve"); 252 (void)status; 253 JXL_DASSERT(status); 254 } 255 256 // Image3F. 257 template <size_t kSizeModN, class Weights> 258 static JXL_INLINE void RunInteriorRows(const Image3F& in, const Rect& rect, 259 const int64_t ybegin, 260 const int64_t yend, 261 const Weights& weights, 262 ThreadPool* pool, Image3F* out) { 263 const int64_t stride = in.PixelsPerRow(); 264 const auto process_row = [&](const uint32_t y, size_t /*thread*/) HWY_ATTR { 265 for (size_t c = 0; c < 3; ++c) { 266 RunRow<kSizeModN>(rect.ConstPlaneRow(in, c, y), rect.xsize(), stride, 267 WrapRowUnchanged(), weights, out->PlaneRow(c, y)); 268 } 269 return true; 270 }; 271 Status status = RunOnPool(pool, ybegin, yend, ThreadPool::NoInit, 272 process_row, "Convolve3"); 273 (void)status; 274 JXL_DASSERT(status); 275 } 276 277 template <size_t kSizeModN, class Image, class Weights> 278 static JXL_INLINE void RunRows(const Image& in, const Rect& rect, 279 const Weights& weights, ThreadPool* pool, 280 Image* out) { 281 const int64_t ysize = rect.ysize(); 282 RunBorderRows<kSizeModN>(in, rect, 0, 283 std::min(static_cast<int64_t>(kRadius), ysize), 284 weights, out); 285 if (ysize > 2 * static_cast<int64_t>(kRadius)) { 286 RunInteriorRows<kSizeModN>(in, rect, static_cast<int64_t>(kRadius), 287 ysize - static_cast<int64_t>(kRadius), weights, 288 pool, out); 289 } 290 if (ysize > static_cast<int64_t>(kRadius)) { 291 RunBorderRows<kSizeModN>(in, rect, ysize - static_cast<int64_t>(kRadius), 292 ysize, weights, out); 293 } 294 } 295 }; 296 297 } // namespace 298 // NOLINTNEXTLINE(google-readability-namespace-comments) 299 } // namespace HWY_NAMESPACE 300 } // namespace jxl 301 HWY_AFTER_NAMESPACE(); 302 303 #endif // LIB_JXL_CONVOLVE_INL_H_