batmat 0.0.14
Batched linear algebra routines
Loading...
Searching...
No Matches
trtri.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
7#include <batmat/loop.hpp>
8
9#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
10
12
13/// @param A k×RowsReg.
14/// @param D k×RowsReg.
15/// @param k Number of rows in A and D.
16/// Invert the top block of A and store it in the top block of D. Then multiply the bottom blocks of
17/// D by this block (on the right).
18template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder OA, StorageOrder OD>
19[[gnu::hot, gnu::flatten]] void trtri_copy_microkernel(const uview<const T, Abi, OA> A,
20 const uview<T, Abi, OD> D,
21 const index_t k) noexcept {
22 static_assert(Conf.struc == MatrixStructure::LowerTriangular); // TODO
23 static_assert(RowsReg > 0);
25 using simd = datapar::simd<T, Abi>;
26 // Pre-compute the offsets of the columns of A
27 const auto A1_cached = with_cached_access<RowsReg, RowsReg>(A);
28 const auto A_cached = with_cached_access<0, RowsReg>(A);
29 // Load matrix into registers
30 simd A1_reg[RowsReg * (RowsReg + 1) / 2]; // NOLINT(*-c-arrays)
31 auto A1r = [&A1_reg](index_t r, index_t c) -> simd & {
32 return A1_reg[c * (2 * RowsReg - 1 - c) / 2 + r];
33 };
34 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
35 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj)
36 A1r(ii, jj) = A1_cached.load(ii, jj);
37
38 // Invert A₁.
39 // Recursively apply Fact 2.17.1 from Bernstein 2009 - Matrix mathematics
40 // theory, facts, and formulas.
41 // [ L₁₁ 0 ]⁻¹ = [ L₁₁⁻¹ 0 ]
42 // [ L₂₁ L₂₂ ] [ -L₂₂⁻¹ L₂₁ L₁₁⁻¹ L₂₂⁻¹ ]
43 // First apply it to the last column:
44 // [ l₁₁ ]
45 // [ l₂₁ l₂₂ ]
46 // [ l₃₁ l₃₂ l₃₃ ]
47 // [ l₄₁ l₄₂ l₄₃ l₄₄⁻¹ ]
48 // Then to the bottom right 2×2 block:
49 // [ l₁₁ ] = [ l₁₁ ]
50 // [ l₂₁ l₂₂ ] [ l₂₁ l₂₂ ]
51 // [ l₃₁ l₃₂ l₃₃⁻¹ ] [ l₃₁ l₃₂ [ ] ]
52 // [ l₄₁ l₄₂ -l₄₄⁻¹ l₄₃ l₃₃⁻¹ l₄₄⁻¹ ] [ l₄₁ l₄₂ [ L₃₃⁻¹ ] ]
53 // Then to the bottom right 3×3 block, and so on.
54 UNROLL_FOR (index_t jj = RowsReg - 1; jj >= 0; --jj) {
55 // Invert diagonal element.
56 A1r(jj, jj) = simd{1} / A1r(jj, jj);
57 // Multiply current diagonal element with column j.
58 // -ℓ₂₁ ℓ₁₁⁻¹
59 UNROLL_FOR (index_t ii = RowsReg - 1; ii > jj; --ii)
60 A1r(ii, jj) *= -A1r(jj, jj);
61 // Triangular matrix-vector product of bottom right block with column j.
62 // -L₂₂⁻¹ ℓ₂₁ ℓ₁₁⁻¹
63 UNROLL_FOR (index_t ll = RowsReg - 1; ll > jj; --ll) {
64 UNROLL_FOR (index_t ii = RowsReg - 1; ii > ll; --ii)
65 A1r(ii, jj) += A1r(ii, ll) * A1r(ll, jj);
66 A1r(ll, jj) *= A1r(ll, ll);
67 }
68 }
69
70 // Pre-compute the offsets of the columns of D
71 const auto D1_cached = with_cached_access<RowsReg, RowsReg>(D);
72 const auto D_cached = with_cached_access<0, RowsReg>(D);
73 // Store matrix A₁⁻¹ to D₁
74 UNROLL_FOR (index_t i = 0; i < RowsReg; ++i)
75 UNROLL_FOR (index_t j = 0; j <= i; ++j)
76 D1_cached.store(A1r(i, j), i, j);
77
78 // Multiply A₂ by -A₁⁻¹ and store in D₂
79 for (index_t l = RowsReg; l < k; ++l) {
80 simd A2r[RowsReg]; // NOLINT(*-c-arrays)
81 UNROLL_FOR (index_t i = 0; i < RowsReg; ++i)
82 A2r[i] = A_cached.load(l, i);
83 UNROLL_FOR (index_t i = 0; i < RowsReg; ++i) {
84 A2r[i] *= -A1r(i, i);
85 UNROLL_FOR (index_t j = i + 1; j < RowsReg; ++j)
86 A2r[i] -= A2r[j] * A1r(j, i);
87 D_cached.store(A2r[i], l, i);
88 }
89 }
90}
91
92/// @param Dr RowsReg×k lower trapezoidal
93/// @param D k×ColsReg
94/// @param k Number of rows in D.
95/// Compute product Dr D and store the result in the bottom block of D
96template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder OD>
97[[gnu::hot, gnu::flatten]]
99 const index_t k) noexcept {
100 static_assert(Conf.struc == MatrixStructure::LowerTriangular); // TODO
101 static_assert(RowsReg > 0 && ColsReg > 0);
103 using simd = datapar::simd<T, Abi>;
104 // Clear accumulator
105 simd D_reg[RowsReg][ColsReg]{}; // NOLINT(*-c-arrays)
106 // Perform gemm
107 const auto A1_cached = with_cached_access<RowsReg, 0>(Dr);
108 const auto B1_cached = with_cached_access<0, ColsReg>(D);
109 for (index_t l = 0; l < k - RowsReg; ++l)
110 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
111 simd Ail = A1_cached.load(ii, l);
112 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
113 simd &Cij = D_reg[ii][jj];
114 simd Blj = B1_cached.load(l, jj);
115 Cij += Ail * Blj;
116 }
117 }
118 // Perform trmm
119 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) // TODO: move before gemm
120 UNROLL_FOR (index_t ll = 0; ll <= ii; ++ll) {
121 simd Ail = A1_cached.load(ii, k - RowsReg + ll);
122 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
123 D_reg[ii][jj] += Ail * B1_cached.load(k - RowsReg + ll, jj);
124 }
125 // Store result to memory
126 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
127 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
128 D.store(D_reg[ii][jj], k - RowsReg + ii, jj);
129}
130
131template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
133 using enum MatrixStructure;
134 static_assert(Conf.struc == LowerTriangular); // TODO
135 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
136 // Check dimensions
137 const index_t I = A.rows();
138 BATMAT_ASSUME(A.rows() == A.cols());
139 BATMAT_ASSUME(A.rows() == D.rows());
140 BATMAT_ASSUME(A.cols() == D.cols());
141 static const auto trtri_microkernel = trtri_copy_lut<T, Abi, Conf, OA, OD>;
143 (void)trmm_microkernel; // GCC incorrectly warns about unused variable
144 // Sizeless views to partition and pass to the micro-kernels
145 const uview<const T, Abi, OA> A_ = A;
146 const uview<T, Abi, OD> D_ = D;
147
148 // Optimization for very small matrices
149 if (I <= Rows)
150 return trtri_microkernel[I - 1](A_, D_, I);
151
152 // Partition:
153 // [ ... ] [ ... ]
154 // A = [ ... Ajj ] D = [ ... Djj ] with the invariant Dp = Ap⁻¹
155 // [ ... Aj Ap ] [ ... Dj Dp ]
156
157 foreach_chunked_merged( // Loop over the diagonal blocks of A, in reverse
158 0, I, Cols,
159 [&](index_t j, auto nj) {
160 const auto jp = j + nj;
161 const auto Ajj = A_.block(j, j);
162 const auto Djj = D_.block(j, j);
163 const auto Dj = D_.block(jp, j);
164 // Invert Djj = Ajj⁻¹ and multiply Dj = Aj Djj
165 trtri_microkernel[nj - 1](Ajj, Djj, I - j);
166 // Multiply Dp Dj (with Dp lower triangular)
167 foreach_chunked_merged( // Loop over the block rows of Dj, in reverse
168 jp, I, Rows,
169 [&](index_t i, auto ni) {
170 // Block row of already inverted bottom right corner
171 const auto Dpi = D_.block(i, jp);
172 // Current subdiagonal column to be multiplied by Dp
173 trmm_microkernel[ni - 1][nj - 1](Dpi, Dj, i + ni - jp);
174 },
176 },
178}
179
180} // namespace batmat::linalg::micro_kernels::trtri
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:9
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
const constinit auto trmm_lut
Definition trtri.hpp:35
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
const constinit auto trtri_copy_lut
Definition trtri.hpp:29
void trmm_microkernel(uview< const T, Abi, OD > Dr, uview< T, Abi, OD > D, index_t k) noexcept
Definition trtri.tpp:98
void trtri_copy_register(view< const T, Abi, OA > A, view< T, Abi, OD > D) noexcept
Definition trtri.tpp:132
void trtri_copy_microkernel(uview< const T, Abi, OA > A, uview< T, Abi, OD > D, index_t k) noexcept
Definition trtri.tpp:19
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110