batmat 0.0.16
Batched linear algebra routines
Loading...
Searching...
No Matches
potrf.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
7#include <batmat/loop.hpp>
8#include <batmat/ops/cneg.hpp>
10
11#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
12
14
15template <KernelConfig Conf>
16auto load_diag(auto diag, index_t l) noexcept {
17 if constexpr (Conf.with_diag())
18 return diag.load(l);
19 else
20 return std::false_type{};
21}
22
23template <KernelConfig Conf>
24auto apply_diag(auto x, auto d) noexcept {
25 using ops::cneg;
26 if constexpr (Conf.diag_A == Conf.diag_sign_only)
27 return cneg(x, d);
28 else if constexpr (Conf.diag_A == Conf.diag)
29 return x * d;
30 else
31 return x;
32}
33
34/// @param A1 RowsReg×k1.
35/// @param A2 RowsReg×k2.
36/// @param C RowsReg×RowsReg.
37/// @param D RowsReg×RowsReg.
38/// @param invD Inverse diagonal of @p D.
39/// @param k1 Number of columns in A1.
40/// @param k2 Number of columns in A2.
41/// @param regularization Regularization added to the diagonal of C before factorization.
42/// @param diag k1-vector that scales the columns of A1 before multiplying by its transpose.
43/// Used only if enabled in the kernel config.
44template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder O1, StorageOrder O2>
45[[gnu::hot, gnu::flatten]] void
47 const uview<const T, Abi, O2> C, const uview<T, Abi, O2> D, T *const invD,
48 const index_t k1, const index_t k2, T regularization,
49 const diag_uview_type<const T, Abi, Conf> diag) noexcept {
50 static_assert(Conf.struc_C == MatrixStructure::LowerTriangular); // TODO
51 static_assert(RowsReg > 0);
52 using ops::rsqrt;
53 using simd = datapar::simd<T, Abi>;
54 // Pre-compute the offsets of the columns of C
55 const auto C_cached = with_cached_access<RowsReg, RowsReg>(C);
56 // Load matrix into registers
57 simd C_reg[RowsReg * (RowsReg + 1) / 2]; // NOLINT(*-c-arrays)
58 const auto index = [](index_t r, index_t c) { return c * (2 * RowsReg - 1 - c) / 2 + r; };
59 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
60 UNROLL_FOR (index_t jj = 0; jj < ii; ++jj)
61 C_reg[index(ii, jj)] = C_cached.load(ii, jj);
62 C_reg[index(ii, ii)] = C_cached.load(ii, ii) + simd(regularization);
63 }
64 // Perform syrk operation of A
65 const auto A1_cached = with_cached_access<RowsReg, 0>(A1);
66 for (index_t l = 0; l < k1; ++l) {
67 auto dl = load_diag<Conf>(diag, l);
68 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
69 simd Ail = apply_diag<Conf>(A1_cached.load(ii, l), dl);
70 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj) {
71 simd &Cij = C_reg[index(ii, jj)];
72 simd Blj = A1_cached.load(jj, l);
73 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
74 }
75 }
76 }
77 const auto A2_cached = with_cached_access<RowsReg, 0>(A2);
78 for (index_t l = 0; l < k2; ++l)
79 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
80 simd Ail = A2_cached.load(ii, l);
81 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj) {
82 simd &Cij = C_reg[index(ii, jj)];
83 simd Blj = A2_cached.load(jj, l);
84 Cij -= Ail * Blj;
85 }
86 }
87#if 1
88 // Actual Cholesky kernel (Cholesky–Crout)
89 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
90 UNROLL_FOR (index_t k = 0; k < j; ++k)
91 C_reg[index(j, j)] -= C_reg[index(j, k)] * C_reg[index(j, k)];
92 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
93 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
94 datapar::aligned_store(inv_pivot, invD + j * simd::size());
95 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i) {
96 UNROLL_FOR (index_t k = 0; k < j; ++k)
97 C_reg[index(i, j)] -= C_reg[index(i, k)] * C_reg[index(j, k)];
98 C_reg[index(i, j)] = inv_pivot * C_reg[index(i, j)];
99 }
100 }
101#elif 0
102 // Actual Cholesky kernel (naive, sqrt/rsqrt in critical path)
103 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
104 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
105 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
106 datapar::aligned_store(inv_pivot, invD + j * simd::size());
107 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
108 C_reg[index(i, j)] *= inv_pivot;
109 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
110 UNROLL_FOR (index_t k = j + 1; k <= i; ++k)
111 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
112 }
113#else
114 // Actual Cholesky kernel (naive, but hiding the latency of sqrt/rsqrt)
115 simd inv_pivot = rsqrt(C_reg[index(0, 0)]);
116 C_reg[index(0, 0)] = sqrt(C_reg[index(0, 0)]);
117 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
118 datapar::aligned_store(inv_pivot, invD + j * simd::size());
119 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
120 C_reg[index(i, j)] *= inv_pivot;
121 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
122 UNROLL_FOR (index_t k = j + 1; k <= i; ++k) {
123 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
124 if (k == j + 1 && i == j + 1) {
125 inv_pivot = rsqrt(C_reg[index(j + 1, j + 1)]);
126 C_reg[index(j + 1, j + 1)] = sqrt(C_reg[index(j + 1, j + 1)]);
127 }
128 }
129 }
130#endif
131 // Store result to memory
132 auto D_cached = with_cached_access<RowsReg, RowsReg>(D);
133 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
134 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj)
135 D_cached.store(C_reg[index(ii, jj)], ii, jj);
136}
137
138/// @param A1 RowsReg×k1.
139/// @param B1 ColsReg×k1.
140/// @param A2 RowsReg×k2.
141/// @param B2 ColsReg×k2.
142/// @param L ColsReg×ColsReg.
143/// @param invL ColsReg (inverted diagonal of L).
144/// @param C RowsReg×ColsReg.
145/// @param D RowsReg×ColsReg.
146/// @param k1 Number of columns in A1 and B1.
147/// @param k2 Number of columns in A2 and B2.
148/// @param diag k1-vector that scales the columns of A1 before multiplying by its transpose.
149/// Used only if enabled in the kernel config.
150template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder O1,
151 StorageOrder O2>
152[[gnu::hot, gnu::flatten]]
155 const uview<const T, Abi, O2> L, const T *invL,
157 const index_t k1, const index_t k2,
158 const diag_uview_type<const T, Abi, Conf> diag) noexcept {
159 static_assert(Conf.struc_C == MatrixStructure::LowerTriangular); // TODO
160 static_assert(RowsReg > 0 && ColsReg > 0);
161 using ops::rsqrt;
162 using simd = datapar::simd<T, Abi>;
163 // Pre-compute the offsets of the columns of C
164 const auto C_cached = with_cached_access<RowsReg, ColsReg>(C);
165 // Load matrix into registers
166 simd C_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
167 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
168 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
169 C_reg[ii][jj] = C_cached.load(ii, jj);
170 // Perform gemm operation of A and B
171 const auto A1_cached = with_cached_access<RowsReg, 0>(A1);
172 const auto B1_cached = with_cached_access<ColsReg, 0>(B1);
173 for (index_t l = 0; l < k1; ++l) {
174 const auto dl = load_diag<Conf>(diag, l);
175 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
176 simd Ail = apply_diag<Conf>(A1_cached.load(ii, l), dl);
177 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
178 simd &Cij = C_reg[ii][jj];
179 simd Blj = B1_cached.load(jj, l);
180 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
181 }
182 }
183 }
184 const auto A2_cached = with_cached_access<RowsReg, 0>(A2);
185 const auto B2_cached = with_cached_access<ColsReg, 0>(B2);
186 for (index_t l = 0; l < k2; ++l)
187 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
188 simd Ail = A2_cached.load(ii, l);
189 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
190 simd &Cij = C_reg[ii][jj];
191 simd Blj = B2_cached.load(jj, l);
192 Cij -= Ail * Blj;
193 }
194 }
195 // Triangular solve
196 UNROLL_FOR (index_t jj = 0; jj < RowsReg; ++jj)
197 UNROLL_FOR (index_t ii = 0; ii < ColsReg; ++ii) {
198 simd &Xij = C_reg[jj][ii];
199 simd inv_piv = datapar::aligned_load<simd>(invL + ii * simd::size());
200 UNROLL_FOR (index_t kk = 0; kk < ii; ++kk) {
201 simd Aik = L.load(ii, kk);
202 simd Xkj = C_reg[jj][kk];
203 Xij -= Aik * Xkj;
204 }
205 Xij *= inv_piv;
206 }
207 // Store result to memory
208 auto D_cached = with_cached_access<RowsReg, ColsReg>(D);
209 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
210 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
211 D_cached.store(C_reg[ii][jj], ii, jj);
212}
213
214template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OCD>
216 const view<T, Abi, OCD> D, T regularization,
217 const diag_view_type<const T, Abi, Conf> d) noexcept {
218 using enum MatrixStructure;
219 static_assert(Conf.struc_C == LowerTriangular); // TODO
220 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
221 // Check dimensions
222 const index_t I = C.rows(), K = A.cols(), J = C.cols();
223 BATMAT_ASSUME(I >= J);
224 BATMAT_ASSUME(A.cols() == 0 || A.rows() == I);
225 BATMAT_ASSUME(D.rows() == I);
226 BATMAT_ASSUME(D.cols() == J);
227 if constexpr (Conf.with_diag()) {
228 BATMAT_ASSUME(d.rows() == K);
229 BATMAT_ASSUME(d.cols() == 1);
230 }
231 BATMAT_ASSUME(I > 0);
232 BATMAT_ASSUME(J > 0);
233 static const auto potrf_microkernel = potrf_copy_lut<T, Abi, Conf, OA, OCD>;
234 static const auto trsm_microkernel = trsm_copy_lut<T, Abi, Conf, OA, OCD>;
235 (void)trsm_microkernel; // GCC incorrectly warns about unused variable
236 // Sizeless views to partition and pass to the micro-kernels
237 const uview<const T, Abi, OA> A_ = A;
238 const uview<const T, Abi, OCD> C_ = C;
239 const uview<T, Abi, OCD> D_ = D;
242
243 // Optimization for very small matrices
244 if (I <= Rows && J <= Cols && I == J)
245 return potrf_microkernel[J - 1](A_, C_, C_, D_, invD, K, 0, regularization, d_);
246
247 foreach_chunked_merged( // Loop over the diagonal blocks of C
248 0, J, Cols, [&](index_t j, auto nj) {
249 const auto Aj = A_.middle_rows(j);
250 const auto Dj = D_.middle_rows(j);
251 const auto Djj = D_.block(j, j);
252 // Djj = chol(Cjj ± Aj Ajᵀ - Dj Djᵀ)
253 potrf_microkernel[nj - 1](Aj, Dj, C_.block(j, j), Djj, invD, K, j, regularization, d_);
254 foreach_chunked_merged( // Loop over the subdiagonal rows
255 j + nj, I, Rows, [&](index_t i, auto ni) {
256 const auto Ai = A_.middle_rows(i);
257 const auto Di = D_.middle_rows(i);
258 const auto Cij = C_.block(i, j);
259 const auto Dij = D_.block(i, j);
260 // Dij = (Cij ± Ai Ajᵀ - Di Djᵀ) Djj⁻ᵀ
261 trsm_microkernel[ni - 1][nj - 1](Ai, Aj, Di, Dj, Djj, invD, Cij, Dij, K, j, d_);
262 });
263 });
264}
265
266} // namespace batmat::linalg::micro_kernels::potrf
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:9
T rsqrt(T x)
Inverse square root.
Definition rsqrt.hpp:15
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
void aligned_store(V v, typename V::value_type *p)
Definition simd.hpp:121
stdx::memory_alignment< simd< Tp, Abi > > simd_align
Definition simd.hpp:139
stdx::simd_size< Tp, Abi > simd_size
Definition simd.hpp:137
V aligned_load(const typename V::value_type *p)
Definition simd.hpp:111
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
auto apply_diag(auto x, auto d) noexcept
Definition potrf.tpp:24
const constinit auto potrf_copy_lut
Definition potrf.hpp:50
std::conditional_t< Conf.with_diag(), uview_vec< T, Abi >, std::false_type > diag_uview_type
Definition potrf.hpp:22
void trsm_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O1 > B1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > B2, uview< const T, Abi, O2 > L, const T *invL, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, index_t k1, index_t k2, diag_uview_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:153
auto load_diag(auto diag, index_t l) noexcept
Definition potrf.tpp:16
void potrf_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OCD > C, view< T, Abi, OCD > D, T regularization, diag_view_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:215
const constinit auto trsm_copy_lut
Definition potrf.hpp:56
std::conditional_t< Conf.with_diag(), view< T, Abi >, std::false_type > diag_view_type
Definition potrf.hpp:24
void potrf_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, T *invD, index_t k1, index_t k2, T regularization, diag_uview_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:46
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
T cneg(T x, T signs)
Conditionally negates the sign bit of x, depending on signs, which should contain only ±0 (i....
Definition cneg.hpp:42
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114