matrix_ops.h (3039B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #ifndef LIB_JXL_BASE_MATRIX_OPS_H_ 7 #define LIB_JXL_BASE_MATRIX_OPS_H_ 8 9 // 3x3 matrix operations. 10 11 #include <array> 12 #include <cmath> // abs 13 #include <cstddef> 14 15 #include "lib/jxl/base/status.h" 16 17 namespace jxl { 18 19 typedef std::array<float, 3> Vector3; 20 typedef std::array<double, 3> Vector3d; 21 typedef std::array<Vector3, 3> Matrix3x3; 22 typedef std::array<Vector3d, 3> Matrix3x3d; 23 24 // Computes C = A * B, where A, B, C are 3x3 matrices. 25 template <typename Matrix> 26 void Mul3x3Matrix(const Matrix& a, const Matrix& b, Matrix& c) { 27 for (size_t x = 0; x < 3; x++) { 28 alignas(16) Vector3d temp{b[0][x], b[1][x], b[2][x]}; // transpose 29 for (size_t y = 0; y < 3; y++) { 30 c[y][x] = a[y][0] * temp[0] + a[y][1] * temp[1] + a[y][2] * temp[2]; 31 } 32 } 33 } 34 35 // Computes C = A * B, where A is 3x3 matrix and B is vector. 36 template <typename Matrix, typename Vector> 37 void Mul3x3Vector(const Matrix& a, const Vector& b, Vector& c) { 38 for (size_t y = 0; y < 3; y++) { 39 double e = 0; 40 for (size_t x = 0; x < 3; x++) { 41 e += a[y][x] * b[x]; 42 } 43 c[y] = e; 44 } 45 } 46 47 // Inverts a 3x3 matrix in place. 48 template <typename Matrix> 49 Status Inv3x3Matrix(Matrix& matrix) { 50 // Intermediate computation is done in double precision. 51 Matrix3x3d temp; 52 temp[0][0] = static_cast<double>(matrix[1][1]) * matrix[2][2] - 53 static_cast<double>(matrix[1][2]) * matrix[2][1]; 54 temp[0][1] = static_cast<double>(matrix[0][2]) * matrix[2][1] - 55 static_cast<double>(matrix[0][1]) * matrix[2][2]; 56 temp[0][2] = static_cast<double>(matrix[0][1]) * matrix[1][2] - 57 static_cast<double>(matrix[0][2]) * matrix[1][1]; 58 temp[1][0] = static_cast<double>(matrix[1][2]) * matrix[2][0] - 59 static_cast<double>(matrix[1][0]) * matrix[2][2]; 60 temp[1][1] = static_cast<double>(matrix[0][0]) * matrix[2][2] - 61 static_cast<double>(matrix[0][2]) * matrix[2][0]; 62 temp[1][2] = static_cast<double>(matrix[0][2]) * matrix[1][0] - 63 static_cast<double>(matrix[0][0]) * matrix[1][2]; 64 temp[2][0] = static_cast<double>(matrix[1][0]) * matrix[2][1] - 65 static_cast<double>(matrix[1][1]) * matrix[2][0]; 66 temp[2][1] = static_cast<double>(matrix[0][1]) * matrix[2][0] - 67 static_cast<double>(matrix[0][0]) * matrix[2][1]; 68 temp[2][2] = static_cast<double>(matrix[0][0]) * matrix[1][1] - 69 static_cast<double>(matrix[0][1]) * matrix[1][0]; 70 double det = matrix[0][0] * temp[0][0] + matrix[0][1] * temp[1][0] + 71 matrix[0][2] * temp[2][0]; 72 if (std::abs(det) < 1e-10) { 73 return JXL_FAILURE("Matrix determinant is too close to 0"); 74 } 75 double idet = 1.0 / det; 76 for (size_t j = 0; j < 3; j++) { 77 for (size_t i = 0; i < 3; i++) { 78 matrix[j][i] = temp[j][i] * idet; 79 } 80 } 81 return true; 82 } 83 84 } // namespace jxl 85 86 #endif // LIB_JXL_BASE_MATRIX_OPS_H_