batmat main
Batched linear algebra routines
Loading...
Searching...
No Matches
trtri.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
7#include <batmat/loop.hpp>
8#include <batmat/lut.hpp>
9
10#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
11
13
14template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
15inline const constinit auto trtri_copy_lut =
18 });
19
20template <class T, class Abi, KernelConfig Conf, StorageOrder OD>
22 []<index_t Row, index_t Col>(index_constant<Row>, index_constant<Col>) {
24 });
25
26/// @param A k×RowsReg.
27/// @param D k×RowsReg.
28/// @param k Number of rows in A and D.
29/// Invert the top block of A and store it in the top block of D. Then multiply the bottom blocks of
30/// D by this block (on the right).
31template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder OA, StorageOrder OD>
32[[gnu::hot, gnu::flatten]] void trtri_copy_microkernel(const uview<const T, Abi, OA> A,
33 const uview<T, Abi, OD> D,
34 const index_t k) noexcept {
35 static_assert(Conf.struc == MatrixStructure::LowerTriangular); // TODO
36 static_assert(RowsReg > 0);
38 using simd = datapar::simd<T, Abi>;
39 // Pre-compute the offsets of the columns of A
40 const auto A1_cached = with_cached_access<RowsReg, RowsReg>(A);
41 const auto A_cached = with_cached_access<0, RowsReg>(A);
42 // Load matrix into registers
43 simd A1_reg[RowsReg * (RowsReg + 1) / 2]; // NOLINT(*-c-arrays)
44 auto A1r = [&A1_reg](index_t r, index_t c) -> simd & {
45 return A1_reg[c * (2 * RowsReg - 1 - c) / 2 + r];
46 };
47 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
48 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj)
49 A1r(ii, jj) = A1_cached.load(ii, jj);
50
51 // Invert A₁.
52 // Recursively apply Fact 2.17.1 from Bernstein 2009 - Matrix mathematics
53 // theory, facts, and formulas.
54 // [ L₁₁ 0 ]⁻¹ = [ L₁₁⁻¹ 0 ]
55 // [ L₂₁ L₂₂ ] [ -L₂₂⁻¹ L₂₁ L₁₁⁻¹ L₂₂⁻¹ ]
56 // First apply it to the last column:
57 // [ l₁₁ ]
58 // [ l₂₁ l₂₂ ]
59 // [ l₃₁ l₃₂ l₃₃ ]
60 // [ l₄₁ l₄₂ l₄₃ l₄₄⁻¹ ]
61 // Then to the bottom right 2×2 block:
62 // [ l₁₁ ] = [ l₁₁ ]
63 // [ l₂₁ l₂₂ ] [ l₂₁ l₂₂ ]
64 // [ l₃₁ l₃₂ l₃₃⁻¹ ] [ l₃₁ l₃₂ [ ] ]
65 // [ l₄₁ l₄₂ -l₄₄⁻¹ l₄₃ l₃₃⁻¹ l₄₄⁻¹ ] [ l₄₁ l₄₂ [ L₃₃⁻¹ ] ]
66 // Then to the bottom right 3×3 block, and so on.
67 UNROLL_FOR (index_t jj = RowsReg - 1; jj >= 0; --jj) {
68 // Invert diagonal element.
69 A1r(jj, jj) = simd{1} / A1r(jj, jj);
70 // Multiply current diagonal element with column j.
71 // -ℓ₂₁ ℓ₁₁⁻¹
72 UNROLL_FOR (index_t ii = RowsReg - 1; ii > jj; --ii)
73 A1r(ii, jj) *= -A1r(jj, jj);
74 // Triangular matrix-vector product of bottom right block with column j.
75 // -L₂₂⁻¹ ℓ₂₁ ℓ₁₁⁻¹
76 UNROLL_FOR (index_t ll = RowsReg - 1; ll > jj; --ll) {
77 UNROLL_FOR (index_t ii = RowsReg - 1; ii > ll; --ii)
78 A1r(ii, jj) += A1r(ii, ll) * A1r(ll, jj);
79 A1r(ll, jj) *= A1r(ll, ll);
80 }
81 }
82
83 // Pre-compute the offsets of the columns of D
84 const auto D1_cached = with_cached_access<RowsReg, RowsReg>(D);
85 const auto D_cached = with_cached_access<0, RowsReg>(D);
86 // Store matrix A₁⁻¹ to D₁
87 UNROLL_FOR (index_t i = 0; i < RowsReg; ++i)
88 UNROLL_FOR (index_t j = 0; j <= i; ++j)
89 D1_cached.store(A1r(i, j), i, j);
90
91 // Multiply A₂ by -A₁⁻¹ and store in D₂
92 for (index_t l = RowsReg; l < k; ++l) {
93 simd A2r[RowsReg]; // NOLINT(*-c-arrays)
94 UNROLL_FOR (index_t i = 0; i < RowsReg; ++i)
95 A2r[i] = A_cached.load(l, i);
96 UNROLL_FOR (index_t i = 0; i < RowsReg; ++i) {
97 A2r[i] *= -A1r(i, i);
98 UNROLL_FOR (index_t j = i + 1; j < RowsReg; ++j)
99 A2r[i] -= A2r[j] * A1r(j, i);
100 D_cached.store(A2r[i], l, i);
101 }
102 }
103}
104
105/// @param Dr RowsReg×k lower trapezoidal
106/// @param D k×ColsReg
107/// @param k Number of rows in D.
108/// Compute product Dr D and store the result in the bottom block of D
109template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder OD>
110[[gnu::hot, gnu::flatten]]
112 const index_t k) noexcept {
113 static_assert(Conf.struc == MatrixStructure::LowerTriangular); // TODO
114 static_assert(RowsReg > 0 && ColsReg > 0);
116 using simd = datapar::simd<T, Abi>;
117 // Clear accumulator
118 simd D_reg[RowsReg][ColsReg]{}; // NOLINT(*-c-arrays)
119 // Perform gemm
120 const auto A1_cached = with_cached_access<RowsReg, 0>(Dr);
121 const auto B1_cached = with_cached_access<0, ColsReg>(D);
122 for (index_t l = 0; l < k - RowsReg; ++l)
123 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
124 simd Ail = A1_cached.load(ii, l);
125 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
126 simd &Cij = D_reg[ii][jj];
127 simd Blj = B1_cached.load(l, jj);
128 Cij += Ail * Blj;
129 }
130 }
131 // Perform trmm
132 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) // TODO: move before gemm
133 UNROLL_FOR (index_t ll = 0; ll <= ii; ++ll) {
134 simd Ail = A1_cached.load(ii, k - RowsReg + ll);
135 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
136 D_reg[ii][jj] += Ail * B1_cached.load(k - RowsReg + ll, jj);
137 }
138 // Store result to memory
139 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
140 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
141 D.store(D_reg[ii][jj], k - RowsReg + ii, jj);
142}
143
144template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
146 using enum MatrixStructure;
147 static_assert(Conf.struc == LowerTriangular); // TODO
148 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
149 // Check dimensions
150 const index_t I = A.rows();
151 BATMAT_ASSUME(A.rows() == A.cols());
152 BATMAT_ASSUME(A.rows() == D.rows());
153 BATMAT_ASSUME(A.cols() == D.cols());
154 static const auto trtri_microkernel = trtri_copy_lut<T, Abi, Conf, OA, OD>;
156 (void)trmm_microkernel; // GCC incorrectly warns about unused variable
157 // Sizeless views to partition and pass to the micro-kernels
158 const uview<const T, Abi, OA> A_ = A;
159 const uview<T, Abi, OD> D_ = D;
160
161 // Optimization for very small matrices
162 if (I <= Rows)
163 return trtri_microkernel[I - 1](A_, D_, I);
164
165 // Partition:
166 // [ ... ] [ ... ]
167 // A = [ ... Ajj ] D = [ ... Djj ] with the invariant Dp = Ap⁻¹
168 // [ ... Aj Ap ] [ ... Dj Dp ]
169
170 foreach_chunked_merged( // Loop over the diagonal blocks of A, in reverse
171 0, I, Cols,
172 [&](index_t j, auto nj) {
173 const auto jp = j + nj;
174 const auto Ajj = A_.block(j, j);
175 const auto Djj = D_.block(j, j);
176 const auto Dj = D_.block(jp, j);
177 // Invert Djj = Ajj⁻¹ and multiply Dj = Aj Djj
178 trtri_microkernel[nj - 1](Ajj, Djj, I - j);
179 // Multiply Dp Dj (with Dp lower triangular)
180 foreach_chunked_merged( // Loop over the block rows of Dj, in reverse
181 jp, I, Rows,
182 [&](index_t i, auto ni) {
183 // Block row of already inverted bottom right corner
184 const auto Dpi = D_.block(i, jp);
185 // Current subdiagonal column to be multiplied by Dp
186 trmm_microkernel[ni - 1][nj - 1](Dpi, Dj, i + ni - jp);
187 },
189 },
191}
192
193} // namespace batmat::linalg::micro_kernels::trtri
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
const constinit auto trmm_lut
Definition trtri.tpp:21
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
const constinit auto trtri_copy_lut
Definition trtri.tpp:15
void trmm_microkernel(uview< const T, Abi, OD > Dr, uview< T, Abi, OD > D, index_t k) noexcept
Definition trtri.tpp:111
void trtri_copy_register(view< const T, Abi, OA > A, view< T, Abi, OD > D) noexcept
Definition trtri.tpp:145
void trtri_copy_microkernel(uview< const T, Abi, OA > A, uview< T, Abi, OD > D, index_t k) noexcept
Definition trtri.tpp:32
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110