batmat 0.0.14
Batched linear algebra routines
Loading...
Searching...
No Matches
potrf.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
7#include <batmat/loop.hpp>
8#include <batmat/ops/cneg.hpp>
10
11#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
12
14
15template <KernelConfig Conf>
16auto load_diag(auto diag, index_t l) noexcept {
17 if constexpr (Conf.with_diag())
18 return diag.load(l);
19 else
20 return std::false_type{};
21}
22
23template <KernelConfig Conf>
24auto apply_diag(auto x, auto d) noexcept {
25 using ops::cneg;
26 if constexpr (Conf.diag_A == Conf.diag_sign_only)
27 return cneg(x, d);
28 else if constexpr (Conf.diag_A == Conf.diag)
29 return x * d;
30 else
31 return x;
32}
33
34/// @param A1 RowsReg×k1.
35/// @param A2 RowsReg×k2.
36/// @param C RowsReg×RowsReg.
37/// @param D RowsReg×RowsReg.
38/// @param invD Inverse diagonal of @p D.
39/// @param k1 Number of columns in A1.
40/// @param k2 Number of columns in A2.
41/// @param regularization Regularization added to the diagonal of C before factorization.
42template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder O1, StorageOrder O2>
43[[gnu::hot, gnu::flatten]] void
45 const uview<const T, Abi, O2> C, const uview<T, Abi, O2> D, T *const invD,
46 const index_t k1, const index_t k2, T regularization,
47 const diag_uview_type<const T, Abi, Conf> diag) noexcept {
48 static_assert(Conf.struc_C == MatrixStructure::LowerTriangular); // TODO
49 static_assert(RowsReg > 0);
50 using ops::rsqrt;
51 using simd = datapar::simd<T, Abi>;
52 // Pre-compute the offsets of the columns of C
53 const auto C_cached = with_cached_access<RowsReg, RowsReg>(C);
54 // Load matrix into registers
55 simd C_reg[RowsReg * (RowsReg + 1) / 2]; // NOLINT(*-c-arrays)
56 const auto index = [](index_t r, index_t c) { return c * (2 * RowsReg - 1 - c) / 2 + r; };
57 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
58 UNROLL_FOR (index_t jj = 0; jj < ii; ++jj)
59 C_reg[index(ii, jj)] = C_cached.load(ii, jj);
60 C_reg[index(ii, ii)] = C_cached.load(ii, ii) + simd(regularization);
61 }
62 // Perform syrk operation of A
63 const auto A1_cached = with_cached_access<RowsReg, 0>(A1);
64 for (index_t l = 0; l < k1; ++l) {
65 auto dl = load_diag<Conf>(diag, l);
66 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
67 simd Ail = apply_diag<Conf>(A1_cached.load(ii, l), dl);
68 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj) {
69 simd &Cij = C_reg[index(ii, jj)];
70 simd Blj = A1_cached.load(jj, l);
71 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
72 }
73 }
74 }
75 const auto A2_cached = with_cached_access<RowsReg, 0>(A2);
76 for (index_t l = 0; l < k2; ++l)
77 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
78 simd Ail = A2_cached.load(ii, l);
79 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj) {
80 simd &Cij = C_reg[index(ii, jj)];
81 simd Blj = A2_cached.load(jj, l);
82 Cij -= Ail * Blj;
83 }
84 }
85#if 1
86 // Actual Cholesky kernel (Cholesky–Crout)
87 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
88 UNROLL_FOR (index_t k = 0; k < j; ++k)
89 C_reg[index(j, j)] -= C_reg[index(j, k)] * C_reg[index(j, k)];
90 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
91 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
92 datapar::aligned_store(inv_pivot, invD + j * simd::size());
93 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i) {
94 UNROLL_FOR (index_t k = 0; k < j; ++k)
95 C_reg[index(i, j)] -= C_reg[index(i, k)] * C_reg[index(j, k)];
96 C_reg[index(i, j)] = inv_pivot * C_reg[index(i, j)];
97 }
98 }
99#elif 0
100 // Actual Cholesky kernel (naive, sqrt/rsqrt in critical path)
101 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
102 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
103 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
104 datapar::aligned_store(inv_pivot, invD + j * simd::size());
105 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
106 C_reg[index(i, j)] *= inv_pivot;
107 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
108 UNROLL_FOR (index_t k = j + 1; k <= i; ++k)
109 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
110 }
111#else
112 // Actual Cholesky kernel (naive, but hiding the latency of sqrt/rsqrt)
113 simd inv_pivot = rsqrt(C_reg[index(0, 0)]);
114 C_reg[index(0, 0)] = sqrt(C_reg[index(0, 0)]);
115 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
116 datapar::aligned_store(inv_pivot, invD + j * simd::size());
117 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
118 C_reg[index(i, j)] *= inv_pivot;
119 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
120 UNROLL_FOR (index_t k = j + 1; k <= i; ++k) {
121 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
122 if (k == j + 1 && i == j + 1) {
123 inv_pivot = rsqrt(C_reg[index(j + 1, j + 1)]);
124 C_reg[index(j + 1, j + 1)] = sqrt(C_reg[index(j + 1, j + 1)]);
125 }
126 }
127 }
128#endif
129 // Store result to memory
130 auto D_cached = with_cached_access<RowsReg, RowsReg>(D);
131 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
132 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj)
133 D_cached.store(C_reg[index(ii, jj)], ii, jj);
134}
135
136/// @param A1 RowsReg×k1.
137/// @param B1 ColsReg×k1.
138/// @param A2 RowsReg×k2.
139/// @param B2 ColsReg×k2.
140/// @param L ColsReg×ColsReg.
141/// @param invL ColsReg (inverted diagonal of L).
142/// @param C RowsReg×ColsReg.
143/// @param D RowsReg×ColsReg.
144/// @param k1 Number of columns in A1 and B1.
145/// @param k2 Number of columns in A2 and B2.
146template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder O1,
147 StorageOrder O2>
148[[gnu::hot, gnu::flatten]]
151 const uview<const T, Abi, O2> L, const T *invL,
153 const index_t k1, const index_t k2,
154 const diag_uview_type<const T, Abi, Conf> diag) noexcept {
155 static_assert(Conf.struc_C == MatrixStructure::LowerTriangular); // TODO
156 static_assert(RowsReg > 0 && ColsReg > 0);
157 using ops::rsqrt;
158 using simd = datapar::simd<T, Abi>;
159 // Pre-compute the offsets of the columns of C
160 const auto C_cached = with_cached_access<RowsReg, ColsReg>(C);
161 // Load matrix into registers
162 simd C_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
163 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
164 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
165 C_reg[ii][jj] = C_cached.load(ii, jj);
166 // Perform gemm operation of A and B
167 const auto A1_cached = with_cached_access<RowsReg, 0>(A1);
168 const auto B1_cached = with_cached_access<ColsReg, 0>(B1);
169 for (index_t l = 0; l < k1; ++l) {
170 const auto dl = load_diag<Conf>(diag, l);
171 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
172 simd Ail = apply_diag<Conf>(A1_cached.load(ii, l), dl);
173 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
174 simd &Cij = C_reg[ii][jj];
175 simd Blj = B1_cached.load(jj, l);
176 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
177 }
178 }
179 }
180 const auto A2_cached = with_cached_access<RowsReg, 0>(A2);
181 const auto B2_cached = with_cached_access<ColsReg, 0>(B2);
182 for (index_t l = 0; l < k2; ++l)
183 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
184 simd Ail = A2_cached.load(ii, l);
185 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
186 simd &Cij = C_reg[ii][jj];
187 simd Blj = B2_cached.load(jj, l);
188 Cij -= Ail * Blj;
189 }
190 }
191 // Triangular solve
192 UNROLL_FOR (index_t jj = 0; jj < RowsReg; ++jj)
193 UNROLL_FOR (index_t ii = 0; ii < ColsReg; ++ii) {
194 simd &Xij = C_reg[jj][ii];
195 simd inv_piv = datapar::aligned_load<simd>(invL + ii * simd::size());
196 UNROLL_FOR (index_t kk = 0; kk < ii; ++kk) {
197 simd Aik = L.load(ii, kk);
198 simd Xkj = C_reg[jj][kk];
199 Xij -= Aik * Xkj;
200 }
201 Xij *= inv_piv;
202 }
203 // Store result to memory
204 auto D_cached = with_cached_access<RowsReg, ColsReg>(D);
205 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
206 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
207 D_cached.store(C_reg[ii][jj], ii, jj);
208}
209
210template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OCD>
212 const view<T, Abi, OCD> D, T regularization,
213 const diag_view_type<const T, Abi, Conf> d) noexcept {
214 using enum MatrixStructure;
215 static_assert(Conf.struc_C == LowerTriangular); // TODO
216 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
217 // Check dimensions
218 const index_t I = C.rows(), K = A.cols(), J = C.cols();
219 BATMAT_ASSUME(I >= J);
220 BATMAT_ASSUME(A.cols() == 0 || A.rows() == I);
221 BATMAT_ASSUME(D.rows() == I);
222 BATMAT_ASSUME(D.cols() == J);
223 if constexpr (Conf.with_diag()) {
224 BATMAT_ASSUME(d.rows() == K);
225 BATMAT_ASSUME(d.cols() == 1);
226 }
227 BATMAT_ASSUME(I > 0);
228 BATMAT_ASSUME(J > 0);
229 static const auto potrf_microkernel = potrf_copy_lut<T, Abi, Conf, OA, OCD>;
230 static const auto trsm_microkernel = trsm_copy_lut<T, Abi, Conf, OA, OCD>;
231 (void)trsm_microkernel; // GCC incorrectly warns about unused variable
232 // Sizeless views to partition and pass to the micro-kernels
233 const uview<const T, Abi, OA> A_ = A;
234 const uview<const T, Abi, OCD> C_ = C;
235 const uview<T, Abi, OCD> D_ = D;
238
239 // Optimization for very small matrices
240 if (I <= Rows && J <= Cols && I == J)
241 return potrf_microkernel[J - 1](A_, C_, C_, D_, invD, K, 0, regularization, d_);
242
243 foreach_chunked_merged( // Loop over the diagonal blocks of C
244 0, J, Cols, [&](index_t j, auto nj) {
245 const auto Aj = A_.middle_rows(j);
246 const auto Dj = D_.middle_rows(j);
247 const auto Djj = D_.block(j, j);
248 // Djj = chol(Cjj ± Aj Ajᵀ - Dj Djᵀ)
249 potrf_microkernel[nj - 1](Aj, Dj, C_.block(j, j), Djj, invD, K, j, regularization, d_);
250 foreach_chunked_merged( // Loop over the subdiagonal rows
251 j + nj, I, Rows, [&](index_t i, auto ni) {
252 const auto Ai = A_.middle_rows(i);
253 const auto Di = D_.middle_rows(i);
254 const auto Cij = C_.block(i, j);
255 const auto Dij = D_.block(i, j);
256 // Dij = (Cij ± Ai Ajᵀ - Di Djᵀ) Djj⁻ᵀ
257 trsm_microkernel[ni - 1][nj - 1](Ai, Aj, Di, Dj, Djj, invD, Cij, Dij, K, j, d_);
258 });
259 });
260}
261
262} // namespace batmat::linalg::micro_kernels::potrf
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:9
T rsqrt(T x)
Inverse square root.
Definition rsqrt.hpp:15
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
void aligned_store(V v, typename V::value_type *p)
Definition simd.hpp:121
stdx::memory_alignment< simd< Tp, Abi > > simd_align
Definition simd.hpp:139
stdx::simd_size< Tp, Abi > simd_size
Definition simd.hpp:137
V aligned_load(const typename V::value_type *p)
Definition simd.hpp:111
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
auto apply_diag(auto x, auto d) noexcept
Definition potrf.tpp:24
const constinit auto potrf_copy_lut
Definition potrf.hpp:50
std::conditional_t< Conf.with_diag(), uview_vec< T, Abi >, std::false_type > diag_uview_type
Definition potrf.hpp:22
void trsm_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O1 > B1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > B2, uview< const T, Abi, O2 > L, const T *invL, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, index_t k1, index_t k2, diag_uview_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:149
auto load_diag(auto diag, index_t l) noexcept
Definition potrf.tpp:16
void potrf_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OCD > C, view< T, Abi, OCD > D, T regularization, diag_view_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:211
const constinit auto trsm_copy_lut
Definition potrf.hpp:56
std::conditional_t< Conf.with_diag(), view< T, Abi >, std::false_type > diag_view_type
Definition potrf.hpp:24
void potrf_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, T *invD, index_t k1, index_t k2, T regularization, diag_uview_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:44
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
T cneg(T x, T signs)
Conditionally negates the sign bit of x, depending on signs, which should contain only ±0 (i....
Definition cneg.hpp:42
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114