batmat 0.0.17
Batched linear algebra routines
Loading...
Searching...
No Matches
potrf.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
7#include <batmat/loop.hpp>
8#include <batmat/lut.hpp>
9#include <batmat/ops/cneg.hpp>
10#include <batmat/ops/rsqrt.hpp>
11
12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
13
15
16template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OC>
17inline const constinit auto potrf_copy_lut =
20 });
21
22template <class T, class Abi, KernelConfig Conf, StorageOrder O1, StorageOrder O2>
24 []<index_t Row, index_t Col>(index_constant<Row>, index_constant<Col>) {
26 });
27
28template <KernelConfig Conf>
29auto load_diag(auto diag, index_t l) noexcept {
30 if constexpr (Conf.with_diag())
31 return diag.load(l);
32 else
33 return std::false_type{};
34}
35
36template <KernelConfig Conf>
37auto apply_diag(auto x, auto d) noexcept {
38 using ops::cneg;
39 if constexpr (Conf.diag_A == Conf.diag_sign_only)
40 return cneg(x, d);
41 else if constexpr (Conf.diag_A == Conf.diag)
42 return x * d;
43 else
44 return x;
45}
46
47/// @param A1 RowsReg×k1.
48/// @param A2 RowsReg×k2.
49/// @param C RowsReg×RowsReg.
50/// @param D RowsReg×RowsReg.
51/// @param invD Inverse diagonal of @p D.
52/// @param k1 Number of columns in A1.
53/// @param k2 Number of columns in A2.
54/// @param regularization Regularization added to the diagonal of C before factorization.
55/// @param diag k1-vector that scales the columns of A1 before multiplying by its transpose.
56/// Used only if enabled in the kernel config.
57template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder O1, StorageOrder O2>
58[[gnu::hot, gnu::flatten]] void
60 const uview<const T, Abi, O2> C, const uview<T, Abi, O2> D, T *const invD,
61 const index_t k1, const index_t k2, T regularization,
62 const diag_uview_type<const T, Abi, Conf> diag) noexcept {
63 static_assert(Conf.struc_C == MatrixStructure::LowerTriangular); // TODO
64 static_assert(RowsReg > 0);
65 using ops::rsqrt;
66 using simd = datapar::simd<T, Abi>;
67 // Pre-compute the offsets of the columns of C
68 const auto C_cached = with_cached_access<RowsReg, RowsReg>(C);
69 // Load matrix into registers
70 simd C_reg[RowsReg * (RowsReg + 1) / 2]; // NOLINT(*-c-arrays)
71 const auto index = [](index_t r, index_t c) { return c * (2 * RowsReg - 1 - c) / 2 + r; };
72 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
73 UNROLL_FOR (index_t jj = 0; jj < ii; ++jj)
74 C_reg[index(ii, jj)] = C_cached.load(ii, jj);
75 C_reg[index(ii, ii)] = C_cached.load(ii, ii) + simd(regularization);
76 }
77 // Perform syrk operation of A
78 const auto A1_cached = with_cached_access<RowsReg, 0>(A1);
79 for (index_t l = 0; l < k1; ++l) {
80 auto dl = load_diag<Conf>(diag, l);
81 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
82 simd Ail = apply_diag<Conf>(A1_cached.load(ii, l), dl);
83 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj) {
84 simd &Cij = C_reg[index(ii, jj)];
85 simd Blj = A1_cached.load(jj, l);
86 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
87 }
88 }
89 }
90 const auto A2_cached = with_cached_access<RowsReg, 0>(A2);
91 for (index_t l = 0; l < k2; ++l)
92 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
93 simd Ail = A2_cached.load(ii, l);
94 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj) {
95 simd &Cij = C_reg[index(ii, jj)];
96 simd Blj = A2_cached.load(jj, l);
97 Cij -= Ail * Blj;
98 }
99 }
100#if 1
101 // Actual Cholesky kernel (Cholesky–Crout)
102 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
103 UNROLL_FOR (index_t k = 0; k < j; ++k)
104 C_reg[index(j, j)] -= C_reg[index(j, k)] * C_reg[index(j, k)];
105 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
106 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
107 datapar::aligned_store(inv_pivot, invD + j * simd::size());
108 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i) {
109 UNROLL_FOR (index_t k = 0; k < j; ++k)
110 C_reg[index(i, j)] -= C_reg[index(i, k)] * C_reg[index(j, k)];
111 C_reg[index(i, j)] = inv_pivot * C_reg[index(i, j)];
112 }
113 }
114#elif 0
115 // Actual Cholesky kernel (naive, sqrt/rsqrt in critical path)
116 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
117 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
118 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
119 datapar::aligned_store(inv_pivot, invD + j * simd::size());
120 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
121 C_reg[index(i, j)] *= inv_pivot;
122 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
123 UNROLL_FOR (index_t k = j + 1; k <= i; ++k)
124 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
125 }
126#else
127 // Actual Cholesky kernel (naive, but hiding the latency of sqrt/rsqrt)
128 simd inv_pivot = rsqrt(C_reg[index(0, 0)]);
129 C_reg[index(0, 0)] = sqrt(C_reg[index(0, 0)]);
130 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j) {
131 datapar::aligned_store(inv_pivot, invD + j * simd::size());
132 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
133 C_reg[index(i, j)] *= inv_pivot;
134 UNROLL_FOR (index_t i = j + 1; i < RowsReg; ++i)
135 UNROLL_FOR (index_t k = j + 1; k <= i; ++k) {
136 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
137 if (k == j + 1 && i == j + 1) {
138 inv_pivot = rsqrt(C_reg[index(j + 1, j + 1)]);
139 C_reg[index(j + 1, j + 1)] = sqrt(C_reg[index(j + 1, j + 1)]);
140 }
141 }
142 }
143#endif
144 // Store result to memory
145 auto D_cached = with_cached_access<RowsReg, RowsReg>(D);
146 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
147 UNROLL_FOR (index_t jj = 0; jj <= ii; ++jj)
148 D_cached.store(C_reg[index(ii, jj)], ii, jj);
149}
150
151/// @param A1 RowsReg×k1.
152/// @param B1 ColsReg×k1.
153/// @param A2 RowsReg×k2.
154/// @param B2 ColsReg×k2.
155/// @param L ColsReg×ColsReg.
156/// @param invL ColsReg (inverted diagonal of L).
157/// @param C RowsReg×ColsReg.
158/// @param D RowsReg×ColsReg.
159/// @param k1 Number of columns in A1 and B1.
160/// @param k2 Number of columns in A2 and B2.
161/// @param diag k1-vector that scales the columns of A1 before multiplying by its transpose.
162/// Used only if enabled in the kernel config.
163template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder O1,
164 StorageOrder O2>
165[[gnu::hot, gnu::flatten]]
168 const uview<const T, Abi, O2> L, const T *invL,
170 const index_t k1, const index_t k2,
171 const diag_uview_type<const T, Abi, Conf> diag) noexcept {
172 static_assert(Conf.struc_C == MatrixStructure::LowerTriangular); // TODO
173 static_assert(RowsReg > 0 && ColsReg > 0);
174 using ops::rsqrt;
175 using simd = datapar::simd<T, Abi>;
176 // Pre-compute the offsets of the columns of C
177 const auto C_cached = with_cached_access<RowsReg, ColsReg>(C);
178 // Load matrix into registers
179 simd C_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
180 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
181 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
182 C_reg[ii][jj] = C_cached.load(ii, jj);
183 // Perform gemm operation of A and B
184 const auto A1_cached = with_cached_access<RowsReg, 0>(A1);
185 const auto B1_cached = with_cached_access<ColsReg, 0>(B1);
186 for (index_t l = 0; l < k1; ++l) {
187 const auto dl = load_diag<Conf>(diag, l);
188 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
189 simd Ail = apply_diag<Conf>(A1_cached.load(ii, l), dl);
190 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
191 simd &Cij = C_reg[ii][jj];
192 simd Blj = B1_cached.load(jj, l);
193 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
194 }
195 }
196 }
197 const auto A2_cached = with_cached_access<RowsReg, 0>(A2);
198 const auto B2_cached = with_cached_access<ColsReg, 0>(B2);
199 for (index_t l = 0; l < k2; ++l)
200 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
201 simd Ail = A2_cached.load(ii, l);
202 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
203 simd &Cij = C_reg[ii][jj];
204 simd Blj = B2_cached.load(jj, l);
205 Cij -= Ail * Blj;
206 }
207 }
208 // Triangular solve
209 UNROLL_FOR (index_t jj = 0; jj < RowsReg; ++jj)
210 UNROLL_FOR (index_t ii = 0; ii < ColsReg; ++ii) {
211 simd &Xij = C_reg[jj][ii];
212 simd inv_piv = datapar::aligned_load<simd>(invL + ii * simd::size());
213 UNROLL_FOR (index_t kk = 0; kk < ii; ++kk) {
214 simd Aik = L.load(ii, kk);
215 simd Xkj = C_reg[jj][kk];
216 Xij -= Aik * Xkj;
217 }
218 Xij *= inv_piv;
219 }
220 // Store result to memory
221 auto D_cached = with_cached_access<RowsReg, ColsReg>(D);
222 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
223 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
224 D_cached.store(C_reg[ii][jj], ii, jj);
225}
226
227template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OCD>
229 const view<T, Abi, OCD> D, T regularization,
230 const diag_view_type<const T, Abi, Conf> d) noexcept {
231 using enum MatrixStructure;
232 static_assert(Conf.struc_C == LowerTriangular); // TODO
233 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
234 // Check dimensions
235 const index_t I = C.rows(), K = A.cols(), J = C.cols();
236 BATMAT_ASSUME(I >= J);
237 BATMAT_ASSUME(A.cols() == 0 || A.rows() == I);
238 BATMAT_ASSUME(D.rows() == I);
239 BATMAT_ASSUME(D.cols() == J);
240 if constexpr (Conf.with_diag()) {
241 BATMAT_ASSUME(d.rows() == K);
242 BATMAT_ASSUME(d.cols() == 1);
243 }
244 BATMAT_ASSUME(I > 0);
245 BATMAT_ASSUME(J > 0);
246 static const auto potrf_microkernel = potrf_copy_lut<T, Abi, Conf, OA, OCD>;
247 static const auto trsm_microkernel = trsm_copy_lut<T, Abi, Conf, OA, OCD>;
248 (void)trsm_microkernel; // GCC incorrectly warns about unused variable
249 // Sizeless views to partition and pass to the micro-kernels
250 const uview<const T, Abi, OA> A_ = A;
251 const uview<const T, Abi, OCD> C_ = C;
252 const uview<T, Abi, OCD> D_ = D;
255
256 // Optimization for very small matrices
257 if (I <= Rows && J <= Cols && I == J)
258 return potrf_microkernel[J - 1](A_, C_, C_, D_, invD, K, 0, regularization, d_);
259
260 foreach_chunked_merged( // Loop over the diagonal blocks of C
261 0, J, Cols, [&](index_t j, auto nj) {
262 const auto Aj = A_.middle_rows(j);
263 const auto Dj = D_.middle_rows(j);
264 const auto Djj = D_.block(j, j);
265 // Djj = chol(Cjj ± Aj Ajᵀ - Dj Djᵀ)
266 potrf_microkernel[nj - 1](Aj, Dj, C_.block(j, j), Djj, invD, K, j, regularization, d_);
267 foreach_chunked_merged( // Loop over the subdiagonal rows
268 j + nj, I, Rows, [&](index_t i, auto ni) {
269 const auto Ai = A_.middle_rows(i);
270 const auto Di = D_.middle_rows(i);
271 const auto Cij = C_.block(i, j);
272 const auto Dij = D_.block(i, j);
273 // Dij = (Cij ± Ai Ajᵀ - Di Djᵀ) Djj⁻ᵀ
274 trsm_microkernel[ni - 1][nj - 1](Ai, Aj, Di, Dj, Djj, invD, Cij, Dij, K, j, d_);
275 });
276 });
277}
278
279} // namespace batmat::linalg::micro_kernels::potrf
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
T rsqrt(T x)
Inverse square root.
Definition rsqrt.hpp:15
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
void aligned_store(V v, typename V::value_type *p)
Definition simd.hpp:121
stdx::memory_alignment< simd< Tp, Abi > > simd_align
Definition simd.hpp:139
stdx::simd_size< Tp, Abi > simd_size
Definition simd.hpp:137
V aligned_load(const typename V::value_type *p)
Definition simd.hpp:111
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
auto apply_diag(auto x, auto d) noexcept
Definition potrf.tpp:37
const constinit auto potrf_copy_lut
Definition potrf.tpp:17
std::conditional_t< Conf.with_diag(), uview_vec< T, Abi >, std::false_type > diag_uview_type
Definition potrf.hpp:22
void trsm_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O1 > B1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > B2, uview< const T, Abi, O2 > L, const T *invL, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, index_t k1, index_t k2, diag_uview_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:166
auto load_diag(auto diag, index_t l) noexcept
Definition potrf.tpp:29
void potrf_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OCD > C, view< T, Abi, OCD > D, T regularization, diag_view_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:228
const constinit auto trsm_copy_lut
Definition potrf.tpp:23
std::conditional_t< Conf.with_diag(), view< T, Abi >, std::false_type > diag_view_type
Definition potrf.hpp:24
void potrf_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, T *invD, index_t k1, index_t k2, T regularization, diag_uview_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:59
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
T cneg(T x, T signs)
Conditionally negates the sign bit of x, depending on signs, which should contain only ±0 (i....
Definition cneg.hpp:42
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114