LCOV - code coverage report
Current view: top level - src/include/linalg - PermutationMatrix.hpp (source / functions) Hit Total Coverage
Test: 77f3e5efbbe58a833b8bba78631aab024522bbc3 Lines: 122 122 100.0 %
Date: 2021-02-20 15:40:15 Functions: 36 36 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include "Matrix.hpp"
       4             : 
       5             : /// @addtogroup MatVec
       6             : /// @{
       7             : 
       8             : /// Class that represents matrices that permute the rows or columns of other
       9             : /// matrices.
      10             : /// Stored in an efficient manner with O(n) memory requirements.
      11             : class PermutationMatrix {
      12             : 
      13             :     /// Container to store the elements of the permutation matrix internally.
      14             :     using storage_t = util::storage_t<size_t>;
      15             : 
      16             :   public:
      17             :     enum Type {
      18             :         Unspecified = 0,       ///< Can be used for permuting rows or columns.
      19             :         RowPermutation = 1,    ///< Can be used for permuting rows only.
      20             :         ColumnPermutation = 2, ///< Can be used for permuting columns only.
      21             :     };
      22             : 
      23             :   public:
      24             :     /// @name   Constructors and assignment
      25             :     /// @{
      26             : 
      27             :     /// Default constructor.
      28             :     PermutationMatrix() = default;
      29             : 
      30             :     /// Create an empty permutation matrix with the given type.
      31          14 :     PermutationMatrix(Type type) : type(type) {}
      32             : 
      33             :     /// Create a permutation matrix without permutations.
      34           3 :     PermutationMatrix(size_t rows, Type type = Unspecified)
      35           3 :         : storage(rows), type(type) {
      36           3 :         fill_identity();
      37           3 :     }
      38             : 
      39             :     /// Create a permutation matrix with the given permutation.
      40             :     PermutationMatrix(std::initializer_list<size_t> init,
      41             :                       Type type = Unspecified);
      42             :     /// Assign the given permutation to the matrix.
      43             :     PermutationMatrix &operator=(std::initializer_list<size_t> init);
      44             : 
      45             :     /// Default copy constructor.
      46           5 :     PermutationMatrix(const PermutationMatrix &) = default;
      47             :     /// Move constructor.
      48             :     PermutationMatrix(PermutationMatrix &&);
      49             : 
      50             :     /// Default copy assignment.
      51             :     PermutationMatrix &operator=(const PermutationMatrix &) = default;
      52             :     /// Move assignment.
      53             :     PermutationMatrix &operator=(PermutationMatrix &&);
      54             : 
      55             :     /// @}
      56             : 
      57             :   public:
      58             :     /// @name   Matrix size
      59             :     /// @{
      60             : 
      61             :     /// Get the size of the permutation matrix.
      62        1309 :     size_t size() const { return storage.size(); }
      63             :     /// Get the number of rows of the permutation matrix.
      64             :     size_t rows() const { return size(); }
      65             :     /// Get the number of columns of the permutation matrix.
      66             :     size_t cols() const { return size(); }
      67             :     /// Get the number of elements in the matrix:
      68             :     size_t num_elems() const { return size(); }
      69             :     /// Resize the permutation matrix.
      70          32 :     void resize(size_t size) { storage.resize(size); }
      71             : 
      72             :     /// @}
      73             : 
      74             :   public:
      75             :     /// @name   Element access
      76             :     /// @{
      77             : 
      78             :     /// Get the element at the given position in the swap sequence.
      79             :     /// If the k-th element is i, that is `P(k) == i`, this means that the k-th
      80             :     /// step of the swapping algorithm will swap `i` and `k`.
      81        2270 :     size_t &operator()(size_t index) { return storage[index]; }
      82             :     /// @copydoc    operator()(size_t)
      83        4466 :     const size_t &operator()(size_t index) const { return storage[index]; }
      84             : 
      85             :     /// @}
      86             : 
      87             :   public:
      88             :     /// @name   Transposition
      89             :     /// @{
      90             : 
      91             :     /// Reverse the order of the permutations.
      92           5 :     void reverse() { reverse_ = !reverse_; }
      93             : 
      94             :     /// Transpose or invert the permutation matrix.
      95           5 :     void transpose_inplace() { reverse(); }
      96             : 
      97             :     /// Check if the permutation should be applied in reverse.
      98          26 :     bool is_reversed() const { return reverse_; }
      99             : 
     100             :     /// Get the type of permutation matrix (whether it permutes rows or columns,
     101             :     /// or unspecified).
     102          23 :     Type get_type() const { return type; }
     103             :     /// Set the type of permutation matrix (whether it permutes rows or columns,
     104             :     /// or unspecified).
     105             :     void set_type(Type type) { this->type = type; }
     106             : 
     107             :     /// @}
     108             : 
     109             :   public:
     110             :     /// @name   Conversion to a full matrix or a permutation
     111             :     /// @{
     112             : 
     113             :     /// Convert a permutation matrix into a full matrix.
     114             :     SquareMatrix to_matrix(Type type = Unspecified) const;
     115             : 
     116             :     /// Type that represents a permutation (in the mathematical sense, a
     117             :     /// permutation of the integers 0 through n-1).
     118             :     using Permutation = storage_t;
     119             : 
     120             :     /// Convert a permutation matrix into a mathematical permutation
     121             :     Permutation to_permutation() const;
     122             : 
     123             :     /// @}
     124             : 
     125             :   public:
     126             :     /// @name   Applying the permutation to matrices
     127             :     /// @{
     128             : 
     129             :     /// Apply the permutation to the columns of matrix A.
     130             :     void permute_columns(Matrix &A) const;
     131             :     /// Apply the permutation to the rows of matrix A.
     132             :     void permute_rows(Matrix &A) const;
     133             : 
     134             :     /// @}
     135             : 
     136             :   public:
     137             :     /// @name   Memory management
     138             :     /// @{
     139             : 
     140             :     /// Set the size to zero, and deallocate the storage.
     141             :     void clear_and_deallocate() {
     142             :         storage_t().swap(this->storage); // replace storage with empty storage
     143             :         // temporary storage goes out of scope and deallocates original storage
     144             :     }
     145             : 
     146             :     /// @}
     147             : 
     148             :   public:
     149             :     /// @name    Generating permutations
     150             :     /// @{
     151             : 
     152             :     /// Return a random permutation of the integers 0 through length-1.
     153             :     static Permutation
     154             :     random_permutation(size_t length,
     155             :                        std::default_random_engine::result_type seed =
     156             :                            std::default_random_engine::default_seed);
     157             : 
     158             :     /// Return the identity permutation (0, 1, 2, 3, ..., length-1).
     159             :     static Permutation identity_permutation(size_t length);
     160             : 
     161             :     /// @}
     162             : 
     163             :   public:
     164             :     /// @name   Filling matrices
     165             :     /// @{
     166             : 
     167             :     /// Fill the matrix as an identity permutation.
     168          17 :     void fill_identity() { std::iota(begin(), end(), size_t(0)); }
     169             : 
     170             :     /// Create a permutation matrix from the given permutation.
     171             :     /// @note   This isn't a very fast method, it's mainly used for tests.
     172             :     ///         Internally, the permutation matrix is represented by a sequence
     173             :     ///         of swap operations. Converting from this representation to a
     174             :     ///         mathematical permutation is fast, but the other way around
     175             :     ///         requires O(n²) operations (with the naive implementation used 
     176             :     ///         here, anyway).
     177             :     void fill_from_permutation(Permutation permutation);
     178             : 
     179             :     /// Fill the matrix with a random permutation.
     180             :     /// @note   This isn't a very fast method, it's mainly used for tests.
     181           1 :     void fill_random(std::default_random_engine::result_type seed =
     182             :                          std::default_random_engine::default_seed) {
     183           1 :         fill_from_permutation(random_permutation(size(), seed));
     184           1 :     }
     185             : 
     186             :     /// @}
     187             : 
     188             :   public:
     189             :     /// @name   Create special matrices
     190             :     /// @{
     191             : 
     192             :     /// Create an identity permutation matrix.
     193             :     static PermutationMatrix identity(size_t rows, Type type = Unspecified) {
     194             :         PermutationMatrix p(rows, type);
     195             :         return p;
     196             :     }
     197             : 
     198             :     /// @copydoc    fill_from_permutation
     199           2 :     static PermutationMatrix from_permutation(Permutation permutation,
     200             :                                               Type type = Unspecified) {
     201           2 :         PermutationMatrix p(permutation.size(), type);
     202           2 :         p.fill_from_permutation(std::move(permutation));
     203           2 :         return p;
     204             :     }
     205             : 
     206             :     /// Create a random permutation matrix.
     207             :     /// @note   This isn't a very fast method, it's mainly used for tests.
     208             :     static PermutationMatrix
     209           1 :     random(size_t rows, Type type = Unspecified,
     210             :            std::default_random_engine::result_type seed =
     211             :                std::default_random_engine::default_seed) {
     212           1 :         PermutationMatrix p(rows, type);
     213           1 :         p.fill_random(seed);
     214           1 :         return p;
     215             :     }
     216             : 
     217             :     /// @}
     218             : 
     219             :   public:
     220             :     /// @name   Iterators
     221             :     /// @{
     222             : 
     223             :     /// Get the iterator to the first element of the swapping sequence.
     224          17 :     storage_t::iterator begin() { return storage.begin(); }
     225             :     /// Get the iterator to the first element of the swapping sequence.
     226             :     storage_t::const_iterator begin() const { return storage.begin(); }
     227             :     /// Get the iterator to the first element of the swapping sequence.
     228             :     storage_t::const_iterator cbegin() const { return storage.begin(); }
     229             : 
     230             :     /// Get the iterator to the element past the end of the swapping sequence.
     231          17 :     storage_t::iterator end() { return storage.end(); }
     232             :     /// Get the iterator to the element past the end of the swapping sequence.
     233             :     storage_t::const_iterator end() const { return storage.end(); }
     234             :     /// Get the iterator to the element past the end of the swapping sequence.
     235             :     storage_t::const_iterator cend() const { return storage.end(); }
     236             : 
     237             :     /// @}
     238             : 
     239             :   public:
     240             :     /// @name   Printing
     241             :     /// @{
     242             : 
     243             :     /// Print a permutation matrix.
     244             :     /// @param  os
     245             :     ///         The stream to print to.
     246             :     /// @param  precision
     247             :     ///         The number of significant figures to print.
     248             :     ///         (0 = auto)
     249             :     /// @param  width
     250             :     ///         The width of each element (number of characters).
     251             :     ///         (0 = auto)
     252             :     void print(std::ostream &os, uint8_t precision = 0,
     253             :                uint8_t width = 0) const;
     254             : 
     255             :     /// @}
     256             : 
     257             :   protected:
     258             :     storage_t storage;
     259             :     bool reverse_ = false;
     260             :     Type type = Unspecified;
     261             : };
     262             : 
     263             : /// @}
     264             : 
     265             : /// Print a permutation matrix.
     266             : /// @related    PermutationMatrix
     267             : std::ostream &operator<<(std::ostream &os, const PermutationMatrix &M);
     268             : 
     269             : // :::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: //
     270             : 
     271             : /// @addtogroup MatMul
     272             : /// @{
     273             : 
     274             : /// Left application of permutation matrix (P permutes rows of A).
     275           3 : inline Matrix operator*(const PermutationMatrix &P, const Matrix &A) {
     276           3 :     Matrix result = A;
     277           3 :     P.permute_rows(result);
     278           3 :     return result;
     279             : }
     280             : /// Left application of permutation matrix (P permutes rows of A).
     281           1 : inline Matrix &&operator*(const PermutationMatrix &P, Matrix &&A) {
     282           1 :     P.permute_rows(A);
     283           1 :     return std::move(A);
     284             : }
     285             : /// Right application of permutation matrix (P permutes columns of A).
     286           2 : inline Matrix operator*(const Matrix &A, const PermutationMatrix &P) {
     287           2 :     Matrix result = A;
     288           2 :     P.permute_columns(result);
     289           2 :     return result;
     290             : }
     291             : /// Right application of permutation matrix (P permutes columns of A).
     292           1 : inline Matrix &&operator*(Matrix &&A, const PermutationMatrix &P) {
     293           1 :     P.permute_columns(A);
     294           1 :     return std::move(A);
     295             : }
     296             : 
     297             : /// Left application of permutation matrix (P permutes rows of A).
     298           2 : inline SquareMatrix operator*(const PermutationMatrix &P,
     299             :                               const SquareMatrix &A) {
     300           2 :     SquareMatrix result = A;
     301           2 :     P.permute_rows(result);
     302           2 :     return result;
     303             : }
     304             : /// Left application of permutation matrix (P permutes rows of A).
     305           3 : inline SquareMatrix &&operator*(const PermutationMatrix &P, SquareMatrix &&A) {
     306           3 :     P.permute_rows(A);
     307           3 :     return std::move(A);
     308             : }
     309             : /// Right application of permutation matrix (P permutes columns of A).
     310           1 : inline SquareMatrix operator*(const SquareMatrix &A,
     311             :                               const PermutationMatrix &P) {
     312           1 :     SquareMatrix result = A;
     313           1 :     P.permute_columns(result);
     314           1 :     return result;
     315             : }
     316             : /// Right application of permutation matrix (P permutes columns of A).
     317           1 : inline SquareMatrix &&operator*(SquareMatrix &&A, const PermutationMatrix &P) {
     318           1 :     P.permute_columns(A);
     319           1 :     return std::move(A);
     320             : }
     321             : 
     322             : /// Left application of permutation matrix (P permutes rows of v).
     323           1 : inline Vector operator*(const PermutationMatrix &P, const Vector &v) {
     324           1 :     Vector result = v;
     325           1 :     P.permute_rows(result);
     326           1 :     return result;
     327             : }
     328             : /// Left application of permutation matrix (P permutes rows of v).
     329           1 : inline Vector &&operator*(const PermutationMatrix &P, Vector &&v) {
     330           1 :     P.permute_rows(v);
     331           1 :     return std::move(v);
     332             : }
     333             : 
     334             : /// Right application of permutation matrix (P permutes columns of v).
     335           1 : inline RowVector operator*(const RowVector &v, const PermutationMatrix &P) {
     336           1 :     RowVector result = v;
     337           1 :     P.permute_columns(result);
     338           1 :     return result;
     339             : }
     340             : /// Right application of permutation matrix (P permutes columns of v).
     341           1 : inline RowVector &&operator*(RowVector &&v, const PermutationMatrix &P) {
     342           1 :     P.permute_columns(v);
     343           1 :     return std::move(v);
     344             : }
     345             : 
     346             : /// @}
     347             : 
     348             : /// @addtogroup MatTrans
     349             : /// @{
     350             : 
     351             : /// Transpose a permutation matrix (inverse permutation).
     352           3 : inline PermutationMatrix transpose(const PermutationMatrix &P) {
     353           3 :     PermutationMatrix result = P;
     354           3 :     result.transpose_inplace();
     355           3 :     return result;
     356             : }
     357             : /// Transpose a permutation matrix (inverse permutation).
     358           2 : inline PermutationMatrix &&transpose(PermutationMatrix &&P) {
     359           2 :     P.transpose_inplace();
     360           2 :     return std::move(P);
     361             : }
     362             : 
     363             : /// @}
     364             : 
     365             : //                              Implementations                               //
     366             : // -------------------------------------------------------------------------- //
     367             : 
     368             : inline PermutationMatrix::PermutationMatrix(PermutationMatrix &&other) {
     369             :     *this = std::move(other);
     370             : }
     371             : 
     372             : inline PermutationMatrix &
     373             : PermutationMatrix::operator=(PermutationMatrix &&other) {
     374             :     // By explicitly defining move assignment, we can be sure that the object
     375             :     // that's being moved from has a consistent state.
     376             :     this->storage = std::move(other.storage);
     377             :     std::swap(this->type, other.type);
     378             :     std::swap(this->reverse_, other.reverse_);
     379             :     other.clear_and_deallocate();
     380             :     return *this;
     381             : }
     382             : 
     383          15 : inline PermutationMatrix::PermutationMatrix(std::initializer_list<size_t> init,
     384          15 :                                             Type type)
     385          15 :     : type(type) {
     386          15 :     *this = init;
     387          15 : }
     388             : 
     389             : inline SquareMatrix PermutationMatrix::to_matrix(Type type) const {
     390             :     // TODO: I'm sure this can be sped up
     391             :     Type actual_type = type == Unspecified ? this->type : type;
     392             :     assert(actual_type != Unspecified);
     393             :     if (actual_type == RowPermutation) {
     394             :         SquareMatrix P = SquareMatrix::identity(size());
     395             :         permute_rows(P);
     396             :         return P;
     397             :     } else if (actual_type == ColumnPermutation) {
     398             :         SquareMatrix P = SquareMatrix::identity(size());
     399             :         permute_columns(P);
     400             :         return P;
     401             :     }
     402             :     assert(false);
     403             :     return {};
     404             : }
     405             : 
     406             : inline PermutationMatrix::Permutation
     407           3 : PermutationMatrix::to_permutation() const {
     408           3 :     Permutation p = identity_permutation(size());
     409           3 :     auto &This = *this;
     410           3 :     if (is_reversed()) {
     411             :         // Count down
     412        1025 :         for (size_t i = size(); i-- > 0;)
     413        1024 :             if (i != This(i))
     414        1018 :                 std::swap(p[i], p[This(i)]);
     415             :     } else {
     416             :         // Count up
     417        1154 :         for (size_t i = 0; i < size(); ++i)
     418        1152 :             if (i != This(i))
     419        1139 :                 std::swap(p[i], p[This(i)]);
     420             :     }
     421           3 :     return p;
     422             : }
     423             : 
     424             : inline PermutationMatrix &
     425          15 : PermutationMatrix::operator=(std::initializer_list<size_t> init) {
     426          30 :     storage_t permutation(init.size());
     427          15 :     std::copy(init.begin(), init.end(), permutation.begin());
     428          15 :     fill_from_permutation(std::move(permutation));
     429             : 
     430          30 :     return *this;
     431             : }
     432             : 
     433           7 : inline void PermutationMatrix::permute_columns(Matrix &A) const {
     434           7 :     assert(A.cols() == size());
     435           7 :     assert(get_type() != RowPermutation);
     436           7 :     auto &This = *this;
     437           7 :     if (is_reversed()) {
     438             :         // Count down
     439           5 :         for (size_t i = size(); i-- > 0;)
     440           4 :             if (i != This(i))
     441           2 :                 A.swap_columns(i, This(i));
     442             :     } else {
     443             :         // Count up
     444          30 :         for (size_t i = 0; i < size(); ++i)
     445          24 :             if (i != This(i))
     446          12 :                 A.swap_columns(i, This(i));
     447             :     }
     448           7 : }
     449             : 
     450          16 : inline void PermutationMatrix::permute_rows(Matrix &A) const {
     451          16 :     assert(A.rows() == size());
     452          16 :     assert(get_type() != ColumnPermutation);
     453          16 :     auto &This = *this;
     454          16 :     if (is_reversed()) {
     455             :         // Count down
     456          16 :         for (size_t i = size(); i-- > 0;)
     457          13 :             if (i != This(i))
     458           7 :                 A.swap_rows(i, This(i));
     459             :     } else {
     460             :         // Count up
     461          61 :         for (size_t i = 0; i < size(); ++i)
     462          48 :             if (i != This(i))
     463          23 :                 A.swap_rows(i, This(i));
     464             :     }
     465          16 : }

Generated by: LCOV version 1.15