batmat 0.0.13
Batched linear algebra routines
Loading...
Searching...
No Matches
potrf.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
7#include <batmat/loop.hpp>
9
10#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
11
13
14/// @param A1 RowsReg×k1.
15/// @param A2 RowsReg×k2.
16/// @param C RowsReg×RowsReg.
17/// @param D RowsReg×RowsReg.
18/// @param invD Inverse diagonal of @p D.
19/// @param k1 Number of columns in A1.
20/// @param k2 Number of columns in A2.
21/// @param regularization Regularization added to the diagonal of C before factorization.
22template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder O1, StorageOrder O2>
23[[gnu::hot, gnu::flatten]] void
25 const uview<const T, Abi, O2> C, const uview<T, Abi, O2> D, T *const invD,
26 const index_t k1, const index_t k2, T regularization) noexcept {
27 static_assert(Conf.struc_C == MatrixStructure::LowerTriangular); // TODO
28 static_assert(RowsReg > 0);
29 using ops::rsqrt;
30 using simd = datapar::simd<T, Abi>;
31 // Pre-compute the offsets of the columns of C
32 const auto C_cached = with_cached_access<RowsReg, RowsReg>(C);
33 // Load matrix into registers
34 simd C_reg[RowsReg * (RowsReg + 1) / 2]; // NOLINT(*-c-arrays)
35 const auto index = [](index_t r, index_t c) { return c * (2 * RowsReg - 1 - c) / 2 + r; };
36 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
37 UNROLL_FOR (index_t jj = 0; jj < ii; ++jj)
38 C_reg[index(ii, jj)] = C_cached.load(ii, jj);
39 C_reg[index(ii, ii)] = C_cached.load(ii, ii) + simd(regularization);
40 }
41 // Perform syrk operation of A
42 const auto A1_cached = with_cached_access<RowsReg, 0>(A1);
43 for (index_t l = 0; l < k1; ++l)
44 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
45 simd Ail = A1_cached.load(ii, l);
46 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj) {
47 simd &Cij = C_reg[index(ii, jj)];
48 simd Blj = A1_cached.load(jj, l);
49 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
50 }
51 }
52 const auto A2_cached = with_cached_access<RowsReg, 0>(A2);
53 for (index_t l = 0; l < k2; ++l)
54 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
55 simd Ail = A2_cached.load(ii, l);
56 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj) {
57 simd &Cij = C_reg[index(ii, jj)];
58 simd Blj = A2_cached.load(jj, l);
59 Cij -= Ail * Blj;
60 }
61 }
62#if 1
63 // Actual Cholesky kernel (Cholesky–Crout)
64 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
65 UNROLL_FOR (index_t k = 0; k < j; ++k)
66 C_reg[index(j, j)] -= C_reg[index(j, k)] * C_reg[index(j, k)];
67 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
68 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
69 datapar::aligned_store(inv_pivot, invD + j * simd::size());
70 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i) {
71 UNROLL_FOR (index_t k = 0; k < j; ++k)
72 C_reg[index(i, j)] -= C_reg[index(i, k)] * C_reg[index(j, k)];
73 C_reg[index(i, j)] = inv_pivot * C_reg[index(i, j)];
74 }
75 }
76#elif 0
77 // Actual Cholesky kernel (naive, sqrt/rsqrt in critical path)
78 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
79 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
80 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
81 datapar::aligned_store(inv_pivot, invD + j * simd::size());
82 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
83 C_reg[index(i, j)] *= inv_pivot;
84 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
85 UNROLL_FOR (index_t k = j + 1; k <= i; ++k)
86 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
87 }
88#else
89 // Actual Cholesky kernel (naive, but hiding the latency of sqrt/rsqrt)
90 simd inv_pivot = rsqrt(C_reg[index(0, 0)]);
91 C_reg[index(0, 0)] = sqrt(C_reg[index(0, 0)]);
92 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
93 datapar::aligned_store(inv_pivot, invD + j * simd::size());
94 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
95 C_reg[index(i, j)] *= inv_pivot;
96 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
97 UNROLL_FOR (index_t k = j + 1; k <= i; ++k) {
98 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
99 if (k == j + 1 && i == j + 1) {
100 inv_pivot = rsqrt(C_reg[index(j + 1, j + 1)]);
101 C_reg[index(j + 1, j + 1)] = sqrt(C_reg[index(j + 1, j + 1)]);
102 }
103 }
104 }
105#endif
106 // Store result to memory
107 auto D_cached = with_cached_access<RowsReg, RowsReg>(D);
108 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
109 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj)
110 D_cached.store(C_reg[index(ii, jj)], ii, jj);
111}
112
113/// @param A1 RowsReg×k1.
114/// @param B1 ColsReg×k1.
115/// @param A2 RowsReg×k2.
116/// @param B2 ColsReg×k2.
117/// @param L ColsReg×ColsReg.
118/// @param invL ColsReg (inverted diagonal of L).
119/// @param C RowsReg×ColsReg.
120/// @param D RowsReg×ColsReg.
121/// @param k1 Number of columns in A1 and B1.
122/// @param k2 Number of columns in A2 and B2.
123template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder O1,
124 StorageOrder O2>
125[[gnu::hot, gnu::flatten]]
128 const uview<const T, Abi, O2> L, const T *invL,
130 const index_t k1, const index_t k2) noexcept {
131 static_assert(Conf.struc_C == MatrixStructure::LowerTriangular); // TODO
132 static_assert(RowsReg > 0 && ColsReg > 0);
133 using ops::rsqrt;
134 using simd = datapar::simd<T, Abi>;
135 // Pre-compute the offsets of the columns of C
136 const auto C_cached = with_cached_access<RowsReg, ColsReg>(C);
137 // Load matrix into registers
138 simd C_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
139 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
140 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
141 C_reg[ii][jj] = C_cached.load(ii, jj);
142 // Perform gemm operation of A and B
143 const auto A1_cached = with_cached_access<RowsReg, 0>(A1);
144 const auto B1_cached = with_cached_access<ColsReg, 0>(B1);
145 for (index_t l = 0; l < k1; ++l)
146 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
147 simd Ail = A1_cached.load(ii, l);
148 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
149 simd &Cij = C_reg[ii][jj];
150 simd Blj = B1_cached.load(jj, l);
151 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
152 }
153 }
154 const auto A2_cached = with_cached_access<RowsReg, 0>(A2);
155 const auto B2_cached = with_cached_access<ColsReg, 0>(B2);
156 for (index_t l = 0; l < k2; ++l)
157 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
158 simd Ail = A2_cached.load(ii, l);
159 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
160 simd &Cij = C_reg[ii][jj];
161 simd Blj = B2_cached.load(jj, l);
162 Cij -= Ail * Blj;
163 }
164 }
165 // Triangular solve
166 UNROLL_FOR (index_t jj = 0; jj < RowsReg; ++jj)
167 UNROLL_FOR (index_t ii = 0; ii < ColsReg; ++ii) {
168 simd &Xij = C_reg[jj][ii];
169 simd inv_piv = datapar::aligned_load<simd>(invL + ii * simd::size());
170 UNROLL_FOR (index_t kk = 0; kk < ii; ++kk) {
171 simd Aik = L.load(ii, kk);
172 simd Xkj = C_reg[jj][kk];
173 Xij -= Aik * Xkj;
174 }
175 Xij *= inv_piv;
176 }
177 // Store result to memory
178 auto D_cached = with_cached_access<RowsReg, ColsReg>(D);
179 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
180 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
181 D_cached.store(C_reg[ii][jj], ii, jj);
182}
183
184template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OCD>
186 const view<T, Abi, OCD> D, T regularization) noexcept {
187 using enum MatrixStructure;
188 static_assert(Conf.struc_C == LowerTriangular); // TODO
189 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
190 // Check dimensions
191 const index_t I = C.rows(), K = A.cols(), J = C.cols();
192 BATMAT_ASSUME(I >= J);
193 BATMAT_ASSUME(A.cols() == 0 || A.rows() == I);
194 BATMAT_ASSUME(D.rows() == I);
195 BATMAT_ASSUME(D.cols() == J);
196 BATMAT_ASSUME(I > 0);
197 BATMAT_ASSUME(J > 0);
198 static const auto potrf_microkernel = potrf_copy_lut<T, Abi, Conf, OA, OCD>;
199 static const auto trsm_microkernel = trsm_copy_lut<T, Abi, Conf, OA, OCD>;
200 (void)trsm_microkernel; // GCC incorrectly warns about unused variable
201 // Sizeless views to partition and pass to the micro-kernels
202 const uview<const T, Abi, OA> A_ = A;
203 const uview<const T, Abi, OCD> C_ = C;
204 const uview<T, Abi, OCD> D_ = D;
206
207 // Optimization for very small matrices
208 if (I <= Rows && J <= Cols && I == J)
209 return potrf_microkernel[J - 1](A_, C_, C_, D_, invD, K, 0, regularization);
210
211 foreach_chunked_merged( // Loop over the diagonal blocks of C
212 0, J, Cols, [&](index_t j, auto nj) {
213 const auto Aj = A_.middle_rows(j);
214 const auto Dj = D_.middle_rows(j);
215 const auto Djj = D_.block(j, j);
216 // Djj = chol(Cjj ± Aj Ajᵀ - Dj Djᵀ)
217 potrf_microkernel[nj - 1](Aj, Dj, C_.block(j, j), Djj, invD, K, j, regularization);
218 foreach_chunked_merged( // Loop over the subdiagonal rows
219 j + nj, I, Rows, [&](index_t i, auto ni) {
220 const auto Ai = A_.middle_rows(i);
221 const auto Di = D_.middle_rows(i);
222 const auto Cij = C_.block(i, j);
223 const auto Dij = D_.block(i, j);
224 // Dij = (Cij ± Ai Ajᵀ - Di Djᵀ) Djj⁻ᵀ
225 trsm_microkernel[ni - 1][nj - 1](Ai, Aj, Di, Dj, Djj, invD, Cij, Dij, K, j);
226 });
227 });
228}
229
230} // namespace batmat::linalg::micro_kernels::potrf
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:9
T rsqrt(T x)
Inverse square root.
Definition rsqrt.hpp:15
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
void aligned_store(V v, typename V::value_type *p)
Definition simd.hpp:121
stdx::memory_alignment< simd< Tp, Abi > > simd_align
Definition simd.hpp:139
stdx::simd_size< Tp, Abi > simd_size
Definition simd.hpp:137
V aligned_load(const typename V::value_type *p)
Definition simd.hpp:111
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
void potrf_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, T *invD, index_t k1, index_t k2, T regularization) noexcept
Definition potrf.tpp:24
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
const constinit auto potrf_copy_lut
Definition potrf.hpp:37
void potrf_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OCD > C, view< T, Abi, OCD > D, T regularization) noexcept
Definition potrf.tpp:185
void trsm_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O1 > B1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > B2, uview< const T, Abi, O2 > L, const T *invL, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, index_t k1, index_t k2) noexcept
Definition potrf.tpp:126
const constinit auto trsm_copy_lut
Definition potrf.hpp:43
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114