tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

matrix_utils.h (17458B)


      1 //
      2 // Copyright 2015 The ANGLE Project Authors. All rights reserved.
      3 // Use of this source code is governed by a BSD-style license that can be
      4 // found in the LICENSE file.
      5 //
      6 // Matrix:
      7 //   Utility class implementing various matrix operations.
      8 //   Supports matrices with minimum 2 and maximum 4 number of rows/columns.
      9 //
     10 // TODO: Check if we can merge Matrix.h in sample_util with this and replace it with this
     11 // implementation.
     12 // TODO: Rename this file to Matrix.h once we remove Matrix.h in sample_util.
     13 
     14 #ifndef COMMON_MATRIX_UTILS_H_
     15 #define COMMON_MATRIX_UTILS_H_
     16 
     17 #include <vector>
     18 
     19 #include "common/debug.h"
     20 #include "common/mathutil.h"
     21 #include "common/vector_utils.h"
     22 
     23 namespace angle
     24 {
     25 
     26 template <typename T>
     27 class Matrix
     28 {
     29  public:
     30    Matrix(const std::vector<T> &elements, const unsigned int numRows, const unsigned int numCols)
     31        : mElements(elements), mRows(numRows), mCols(numCols)
     32    {
     33        ASSERT(rows() >= 1 && rows() <= 4);
     34        ASSERT(columns() >= 1 && columns() <= 4);
     35    }
     36 
     37    Matrix(const std::vector<T> &elements, const unsigned int size)
     38        : mElements(elements), mRows(size), mCols(size)
     39    {
     40        ASSERT(rows() >= 1 && rows() <= 4);
     41        ASSERT(columns() >= 1 && columns() <= 4);
     42    }
     43 
     44    Matrix(const T *elements, const unsigned int size) : mRows(size), mCols(size)
     45    {
     46        ASSERT(rows() >= 1 && rows() <= 4);
     47        ASSERT(columns() >= 1 && columns() <= 4);
     48        for (size_t i = 0; i < size * size; i++)
     49            mElements.push_back(elements[i]);
     50    }
     51 
     52    const T &operator()(const unsigned int rowIndex, const unsigned int columnIndex) const
     53    {
     54        ASSERT(rowIndex < mRows);
     55        ASSERT(columnIndex < mCols);
     56        return mElements[rowIndex * columns() + columnIndex];
     57    }
     58 
     59    T &operator()(const unsigned int rowIndex, const unsigned int columnIndex)
     60    {
     61        ASSERT(rowIndex < mRows);
     62        ASSERT(columnIndex < mCols);
     63        return mElements[rowIndex * columns() + columnIndex];
     64    }
     65 
     66    const T &at(const unsigned int rowIndex, const unsigned int columnIndex) const
     67    {
     68        ASSERT(rowIndex < mRows);
     69        ASSERT(columnIndex < mCols);
     70        return operator()(rowIndex, columnIndex);
     71    }
     72 
     73    Matrix<T> operator*(const Matrix<T> &m)
     74    {
     75        ASSERT(columns() == m.rows());
     76 
     77        unsigned int resultRows = rows();
     78        unsigned int resultCols = m.columns();
     79        Matrix<T> result(std::vector<T>(resultRows * resultCols), resultRows, resultCols);
     80        for (unsigned int i = 0; i < resultRows; i++)
     81        {
     82            for (unsigned int j = 0; j < resultCols; j++)
     83            {
     84                T tmp = 0.0f;
     85                for (unsigned int k = 0; k < columns(); k++)
     86                    tmp += at(i, k) * m(k, j);
     87                result(i, j) = tmp;
     88            }
     89        }
     90 
     91        return result;
     92    }
     93 
     94    void operator*=(const Matrix<T> &m)
     95    {
     96        ASSERT(columns() == m.rows());
     97        Matrix<T> res  = (*this) * m;
     98        size_t numElts = res.elements().size();
     99        mElements.resize(numElts);
    100        memcpy(mElements.data(), res.data(), numElts * sizeof(float));
    101    }
    102 
    103    bool operator==(const Matrix<T> &m) const
    104    {
    105        ASSERT(columns() == m.columns());
    106        ASSERT(rows() == m.rows());
    107        return mElements == m.elements();
    108    }
    109 
    110    bool operator!=(const Matrix<T> &m) const { return !(mElements == m.elements()); }
    111 
    112    bool nearlyEqual(T epsilon, const Matrix<T> &m) const
    113    {
    114        ASSERT(columns() == m.columns());
    115        ASSERT(rows() == m.rows());
    116        const auto &otherElts = m.elements();
    117        for (size_t i = 0; i < otherElts.size(); i++)
    118        {
    119            if ((mElements[i] - otherElts[i] > epsilon) && (otherElts[i] - mElements[i] > epsilon))
    120                return false;
    121        }
    122        return true;
    123    }
    124 
    125    unsigned int size() const
    126    {
    127        ASSERT(rows() == columns());
    128        return rows();
    129    }
    130 
    131    unsigned int rows() const { return mRows; }
    132 
    133    unsigned int columns() const { return mCols; }
    134 
    135    std::vector<T> elements() const { return mElements; }
    136    T *data() { return mElements.data(); }
    137    const T *constData() const { return mElements.data(); }
    138 
    139    Matrix<T> compMult(const Matrix<T> &mat1) const
    140    {
    141        Matrix result(std::vector<T>(mElements.size()), rows(), columns());
    142        for (unsigned int i = 0; i < rows(); i++)
    143        {
    144            for (unsigned int j = 0; j < columns(); j++)
    145            {
    146                T lhs        = at(i, j);
    147                T rhs        = mat1(i, j);
    148                result(i, j) = rhs * lhs;
    149            }
    150        }
    151 
    152        return result;
    153    }
    154 
    155    Matrix<T> outerProduct(const Matrix<T> &mat1) const
    156    {
    157        unsigned int cols = mat1.columns();
    158        Matrix result(std::vector<T>(rows() * cols), rows(), cols);
    159        for (unsigned int i = 0; i < rows(); i++)
    160            for (unsigned int j = 0; j < cols; j++)
    161                result(i, j) = at(i, 0) * mat1(0, j);
    162 
    163        return result;
    164    }
    165 
    166    Matrix<T> transpose() const
    167    {
    168        Matrix result(std::vector<T>(mElements.size()), columns(), rows());
    169        for (unsigned int i = 0; i < columns(); i++)
    170            for (unsigned int j = 0; j < rows(); j++)
    171                result(i, j) = at(j, i);
    172 
    173        return result;
    174    }
    175 
    176    T determinant() const
    177    {
    178        ASSERT(rows() == columns());
    179 
    180        switch (size())
    181        {
    182            case 2:
    183                return at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0);
    184 
    185            case 3:
    186                return at(0, 0) * at(1, 1) * at(2, 2) + at(0, 1) * at(1, 2) * at(2, 0) +
    187                       at(0, 2) * at(1, 0) * at(2, 1) - at(0, 2) * at(1, 1) * at(2, 0) -
    188                       at(0, 1) * at(1, 0) * at(2, 2) - at(0, 0) * at(1, 2) * at(2, 1);
    189 
    190            case 4:
    191            {
    192                const float minorMatrices[4][3 * 3] = {{
    193                                                           at(1, 1),
    194                                                           at(2, 1),
    195                                                           at(3, 1),
    196                                                           at(1, 2),
    197                                                           at(2, 2),
    198                                                           at(3, 2),
    199                                                           at(1, 3),
    200                                                           at(2, 3),
    201                                                           at(3, 3),
    202                                                       },
    203                                                       {
    204                                                           at(1, 0),
    205                                                           at(2, 0),
    206                                                           at(3, 0),
    207                                                           at(1, 2),
    208                                                           at(2, 2),
    209                                                           at(3, 2),
    210                                                           at(1, 3),
    211                                                           at(2, 3),
    212                                                           at(3, 3),
    213                                                       },
    214                                                       {
    215                                                           at(1, 0),
    216                                                           at(2, 0),
    217                                                           at(3, 0),
    218                                                           at(1, 1),
    219                                                           at(2, 1),
    220                                                           at(3, 1),
    221                                                           at(1, 3),
    222                                                           at(2, 3),
    223                                                           at(3, 3),
    224                                                       },
    225                                                       {
    226                                                           at(1, 0),
    227                                                           at(2, 0),
    228                                                           at(3, 0),
    229                                                           at(1, 1),
    230                                                           at(2, 1),
    231                                                           at(3, 1),
    232                                                           at(1, 2),
    233                                                           at(2, 2),
    234                                                           at(3, 2),
    235                                                       }};
    236                return at(0, 0) * Matrix<T>(minorMatrices[0], 3).determinant() -
    237                       at(0, 1) * Matrix<T>(minorMatrices[1], 3).determinant() +
    238                       at(0, 2) * Matrix<T>(minorMatrices[2], 3).determinant() -
    239                       at(0, 3) * Matrix<T>(minorMatrices[3], 3).determinant();
    240            }
    241 
    242            default:
    243                UNREACHABLE();
    244                break;
    245        }
    246 
    247        return T();
    248    }
    249 
    250    Matrix<T> inverse() const
    251    {
    252        ASSERT(rows() == columns());
    253 
    254        Matrix<T> cof(std::vector<T>(mElements.size()), rows(), columns());
    255        switch (size())
    256        {
    257            case 2:
    258                cof(0, 0) = at(1, 1);
    259                cof(0, 1) = -at(1, 0);
    260                cof(1, 0) = -at(0, 1);
    261                cof(1, 1) = at(0, 0);
    262                break;
    263 
    264            case 3:
    265                cof(0, 0) = at(1, 1) * at(2, 2) - at(2, 1) * at(1, 2);
    266                cof(0, 1) = -(at(1, 0) * at(2, 2) - at(2, 0) * at(1, 2));
    267                cof(0, 2) = at(1, 0) * at(2, 1) - at(2, 0) * at(1, 1);
    268                cof(1, 0) = -(at(0, 1) * at(2, 2) - at(2, 1) * at(0, 2));
    269                cof(1, 1) = at(0, 0) * at(2, 2) - at(2, 0) * at(0, 2);
    270                cof(1, 2) = -(at(0, 0) * at(2, 1) - at(2, 0) * at(0, 1));
    271                cof(2, 0) = at(0, 1) * at(1, 2) - at(1, 1) * at(0, 2);
    272                cof(2, 1) = -(at(0, 0) * at(1, 2) - at(1, 0) * at(0, 2));
    273                cof(2, 2) = at(0, 0) * at(1, 1) - at(1, 0) * at(0, 1);
    274                break;
    275 
    276            case 4:
    277                cof(0, 0) = at(1, 1) * at(2, 2) * at(3, 3) + at(2, 1) * at(3, 2) * at(1, 3) +
    278                            at(3, 1) * at(1, 2) * at(2, 3) - at(1, 1) * at(3, 2) * at(2, 3) -
    279                            at(2, 1) * at(1, 2) * at(3, 3) - at(3, 1) * at(2, 2) * at(1, 3);
    280                cof(0, 1) = -(at(1, 0) * at(2, 2) * at(3, 3) + at(2, 0) * at(3, 2) * at(1, 3) +
    281                              at(3, 0) * at(1, 2) * at(2, 3) - at(1, 0) * at(3, 2) * at(2, 3) -
    282                              at(2, 0) * at(1, 2) * at(3, 3) - at(3, 0) * at(2, 2) * at(1, 3));
    283                cof(0, 2) = at(1, 0) * at(2, 1) * at(3, 3) + at(2, 0) * at(3, 1) * at(1, 3) +
    284                            at(3, 0) * at(1, 1) * at(2, 3) - at(1, 0) * at(3, 1) * at(2, 3) -
    285                            at(2, 0) * at(1, 1) * at(3, 3) - at(3, 0) * at(2, 1) * at(1, 3);
    286                cof(0, 3) = -(at(1, 0) * at(2, 1) * at(3, 2) + at(2, 0) * at(3, 1) * at(1, 2) +
    287                              at(3, 0) * at(1, 1) * at(2, 2) - at(1, 0) * at(3, 1) * at(2, 2) -
    288                              at(2, 0) * at(1, 1) * at(3, 2) - at(3, 0) * at(2, 1) * at(1, 2));
    289                cof(1, 0) = -(at(0, 1) * at(2, 2) * at(3, 3) + at(2, 1) * at(3, 2) * at(0, 3) +
    290                              at(3, 1) * at(0, 2) * at(2, 3) - at(0, 1) * at(3, 2) * at(2, 3) -
    291                              at(2, 1) * at(0, 2) * at(3, 3) - at(3, 1) * at(2, 2) * at(0, 3));
    292                cof(1, 1) = at(0, 0) * at(2, 2) * at(3, 3) + at(2, 0) * at(3, 2) * at(0, 3) +
    293                            at(3, 0) * at(0, 2) * at(2, 3) - at(0, 0) * at(3, 2) * at(2, 3) -
    294                            at(2, 0) * at(0, 2) * at(3, 3) - at(3, 0) * at(2, 2) * at(0, 3);
    295                cof(1, 2) = -(at(0, 0) * at(2, 1) * at(3, 3) + at(2, 0) * at(3, 1) * at(0, 3) +
    296                              at(3, 0) * at(0, 1) * at(2, 3) - at(0, 0) * at(3, 1) * at(2, 3) -
    297                              at(2, 0) * at(0, 1) * at(3, 3) - at(3, 0) * at(2, 1) * at(0, 3));
    298                cof(1, 3) = at(0, 0) * at(2, 1) * at(3, 2) + at(2, 0) * at(3, 1) * at(0, 2) +
    299                            at(3, 0) * at(0, 1) * at(2, 2) - at(0, 0) * at(3, 1) * at(2, 2) -
    300                            at(2, 0) * at(0, 1) * at(3, 2) - at(3, 0) * at(2, 1) * at(0, 2);
    301                cof(2, 0) = at(0, 1) * at(1, 2) * at(3, 3) + at(1, 1) * at(3, 2) * at(0, 3) +
    302                            at(3, 1) * at(0, 2) * at(1, 3) - at(0, 1) * at(3, 2) * at(1, 3) -
    303                            at(1, 1) * at(0, 2) * at(3, 3) - at(3, 1) * at(1, 2) * at(0, 3);
    304                cof(2, 1) = -(at(0, 0) * at(1, 2) * at(3, 3) + at(1, 0) * at(3, 2) * at(0, 3) +
    305                              at(3, 0) * at(0, 2) * at(1, 3) - at(0, 0) * at(3, 2) * at(1, 3) -
    306                              at(1, 0) * at(0, 2) * at(3, 3) - at(3, 0) * at(1, 2) * at(0, 3));
    307                cof(2, 2) = at(0, 0) * at(1, 1) * at(3, 3) + at(1, 0) * at(3, 1) * at(0, 3) +
    308                            at(3, 0) * at(0, 1) * at(1, 3) - at(0, 0) * at(3, 1) * at(1, 3) -
    309                            at(1, 0) * at(0, 1) * at(3, 3) - at(3, 0) * at(1, 1) * at(0, 3);
    310                cof(2, 3) = -(at(0, 0) * at(1, 1) * at(3, 2) + at(1, 0) * at(3, 1) * at(0, 2) +
    311                              at(3, 0) * at(0, 1) * at(1, 2) - at(0, 0) * at(3, 1) * at(1, 2) -
    312                              at(1, 0) * at(0, 1) * at(3, 2) - at(3, 0) * at(1, 1) * at(0, 2));
    313                cof(3, 0) = -(at(0, 1) * at(1, 2) * at(2, 3) + at(1, 1) * at(2, 2) * at(0, 3) +
    314                              at(2, 1) * at(0, 2) * at(1, 3) - at(0, 1) * at(2, 2) * at(1, 3) -
    315                              at(1, 1) * at(0, 2) * at(2, 3) - at(2, 1) * at(1, 2) * at(0, 3));
    316                cof(3, 1) = at(0, 0) * at(1, 2) * at(2, 3) + at(1, 0) * at(2, 2) * at(0, 3) +
    317                            at(2, 0) * at(0, 2) * at(1, 3) - at(0, 0) * at(2, 2) * at(1, 3) -
    318                            at(1, 0) * at(0, 2) * at(2, 3) - at(2, 0) * at(1, 2) * at(0, 3);
    319                cof(3, 2) = -(at(0, 0) * at(1, 1) * at(2, 3) + at(1, 0) * at(2, 1) * at(0, 3) +
    320                              at(2, 0) * at(0, 1) * at(1, 3) - at(0, 0) * at(2, 1) * at(1, 3) -
    321                              at(1, 0) * at(0, 1) * at(2, 3) - at(2, 0) * at(1, 1) * at(0, 3));
    322                cof(3, 3) = at(0, 0) * at(1, 1) * at(2, 2) + at(1, 0) * at(2, 1) * at(0, 2) +
    323                            at(2, 0) * at(0, 1) * at(1, 2) - at(0, 0) * at(2, 1) * at(1, 2) -
    324                            at(1, 0) * at(0, 1) * at(2, 2) - at(2, 0) * at(1, 1) * at(0, 2);
    325                break;
    326 
    327            default:
    328                UNREACHABLE();
    329                break;
    330        }
    331 
    332        // The inverse of A is the transpose of the cofactor matrix times the reciprocal of the
    333        // determinant of A.
    334        Matrix<T> adjugateMatrix(cof.transpose());
    335        T det = determinant();
    336        Matrix<T> result(std::vector<T>(mElements.size()), rows(), columns());
    337        for (unsigned int i = 0; i < rows(); i++)
    338            for (unsigned int j = 0; j < columns(); j++)
    339                result(i, j) = (det != static_cast<T>(0)) ? adjugateMatrix(i, j) / det : T();
    340 
    341        return result;
    342    }
    343 
    344    void setToIdentity()
    345    {
    346        ASSERT(rows() == columns());
    347 
    348        const auto one  = T(1);
    349        const auto zero = T(0);
    350 
    351        for (auto &e : mElements)
    352            e = zero;
    353 
    354        for (unsigned int i = 0; i < rows(); ++i)
    355        {
    356            const auto pos = i * columns() + (i % columns());
    357            mElements[pos] = one;
    358        }
    359    }
    360 
    361    template <unsigned int Size>
    362    static void setToIdentity(T (&matrix)[Size])
    363    {
    364        static_assert(gl::iSquareRoot<Size>() != 0, "Matrix is not square.");
    365 
    366        const auto cols = gl::iSquareRoot<Size>();
    367        const auto one  = T(1);
    368        const auto zero = T(0);
    369 
    370        for (auto &e : matrix)
    371            e = zero;
    372 
    373        for (unsigned int i = 0; i < cols; ++i)
    374        {
    375            const auto pos = i * cols + (i % cols);
    376            matrix[pos]    = one;
    377        }
    378    }
    379 
    380  protected:
    381    std::vector<T> mElements;
    382    unsigned int mRows;
    383    unsigned int mCols;
    384 };
    385 
    386 class Mat4 : public Matrix<float>
    387 {
    388  public:
    389    Mat4();
    390    Mat4(const Matrix<float> generalMatrix);
    391    Mat4(const std::vector<float> &elements);
    392    Mat4(const float *elements);
    393    Mat4(float m00,
    394         float m01,
    395         float m02,
    396         float m03,
    397         float m10,
    398         float m11,
    399         float m12,
    400         float m13,
    401         float m20,
    402         float m21,
    403         float m22,
    404         float m23,
    405         float m30,
    406         float m31,
    407         float m32,
    408         float m33);
    409 
    410    static Mat4 Rotate(float angle, const Vector3 &axis);
    411    static Mat4 Translate(const Vector3 &t);
    412    static Mat4 Scale(const Vector3 &s);
    413    static Mat4 Frustum(float l, float r, float b, float t, float n, float f);
    414    static Mat4 Perspective(float fov, float aspectRatio, float n, float f);
    415    static Mat4 Ortho(float l, float r, float b, float t, float n, float f);
    416 
    417    Mat4 product(const Mat4 &m);
    418    Vector4 product(const Vector4 &b);
    419    void dump();
    420 };
    421 
    422 }  // namespace angle
    423 
    424 #endif  // COMMON_MATRIX_UTILS_H_