batmat 0.0.24
Batched linear algebra routines
Loading...
Searching...
No Matches
small-potrf.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
5#include <batmat/loop.hpp>
6#include <batmat/lut.hpp>
8#include <batmat/ops/sqrt.hpp>
9#include <batmat/simd.hpp>
10#include <bit>
11
12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
13
15
16template <class T, index_t NC> // number of columns to handle at once
17[[gnu::flatten, gnu::hot]]
19 using ops::rsqrt;
20 using ops::sqrt;
21 constexpr index_t NR = 8; // Number of rows in each sub-diagonal block
22 T Dr[NC * (NC + 1) / 2];
23 static constexpr auto index = [](index_t r, index_t c) { return c * (2 * NC - 1 - c) / 2 + r; };
24 /* Load diagonal block into (scalar) registers */
25 UNROLL_FOR (index_t j = 0; j < NC; ++j) // column
26 UNROLL_FOR (index_t i = j; i < NC; ++i) // row
27 Dr[index(i, j)] = A(i, j);
28 /* Cholesky factorization of diagonal block */
29 UNROLL_FOR (index_t j = 0; j < NC; ++j) { // column
30 const auto pivot = sqrt(Dr[index(j, j)]);
31 const auto inv_pivot = rsqrt(Dr[index(j, j)]);
32 Dr[index(j, j)] = inv_pivot;
33 UNROLL_FOR (index_t i = j + 1; i < NC; ++i)
34 Dr[index(i, j)] *= inv_pivot;
35 UNROLL_FOR (index_t kk = j + 1; kk < NC; ++kk) { // column syrk
36 const T fac = Dr[index(kk, j)];
37 UNROLL_FOR (index_t i = kk; i < NC; ++i)
38 Dr[index(i, kk)] -= Dr[index(i, j)] * fac;
39 }
40 L(j, j) = pivot;
41 UNROLL_FOR (index_t i = j + 1; i < NC; ++i) // row
42 L(i, j) = Dr[index(i, j)];
43 }
44 /* Multiply the sub-diagonal blocks by the inverse of the Cholesky factor */
45 auto trsm_tail = [&](auto &trsm_tail, index_t r, auto N) {
46 using simdN = datapar::deduced_simd<T, N>;
47 for (; r + N <= k; r += N) { // block row
48 simdN Xrx[NC];
49 UNROLL_FOR (index_t c = 0; c < NC; ++c) // column
50 Xrx[c] = datapar::unaligned_load<simdN>(&A(r, c));
51 UNROLL_FOR (index_t c = 0; c < NC; ++c) { // column
52 simdN &Xij = Xrx[c];
53 UNROLL_FOR (index_t kk = 0; kk < c; ++kk) { // column inner
54 const T Aik = Dr[index(c, kk)];
55 Xij -= Aik * Xrx[kk];
56 }
57 Xij *= Dr[index(c, c)];
58 datapar::unaligned_store(Xij, &L(r, c));
59 }
60 }
61 if constexpr (N > 1)
62 trsm_tail(trsm_tail, r, std::integral_constant<index_t, N / 2>());
63 };
64 trsm_tail(trsm_tail, NC, std::integral_constant<index_t, NR>());
65}
66
67/// Outer product for updating the bottom right tail during Cholesky factorization.
68/// @param A21 rows×ColsReg
69/// @param A22 rows×RowsReg
70template <class T, index_t RowsReg, index_t ColsReg>
71[[gnu::flatten, gnu::hot]]
73 scalar_view<T> L22) noexcept {
74 constexpr index_t NR = 8; // Number of rows in each sub-diagonal block
75 // Pre-compute the offsets of the columns of A21 and A22
76 auto L21_cached = with_cached_access<0, ColsReg>(L21);
77 auto A22_cached = with_cached_access<0, RowsReg>(A22);
78 auto L22_cached = with_cached_access<0, RowsReg>(L22);
79 // Load matrix into registers
80 T A21_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
81 UNROLL_FOR (index_t i = 0; i < RowsReg; ++i)
82 UNROLL_FOR (index_t j = 0; j < ColsReg; ++j)
83 A21_reg[i][j] = L21_cached(i, j);
84 // Matrix multiplication of diagonal block
85 UNROLL_FOR (index_t i = 0; i < RowsReg; ++i) {
86 T A22ix[RowsReg];
87 UNROLL_FOR (index_t j = 0; j <= i; ++j)
88 A22ix[j] = A22_cached(i, j);
89 UNROLL_FOR (index_t j = 0; j <= i; ++j)
90 UNROLL_FOR (index_t kk = 0; kk < ColsReg; ++kk)
91 A22ix[j] -= A21_reg[i][kk] * A21_reg[j][kk];
92 UNROLL_FOR (index_t j = 0; j <= i; ++j)
93 L22_cached(i, j) = A22ix[j];
94 }
95 // Matrix multiplication of sub-diagonal block
96 auto gemm_tail = [&](auto &gemm_tail, index_t i, auto N) {
97 using simd = datapar::deduced_simd<T, N>;
98 using simd_index_t = decltype(simd::size());
99 for (; i + N <= k; i += N) { // block row
100 simd Aix[RowsReg];
101 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j)
102 Aix[j] = datapar::unaligned_load<simd>(&A22_cached(i, j));
103 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j)
104 UNROLL_FOR (index_t kk = 0; kk < ColsReg; ++kk) {
105 const simd A21ik = datapar::unaligned_load<simd>(&L21_cached(i, kk));
106 Aix[j] -= A21ik * A21_reg[j][static_cast<simd_index_t>(kk)];
107 }
108 UNROLL_FOR (index_t j = 0; j < RowsReg; ++j)
109 datapar::unaligned_store(Aix[j], &L22_cached(i, j));
110 }
111 if constexpr (N > 1)
112 gemm_tail(gemm_tail, i, std::integral_constant<index_t, N / 2>());
113 };
114 gemm_tail(gemm_tail, RowsReg, std::integral_constant<index_t, NR>());
115}
116
117template <class T, index_t R>
119 index_t n) noexcept {
120 static const constinit auto microkernel_trsm_lut = make_1d_lut<R>(
122 static const constinit auto microkernel_syrk_lut = make_1d_lut<R>(
124 static const constinit auto microkernel_syrk_lut_2 =
127 });
128 (void)microkernel_syrk_lut; // Invalid GCC warning
129 (void)microkernel_syrk_lut_2;
130
131 const index_t m = L.rows(), N = L.cols();
132 if (n < 0)
133 n = N;
134 BATMAT_ASSUME(m >= N);
135 BATMAT_ASSUME((n == m && m == N) || (n == N && m >= N) || (n < m && m == N));
136
138 scalar_view<T> L_ = L;
139
140 // Compute the Cholesky factorization of the very last block (right before
141 // the Schur complement block), which has size r×r rather than R×R.
142 // If requested, also update the rows below the Cholesky factor, and the
143 // Schur complement to the bottom right of the given block.
144 // These extra blocks are always sizes (m-n)×r and (m-n)×(m-n) respectively.
145 const auto process_bottom_right = [m, N, n](scalar_view<const T> Aii, scalar_view<T> Lii,
146 index_t r) {
147 // Cholesky of last block to be factorized + triangular solve with
148 // sub-diagonal block.
149 microkernel_trsm_lut[r - 1](r + m - n, Aii, Lii);
150 // Update the Schur complement (bottom right) with the outer product
151 // of the sub-diagonal block column.
152 if (n < N) {
153 auto L21 = Lii.middle_rows(r), L22 = Lii.block(r, r);
154 auto A22 = Aii.block(r, r);
156 0, m - n, index_constant<R>(),
157 [&](index_t j, auto rem) {
158 auto Lj1 = L21.middle_rows(j), Ljj = L22.middle_cols(j);
159 auto Ajj = A22.middle_cols(j);
160 microkernel_syrk_lut_2[rem - 1][r - 1](m - n - j, Lj1, Ajj, Ljj);
161 },
163 }
164 };
165
166 // Base case
167 if (n == 0) {
168 return;
169 } else if (n <= R) {
170 process_bottom_right(A_, L_, n);
171 return;
172 }
173 // Loop over columns of H with block size R.
174 index_t i;
175 for (i = 0; i + R <= n; i += R) {
176 auto L11 = L_.block(i, i);
177 auto A11 = i == 0 ? A_.block(i, i) : decltype(A_){L11};
178 // Factor the diagonal block and update the subdiagonal block
179 potrf_trsm_microkernel<T, R>(m - i, A11, L11);
180 // Update the Schur complement (bottom right) with the outer product of
181 // the subdiagonal block.
183 i + R, N, index_constant<R>(),
184 [&](index_t j, auto rem) {
185 auto L21 = L_.block(j, i), L22 = L_.block(j, j);
186 auto A22 = i == 0 ? A_.block(j, j) : decltype(A_){L22};
187 microkernel_syrk_lut[rem - 1](m - j, L21, A22, L22);
188 },
190 // Loop backwards for cache locality (we'll use the next column in the
191 // next interation, so we want the syrk operation to leave it in cache).
192 // TODO: verify in benchmark.
193 }
194 const index_t rem = n - i;
195 if (rem > 0) {
196 auto Lii = L_.block(i, i);
197 auto Aii = i == 0 ? A_.block(i, i) : decltype(A_){Lii};
198 process_bottom_right(Aii, Lii, rem);
199 }
200}
201
202/// Left-looking variant of small_potrf, which updates the current block with the outer product of
203/// the previously computed part L21.
204/// @param L21 m×k
205/// @param A22 m×NC
206/// @param L22 m×NC
207template <class T, index_t NC, index_t NR>
208[[gnu::flatten, gnu::hot]]
210 scalar_view<const T> A22, scalar_view<T> L22) noexcept {
211 using ops::sqrt;
212
214 using simd_index_t = decltype(simd::size());
215 const auto load_mask = datapar::generate_mask_until<simd, NC>();
216
217 /* Load diagonal block into registers */
218 simd Dr[NC];
219 UNROLL_FOR (index_t j = 0; j < NC; ++j) // column
220 Dr[j] = NC == simd::size() ? datapar::unaligned_load<simd>(&A22(0, j))
222 /* Accumulate previous updates */
223 for (index_t l = 0; l < k; ++l) { // syrk update diagonal block
224 simd L21l = NC == simd::size() ? datapar::unaligned_load<simd>(&L21(0, l))
226 UNROLL_FOR (index_t j = 0; j < NC; ++j)
227 Dr[j] -= L21l * L21l[static_cast<simd_index_t>(j)];
228 }
229
230 /* Cholesky factorization of diagonal block */
231 T inv_pivots[NC];
232 auto store_mask = load_mask;
233 UNROLL_FOR (index_t j = 0; j < NC; ++j) { // column
234 const T Djj = Dr[j][static_cast<simd_index_t>(j)];
235 BATMAT_ASSUME(Djj > T{});
236 const T pivot = sqrt(Djj);
237 const T inv_pivot = 1 / pivot;
238 inv_pivots[j] = inv_pivot;
239 Dr[j] *= inv_pivot; // update current column
240 UNROLL_FOR (index_t i = j + 1; i < NC; ++i) // column syrk
241 Dr[i] -= Dr[j] * Dr[j][static_cast<simd_index_t>(i)];
242#if BATMAT_WITH_GSI_HPC_SIMD
243 const auto mask_j = datapar::generate_mask<simd>(static_cast<simd_index_t>(j));
244 Dr[j] = datapar::select(mask_j, simd{pivot}, Dr[j]);
245 datapar::masked_unaligned_store(Dr[j], store_mask, &L22(0, j));
246 store_mask = store_mask && !mask_j;
247#else
248 Dr[j][static_cast<simd_index_t>(j)] = pivot;
249 datapar::masked_unaligned_store(Dr[j], store_mask, &L22(0, j));
250 store_mask[static_cast<simd_index_t>(j)] = false;
251#endif
252 }
253
254 /* Multiply the sub-diagonal blocks by the inverse of the Cholesky factor */
255 auto trsm_tail = [&](auto &trsm_tail, index_t r, auto N) {
256 using simdN = datapar::deduced_simd<T, N>;
257 for (; r + N <= m; r += N) { // block row
258 simdN Xrx[NC];
259 UNROLL_FOR (index_t c = 0; c < NC; ++c) // column
260 Xrx[c] = datapar::unaligned_load<simdN>(&A22(r, c));
261 for (index_t l = 0; l < k; ++l) { // syrk update subdiagonal block
262 simdN L21rl = datapar::unaligned_load<simdN>(&L21(r, l));
263 UNROLL_FOR (index_t j = 0; j < NC; ++j)
264 Xrx[j] -= L21rl * L21(j, l);
265 }
266 UNROLL_FOR (index_t j = 0; j < NC; ++j) { // column
267 simdN &Xij = Xrx[j];
268 UNROLL_FOR (index_t i = 0; i < j; ++i) // column inner
269 Xij -= Dr[i][static_cast<simd_index_t>(j)] * Xrx[i];
270 Xij *= inv_pivots[j];
271 datapar::unaligned_store(Xij, &L22(r, j));
272 }
273 }
274 if constexpr (N > 1)
275 trsm_tail(trsm_tail, r, std::integral_constant<index_t, N / 2>());
276 };
277 trsm_tail(trsm_tail, NC, std::integral_constant<index_t, NR>());
278}
279
280template <class T, index_t R, index_t S>
282 view<T, datapar::scalar_abi<T>> L) noexcept {
283 static const constinit auto microkernel_lut =
286 });
287 (void)microkernel_lut; // Invalid GCC warning
288
289 const index_t m = L.rows(), N = L.cols();
290 BATMAT_ASSUME(m >= N);
291
293 scalar_view<T> L_ = L;
294
295 // Loop over columns of H with block size R.
297 0, N, index_constant<R>(),
298 [&](index_t i) {
299 auto L22 = L_.block(i, i);
300 auto A22 = A_.block(i, i);
301 auto L21 = L_.block(i, 0);
302 syrk_potrf_trsm_microkernel<T, R, S>(m - i, i, L21, A22, L22);
303 },
304 [&](index_t i, auto rem) {
305 auto L22 = L_.block(i, i);
306 auto A22 = A_.block(i, i);
307 auto L21 = L_.block(i, 0);
308 microkernel_lut[rem - 1](m - i, i, L21, A22, L22);
309 });
310}
311
312} // namespace batmat::linalg::micro_kernels::small_potrf
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
T rsqrt(T x)
Inverse square root.
Definition rsqrt.hpp:15
T sqrt(T x)
Square root.
Definition sqrt.hpp:15
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
Definition loop.hpp:20
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
void masked_unaligned_store(V v, typename V::mask_type m, typename V::value_type *p)
Definition simd.hpp:194
V unaligned_load(const typename V::value_type *p)
Definition simd.hpp:155
V partial_load(const typename V::value_type *p)
Definition simd.hpp:221
deduced_abi< Tp, 1 > scalar_abi
Definition simd.hpp:239
auto generate_mask()
Definition simd.hpp:199
void unaligned_store(V v, typename V::value_type *p)
Definition simd.hpp:165
simd< Tp, deduced_abi< Tp, Np > > deduced_simd
Definition simd.hpp:152
auto select(auto cond, auto t, auto f)
Definition simd.hpp:245
auto generate_mask_until()
Definition simd.hpp:213
uview< T, datapar::scalar_abi< std::remove_const_t< T > >, StorageOrder::ColMajor > scalar_view
void potrf_syrk_microkernel(index_t k, scalar_view< const T > L21, scalar_view< const T > A22, scalar_view< T > L22) noexcept
Outer product for updating the bottom right tail during Cholesky factorization.
void potrf_trsm_microkernel(index_t k, scalar_view< const T > A, scalar_view< T > L) noexcept
void small_potrf(view< const T, datapar::scalar_abi< T > > A, view< T, datapar::scalar_abi< T > > L, index_t n=-1) noexcept
void syrk_potrf_trsm_microkernel(index_t m, index_t k, scalar_view< const T > L21, scalar_view< const T > A22, scalar_view< T > L22) noexcept
Left-looking variant of small_potrf, which updates the current block with the outer product of the pr...
void small_potrf_left(view< const T, datapar::scalar_abi< T > > A, view< T, datapar::scalar_abi< T > > L) noexcept
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
int index_t
Definition config.hpp:13
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110