batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
small-potrf.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
5#include <batmat/loop.hpp>
6#include <batmat/lut.hpp>
8#include <batmat/ops/sqrt.hpp>
9#include <batmat/simd.hpp>
10#include <bit>
11
12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
13
15
16template <class T, index_t NC> // number of columns to handle at once
17[[gnu::flatten, gnu::hot]]
19 using ops::rsqrt;
20 using ops::sqrt;
21 constexpr index_t NR = 8; // Number of rows in each sub-diagonal block
22 T Dr[NC * (NC + 1) / 2];
23 static constexpr auto index = [](index_t r, index_t c) { return c * (2 * NC - 1 - c) / 2 + r; };
24 /* Load diagonal block into (scalar) registers */
25 UNROLL_FOR (index_t j = 0; j < NC; ++j) // column
26 UNROLL_FOR (index_t i = j; i < NC; ++i) // row
27 Dr[index(i, j)] = A(i, j);
28 /* Cholesky factorization of diagonal block */
29 UNROLL_FOR (index_t j = 0; j < NC; ++j) { // column
30 const auto pivot = sqrt(Dr[index(j, j)]);
31 const auto inv_pivot = rsqrt(Dr[index(j, j)]);
32 Dr[index(j, j)] = inv_pivot;
33 UNROLL_FOR (index_t i = j + 1; i < NC; ++i)
34 Dr[index(i, j)] *= inv_pivot;
35 UNROLL_FOR (index_t kk = j + 1; kk < NC; ++kk) { // column syrk
36 const T fac = Dr[index(kk, j)];
37 UNROLL_FOR (index_t i = kk; i < NC; ++i)
38 Dr[index(i, kk)] -= Dr[index(i, j)] * fac;
39 }
40 L(j, j) = pivot;
41 UNROLL_FOR (index_t i = j + 1; i < NC; ++i) // row
42 L(i, j) = Dr[index(i, j)];
43 }
44 /* Multiply the sub-diagonal blocks by the inverse of the Cholesky factor */
45 auto trsm_tail = [&](auto &trsm_tail, index_t r, auto N) {
46 using simdN = datapar::deduced_simd<T, N>;
47 for (; r + N <= k; r += N) { // block row
48 simdN Xrx[NC];
49 UNROLL_FOR (index_t c = 0; c < NC; ++c) // column
50 Xrx[c] = datapar::unaligned_load<simdN>(&A(r, c));
51 UNROLL_FOR (index_t c = 0; c < NC; ++c) { // column
52 simdN &Xij = Xrx[c];
53 UNROLL_FOR (index_t kk = 0; kk < c; ++kk) { // column inner
54 const T Aik = Dr[index(c, kk)];
55 Xij -= Aik * Xrx[kk];
56 }
57 Xij *= Dr[index(c, c)];
58 datapar::unaligned_store(Xij, &L(r, c));
59 }
60 }
61 if constexpr (N > 1)
62 trsm_tail(trsm_tail, r, std::integral_constant<index_t, N / 2>());
63 };
64 trsm_tail(trsm_tail, NC, std::integral_constant<index_t, NR>());
65}
66
67/// Outer product for updating the bottom right tail during Cholesky factorization.
68/// @param A21 rows×ColsReg
69/// @param A22 rows×RowsReg
70template <class T, index_t RowsReg, index_t ColsReg>
71[[gnu::flatten, gnu::hot]]
73 scalar_view<T> L22) noexcept {
74 constexpr index_t NR = 8; // Number of rows in each sub-diagonal block
75 // Pre-compute the offsets of the columns of A21 and A22
76 auto L21_cached = with_cached_access<0, ColsReg>(L21);
77 auto A22_cached = with_cached_access<0, RowsReg>(A22);
78 auto L22_cached = with_cached_access<0, RowsReg>(L22);
79 // Load matrix into registers
80 T A21_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
81 UNROLL_FOR (index_t i = 0; i < RowsReg; ++i)
82 UNROLL_FOR (index_t j = 0; j < ColsReg; ++j)
83 A21_reg[i][j] = L21_cached(i, j);
84 // Matrix multiplication of diagonal block
85 UNROLL_FOR (index_t i = 0; i < RowsReg; ++i) {
86 T A22ix[RowsReg];
87 UNROLL_FOR (index_t j = 0; j <= i; ++j)
88 A22ix[j] = A22_cached(i, j);
89 UNROLL_FOR (index_t j = 0; j <= i; ++j)
90 UNROLL_FOR (index_t kk = 0; kk < ColsReg; ++kk)
91 A22ix[j] -= A21_reg[i][kk] * A21_reg[j][kk];
92 UNROLL_FOR (index_t j = 0; j <= i; ++j)
93 L22_cached(i, j) = A22ix[j];
94 }
95 // Matrix multiplication of sub-diagonal block
96 auto gemm_tail = [&](auto &gemm_tail, index_t i, auto N) {
97 using simd = datapar::deduced_simd<T, N>;
98 for (; i + N <= k; i += N) { // block row
99 simd Aix[RowsReg];
100 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j)
101 Aix[j] = datapar::unaligned_load<simd>(&A22_cached(i, j));
102 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j)
103 UNROLL_FOR (index_t kk = 0; kk < ColsReg; ++kk) {
104 const simd A21ik = datapar::unaligned_load<simd>(&L21_cached(i, kk));
105 Aix[j] -= A21ik * A21_reg[j][kk];
106 }
107 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j)
108 datapar::unaligned_store(Aix[j], &L22_cached(i, j));
109 }
110 if constexpr (N > 1)
111 gemm_tail(gemm_tail, i, std::integral_constant<index_t, N / 2>());
112 };
113 gemm_tail(gemm_tail, RowsReg, std::integral_constant<index_t, NR>());
114}
115
116template <class T, index_t R>
118 index_t n) noexcept {
119 static const constinit auto microkernel_trsm_lut = make_1d_lut<R>(
121 static const constinit auto microkernel_syrk_lut = make_1d_lut<R>(
123 static const constinit auto microkernel_syrk_lut_2 =
126 });
127 (void)microkernel_syrk_lut; // Invalid GCC warning
128 (void)microkernel_syrk_lut_2;
129
130 const index_t m = L.rows(), N = L.cols();
131 if (n < 0)
132 n = N;
133 BATMAT_ASSUME(m >= N);
134 BATMAT_ASSUME((n == m && m == N) || (n == N && m >= N) || (n < m && m == N));
135
137 scalar_view<T> L_ = L;
138
139 // Compute the Cholesky factorization of the very last block (right before
140 // the Schur complement block), which has size r×r rather than R×R.
141 // If requested, also update the rows below the Cholesky factor, and the
142 // Schur complement to the bottom right of the given block.
143 // These extra blocks are always sizes (m-n)×r and (m-n)×(m-n) respectively.
144 const auto process_bottom_right = [m, N, n](scalar_view<const T> Aii, scalar_view<T> Lii,
145 index_t r) {
146 // Cholesky of last block to be factorized + triangular solve with
147 // sub-diagonal block.
148 microkernel_trsm_lut[r - 1](r + m - n, Aii, Lii);
149 // Update the Schur complement (bottom right) with the outer product
150 // of the sub-diagonal block column.
151 if (n < N) {
152 auto L21 = Lii.middle_rows(r), L22 = Lii.block(r, r);
153 auto A22 = Aii.block(r, r);
155 0, m - n, index_constant<R>(),
156 [&](index_t j, auto rem) {
157 auto Lj1 = L21.middle_rows(j), Ljj = L22.middle_cols(j);
158 auto Ajj = A22.middle_cols(j);
159 microkernel_syrk_lut_2[rem - 1][r - 1](m - n - j, Lj1, Ajj, Ljj);
160 },
162 }
163 };
164
165 // Base case
166 if (n == 0) {
167 return;
168 } else if (n <= R) {
169 process_bottom_right(A_, L_, n);
170 return;
171 }
172 // Loop over columns of H with block size R.
173 index_t i;
174 for (i = 0; i + R <= n; i += R) {
175 auto L11 = L_.block(i, i);
176 auto A11 = i == 0 ? A_.block(i, i) : decltype(A_){L11};
177 // Factor the diagonal block and update the subdiagonal block
178 potrf_trsm_microkernel<T, R>(m - i, A11, L11);
179 // Update the Schur complement (bottom right) with the outer product of
180 // the subdiagonal block.
182 i + R, N, index_constant<R>(),
183 [&](index_t j, auto rem) {
184 auto L21 = L_.block(j, i), L22 = L_.block(j, j);
185 auto A22 = i == 0 ? A_.block(j, j) : decltype(A_){L22};
186 microkernel_syrk_lut[rem - 1](m - j, L21, A22, L22);
187 },
189 // Loop backwards for cache locality (we'll use the next column in the
190 // next interation, so we want the syrk operation to leave it in cache).
191 // TODO: verify in benchmark.
192 }
193 const index_t rem = n - i;
194 if (rem > 0) {
195 auto Lii = L_.block(i, i);
196 auto Aii = i == 0 ? A_.block(i, i) : decltype(A_){Lii};
197 process_bottom_right(Aii, Lii, rem);
198 }
199}
200
201/// Left-looking variant of small_potrf, which updates the current block with the outer product of
202/// the previously computed part L21.
203/// @param L21 m×k
204/// @param A22 m×NC
205/// @param L22 m×NC
206template <class T, index_t NC, index_t NR>
207[[gnu::flatten, gnu::hot]]
209 scalar_view<const T> A22, scalar_view<T> L22) noexcept {
210 using ops::sqrt;
211
213 const auto load_mask = datapar::generate_mask_until<simd, NC>();
214
215 /* Load diagonal block into registers */
216 simd Dr[NC];
217 UNROLL_FOR (index_t j = 0; j < NC; ++j) // column
218 Dr[j] = NC == simd::size() ? datapar::unaligned_load<simd>(&A22(0, j))
220 /* Accumulate previous updates */
221 for (index_t l = 0; l < k; ++l) { // syrk update diagonal block
222 simd L21l = NC == simd::size() ? datapar::unaligned_load<simd>(&L21(0, l))
224 UNROLL_FOR (index_t j = 0; j < NC; ++j)
225 Dr[j] -= L21l * L21l[j];
226 }
227
228 /* Cholesky factorization of diagonal block */
229 T inv_pivots[NC];
230 auto store_mask = load_mask;
231 UNROLL_FOR (index_t j = 0; j < NC; ++j) { // column
232 const T Djj = Dr[j][j];
233 BATMAT_ASSUME(Djj > T{});
234 const T pivot = sqrt(Djj);
235 const T inv_pivot = 1 / pivot;
236 inv_pivots[j] = inv_pivot;
237 Dr[j] *= inv_pivot; // update current column
238 UNROLL_FOR (index_t i = j + 1; i < NC; ++i) // column syrk
239 Dr[i] -= Dr[j] * Dr[j][i];
240#if BATMAT_WITH_GSI_HPC_SIMD
241 const auto mask_j = datapar::generate_mask<simd>(j);
242 Dr[j] = datapar::select(mask_j, simd{pivot}, Dr[j]);
243 datapar::masked_unaligned_store(Dr[j], store_mask, &L22(0, j));
244 store_mask = store_mask && !mask_j;
245#else
246 Dr[j][j] = pivot;
247 datapar::masked_unaligned_store(Dr[j], store_mask, &L22(0, j));
248 store_mask[j] = false;
249#endif
250 }
251
252 /* Multiply the sub-diagonal blocks by the inverse of the Cholesky factor */
253 auto trsm_tail = [&](auto &trsm_tail, index_t r, auto N) {
254 using simd = datapar::deduced_simd<T, N>;
255 for (; r + N <= m; r += N) { // block row
256 simd Xrx[NC];
257 UNROLL_FOR (index_t c = 0; c < NC; ++c) // column
258 Xrx[c] = datapar::unaligned_load<simd>(&A22(r, c));
259 for (index_t l = 0; l < k; ++l) { // syrk update subdiagonal block
260 simd L21rl = datapar::unaligned_load<simd>(&L21(r, l));
261 UNROLL_FOR (index_t j = 0; j < NC; ++j)
262 Xrx[j] -= L21rl * L21(j, l);
263 }
264 UNROLL_FOR (index_t j = 0; j < NC; ++j) { // column
265 simd &Xij = Xrx[j];
266 UNROLL_FOR (index_t i = 0; i < j; ++i) // column inner
267 Xij -= Dr[i][j] * Xrx[i];
268 Xij *= inv_pivots[j];
269 datapar::unaligned_store(Xij, &L22(r, j));
270 }
271 }
272 if constexpr (N > 1)
273 trsm_tail(trsm_tail, r, std::integral_constant<index_t, N / 2>());
274 };
275 trsm_tail(trsm_tail, NC, std::integral_constant<index_t, NR>());
276}
277
278template <class T, index_t R, index_t S>
280 view<T, datapar::scalar_abi<T>> L) noexcept {
281 static const constinit auto microkernel_lut =
284 });
285 (void)microkernel_lut; // Invalid GCC warning
286
287 const index_t m = L.rows(), N = L.cols();
288 BATMAT_ASSUME(m >= N);
289
291 scalar_view<T> L_ = L;
292
293 // Loop over columns of H with block size R.
295 0, N, index_constant<R>(),
296 [&](index_t i) {
297 auto L22 = L_.block(i, i);
298 auto A22 = A_.block(i, i);
299 auto L21 = L_.block(i, 0);
300 syrk_potrf_trsm_microkernel<T, R, S>(m - i, i, L21, A22, L22);
301 },
302 [&](index_t i, auto rem) {
303 auto L22 = L_.block(i, i);
304 auto A22 = A_.block(i, i);
305 auto L21 = L_.block(i, 0);
306 microkernel_lut[rem - 1](m - i, i, L21, A22, L22);
307 });
308}
309
310} // namespace batmat::linalg::micro_kernels::small_potrf
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
T rsqrt(T x)
Inverse square root.
Definition rsqrt.hpp:15
T sqrt(T x)
Square root.
Definition sqrt.hpp:15
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
Definition loop.hpp:20
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
void masked_unaligned_store(V v, typename V::mask_type m, typename V::value_type *p)
Definition simd.hpp:194
V unaligned_load(const typename V::value_type *p)
Definition simd.hpp:155
V partial_load(const typename V::value_type *p)
Definition simd.hpp:221
deduced_abi< Tp, 1 > scalar_abi
Definition simd.hpp:239
auto generate_mask()
Definition simd.hpp:199
void unaligned_store(V v, typename V::value_type *p)
Definition simd.hpp:165
simd< Tp, deduced_abi< Tp, Np > > deduced_simd
Definition simd.hpp:152
auto select(auto cond, auto t, auto f)
Definition simd.hpp:245
auto generate_mask_until()
Definition simd.hpp:213
uview< T, datapar::scalar_abi< std::remove_const_t< T > >, StorageOrder::ColMajor > scalar_view
void potrf_syrk_microkernel(index_t k, scalar_view< const T > L21, scalar_view< const T > A22, scalar_view< T > L22) noexcept
Outer product for updating the bottom right tail during Cholesky factorization.
void potrf_trsm_microkernel(index_t k, scalar_view< const T > A, scalar_view< T > L) noexcept
void small_potrf(view< const T, datapar::scalar_abi< T > > A, view< T, datapar::scalar_abi< T > > L, index_t n=-1) noexcept
void syrk_potrf_trsm_microkernel(index_t m, index_t k, scalar_view< const T > L21, scalar_view< const T > A22, scalar_view< T > L22) noexcept
Left-looking variant of small_potrf, which updates the current block with the outer product of the pr...
void small_potrf_left(view< const T, datapar::scalar_abi< T > > A, view< T, datapar::scalar_abi< T > > L) noexcept
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
int index_t
Definition config.hpp:13
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110