Linear Algebra  arduino
Accessible implementations of linear algebra algorithms
RowPivotLU.cpp
Go to the documentation of this file.
1 #ifndef ARDUINO
2 #include <linalg/RowPivotLU.hpp>
3 #else
5 #endif
6 
7 #include <cassert>
8 
30  // For the intermediate calculations, we'll be working with LU.
31  // It is initialized to the square n×n matrix to be factored.
32 
33  assert(LU.rows() == LU.cols());
34  assert(P.size() == LU.rows());
35 
36  // The goal of the LU factorization algorithm is to repeatedly apply
37  // transformations Lₖ to the matrix A to eventually end up with an upper-
38  // triangular matrix U. When row pivoting is used, the rows of A are
39  // permuted using a permutation matrix P:
40  //
41  // Lₙ'⋯L₂'L₁'PA = U
42  //
43  // The main steps of the algorithm are exactly the same as the original
44  // LU algorithm explained in LU.cpp, and will not be repeated here.
45  // The only difference is that instead of using the diagonal element as the
46  // pivot, rows are swapped so that the element with the largest magnitude
47  // ends up on the diagonal and can be used as the pivot.
48 
49  // Loop over all columns of A:
50  for (size_t k = 0; k < LU.cols(); ++k) {
51  // In the following comments, k = [1, n], because this is more intuitive
52  // and it follows the usual mathematical convention.
53  // In the code, however, array indices start at zero, so k = [0, n-1].
54 
55  // On each iteration, the largest element on or below the diagonal in
56  // the current (k-th) column will be used as the pivot.
57  // To this end, the k-th row and the row that contains the largest
58  // element are swapped, and the swapping is stored in the permutation
59  // matrix, so that it can later be undone, when solving systems of
60  // equations for example.
61 
62  // Find the largest element (in absolute value)
63  double max_elem = std::abs(LU(k, k));
64  size_t max_index = k;
65  for (size_t i = k + 1; i < LU.rows(); ++i) {
66  double abs_elem = std::abs(LU(i, k));
67  if (abs_elem > max_elem) {
68  max_elem = abs_elem;
69  max_index = i;
70  }
71  }
72 
73  // Select the index of the element that is largest in absolute value as
74  // the new pivot index.
75  // If this index is not the diagonal element, rows have to be swapped:
76  if (max_index != k) {
77  P(k) = max_index; // save the permutation
78  LU.swap_rows(k, max_index); // actually perfrom the permutation
79  }
80 
81  // Note how all columns of the two rows are permuted, not just the
82  // columns greater than k. You might wonder how that'll ever work out
83  // correctly in the end.
84  // Recall that for the LU factorization without pivoting, the result was
85  // Lₙ⋯L₂L₁A = U.
86  // When pivoting is used, however, the matrix is permuted before each
87  // elimination step:
88  //
89  // LₙPₙ⋯L₂P₂L₁P₁A = U
90  //
91  // Luckily, the product can be reordered, and all permutation matrices
92  // Pₖ can be grouped together.
93  // Without loss of generality, consider the pivoted LU factorization of
94  // a 3×3 matrix:
95  //
96  // L₃P₃L₂P₂L₁P₁A = U
97  //
98  // Now introduce the following permuted matrices Lₖ':
99  //
100  // L₃' = L₃
101  // L₂' = P₃L₂P₃⁻¹
102  // L₁' = P₃P₂L₁P₂⁻¹P₃⁻¹
103  //
104  // You can then easily see that the following equation holds:
105  //
106  // L₃'L₂'L₁'P₂P₃P₁A = U
107  // ⇔ L₃(P₃L₂P₃⁻¹)(P₃P₂L₁P₂⁻¹P₃⁻¹)P₃P₂P₁A = U
108  // ⇔ L₃P₃L₂P₂L₁P₁A = U
109  //
110  // Furthermore, the matrices Lₖ' have the same structure as Lₖ, because
111  // only rows below the k-th pivot are permuted.
112  //
113  // All of this allows us to group all row permutations into a single
114  // permutation matrix P, and use the same algorithm and storage format
115  // as for the unpivoted case.
116  //
117  // The factors Lₖ' are computed implicitly by applying the row
118  // permutations to the entire matrix that stores both the U and L
119  // factors, rather than just to the elements of the trailing submatrix.
120 
121  // The rest of the algorithm is identical to the one explained in
122  // NoPivotLU.cpp.
123 
124  double pivot = LU(k, k);
125 
126  // Compute the k-th column of L, the coefficients lᵢₖ:
127  for (size_t i = k + 1; i < LU.rows(); ++i)
128  LU(i, k) /= pivot;
129 
130  // Update the trailing submatrix A'(k+1:n,k+1:n) = LₖA(k+1:n,k+1:n):
131  for (size_t c = k + 1; c < LU.cols(); ++c)
132  // Subtract lᵢₖ times the current pivot row A(k,:):
133  for (size_t i = k + 1; i < LU.rows(); ++i)
134  // A'(i,c) = 1·A(i,c) - lᵢₖ·A(k,c)
135  LU(i, c) -= LU(i, k) * LU(k, c);
136 
137  // Because of the row pivoting, zero pivots are no longer an issue,
138  // since the pivot is always chosen to be the largest possible element.
139  // When the matrix is singular, the algorithm will still fail, of
140  // course.
141  }
142  state = Factored;
143  valid_LU = true;
144  valid_P = true;
145 }
147 
153 void RowPivotLU::back_subs(const Matrix &B, Matrix &X) const {
154  // Solve upper triangular system UX = B by solving each column of B as a
155  // vector system Uxᵢ = bᵢ
156  //
157  // ┌ ┐┌ ┐ ┌ ┐
158  // │ u₁₁ u₁₂ u₁₃ u₁₄ ││ x₁ᵢ │ │ b₁ᵢ │
159  // │ u₂₂ u₂₃ u₂₄ ││ x₂ᵢ │ = │ b₂ᵢ │
160  // │ u₃₃ u₃₄ ││ x₃ᵢ │ │ b₃ᵢ │
161  // │ u₄₄ ││ x₄ᵢ │ │ b₄ᵢ │
162  // └ ┘└ ┘ └ ┘
163  //
164  // b₄ᵢ = u₄₄·x₄ᵢ ⟺ x₄ᵢ = b₄ᵢ/u₄₄
165  // b₃ᵢ = u₃₃·x₃ᵢ + u₃₄·x₄ᵢ ⟺ x₃ᵢ = (b₃ᵢ - u₃₄·x₄ᵢ)/u₃₃
166  // b₂ᵢ = u₂₂·x₂ᵢ + u₂₃·x₃ᵢ + u₂₄·x₄ᵢ ⟺ x₂ᵢ = (b₂ᵢ - u₂₃·x₃ᵢ + u₂₄·x₄ᵢ)/u₂₂
167  // ...
168 
169  for (size_t i = 0; i < B.cols(); ++i) {
170  for (size_t r = LU.rows(); r-- > 0;) {
171  X(r, i) = B(r, i);
172  for (size_t c = r + 1; c < LU.cols(); ++c)
173  X(r, i) -= LU(r, c) * X(c, i);
174  X(r, i) /= LU(r, r);
175  }
176  }
177 }
179 
185 void RowPivotLU::forward_subs(const Matrix &B, Matrix &X) const {
186  // Solve lower triangular system LX = B by solving each column of B as a
187  // vector system Lxᵢ = bᵢ.
188  // The diagonal is always 1, due to the construction of the L matrix in the
189  // LU algorithm.
190  //
191  // ┌ ┐┌ ┐ ┌ ┐
192  // │ 1 ││ x₁ᵢ │ │ b₁ᵢ │
193  // │ l₂₁ 1 ││ x₂ᵢ │ = │ b₂ᵢ │
194  // │ l₃₁ l₃₂ 1 ││ x₃ᵢ │ │ b₃ᵢ │
195  // │ l₄₁ l₄₂ l₄₃ 1 ││ x₄ᵢ │ │ b₄ᵢ │
196  // └ ┘└ ┘ └ ┘
197  //
198  // b₁ᵢ = 1·x₁ᵢ ⟺ x₁ᵢ = b₁ᵢ
199  // b₂ᵢ = l₂₁·x₁ᵢ + 1·x₂ᵢ ⟺ x₂ᵢ = b₂ᵢ - l₂₁·x₁ᵢ
200  // b₃ᵢ = l₃₁·x₁ᵢ + l₃₂·x₂ᵢ + 1·x₃ᵢ ⟺ x₃ᵢ = b₃ᵢ - l₃₂·x₂ᵢ - l₃₁·x₁ᵢ
201  // ...
202 
203  for (size_t i = 0; i < B.cols(); ++i) {
204  for (size_t r = 0; r < LU.rows(); ++r) {
205  X(r, i) = B(r, i);
206  for (size_t c = 0; c < r; ++c)
207  X(r, i) -= LU(r, c) * X(c, i);
208  }
209  }
210 }
212 
219  // Solve the system AX = B, PAX = PB or LUX = PB.
220  //
221  // Let UX = Z, and first solve LZ = PB, which is a simple lower-triangular
222  // system of equations.
223  // Now that Z is known, solve UX = Z, which is a simple upper-triangular
224  // system of equations.
225  assert(is_factored());
226 
227  P.permute_rows(B);
228  forward_subs(B, B); // overwrite B with Z
229  back_subs(B, B); // overwrite B (Z) with X
230 }
232 
233 // All implementations of the less interesting functions can be found here:
General matrix class.
Definition: Matrix.hpp:34
void swap_rows(size_t a, size_t b)
Swap two rows of the matrix.
Definition: Matrix.cpp:189
size_t rows() const
Get the number of rows of the matrix.
Definition: Matrix.hpp:81
size_t cols() const
Get the number of columns of the matrix.
Definition: Matrix.hpp:83
size_t size() const
Get the size of the permutation matrix.
void permute_rows(Matrix &A) const
Apply the permutation to the rows of matrix A.
SquareMatrix LU
Result of a LU factorization: stores the upper-triangular matrix U and the strict lower-triangular pa...
Definition: RowPivotLU.hpp:159
void back_subs(const Matrix &B, Matrix &X) const
Back substitution algorithm for solving upper-triangular systems UX = B.
Definition: RowPivotLU.cpp:153
void solve_inplace(Matrix &B) const
Solve the system AX = B or LUX = B.
Definition: RowPivotLU.cpp:218
bool is_factored() const
Check if this object contains a factorization.
Definition: RowPivotLU.hpp:121
void forward_subs(const Matrix &B, Matrix &X) const
Forward substitution algorithm for solving lower-triangular systems LX = B.
Definition: RowPivotLU.cpp:185
void compute_factorization()
The actual LU factorization algorithm.
Definition: RowPivotLU.cpp:29
PermutationMatrix P
The permutation of A that maximizes pivot size.
Definition: RowPivotLU.hpp:161
enum RowPivotLU::State state