12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
16template <
class T, index_t NC>
17[[gnu::flatten, gnu::hot]]
22 T Dr[NC * (NC + 1) / 2];
23 static constexpr auto index = [](
index_t r,
index_t c) {
return c * (2 * NC - 1 - c) / 2 + r; };
27 Dr[index(i, j)] = A(i, j);
30 const auto pivot = sqrt(Dr[index(j, j)]);
31 const auto inv_pivot = rsqrt(Dr[index(j, j)]);
32 Dr[index(j, j)] = inv_pivot;
34 Dr[index(i, j)] *= inv_pivot;
36 const T fac = Dr[index(kk, j)];
38 Dr[index(i, kk)] -= Dr[index(i, j)] * fac;
42 L(i, j) = Dr[index(i, j)];
45 auto trsm_tail = [&](
auto &trsm_tail,
index_t r,
auto N) {
47 for (; r + N <= k; r += N) {
54 const T Aik = Dr[index(c, kk)];
57 Xij *= Dr[index(c, c)];
62 trsm_tail(trsm_tail, r, std::integral_constant<index_t, N / 2>());
64 trsm_tail(trsm_tail, NC, std::integral_constant<index_t, NR>());
70template <
class T, index_t RowsReg, index_t ColsReg>
71[[gnu::flatten, gnu::hot]]
80 T A21_reg[RowsReg][ColsReg];
83 A21_reg[i][j] = L21_cached(i, j);
88 A22ix[j] = A22_cached(i, j);
91 A22ix[j] -= A21_reg[i][kk] * A21_reg[j][kk];
93 L22_cached(i, j) = A22ix[j];
96 auto gemm_tail = [&](
auto &gemm_tail,
index_t i,
auto N) {
98 for (; i + N <= k; i += N) {
105 Aix[j] -= A21ik * A21_reg[j][kk];
111 gemm_tail(gemm_tail, i, std::integral_constant<index_t, N / 2>());
113 gemm_tail(gemm_tail, RowsReg, std::integral_constant<index_t, NR>());
116template <
class T, index_t R>
119 static const constinit auto microkernel_trsm_lut =
make_1d_lut<R>(
121 static const constinit auto microkernel_syrk_lut =
make_1d_lut<R>(
123 static const constinit auto microkernel_syrk_lut_2 =
127 (void)microkernel_syrk_lut;
128 (void)microkernel_syrk_lut_2;
130 const index_t m = L.rows(), N = L.cols();
134 BATMAT_ASSUME((n == m && m == N) || (n == N && m >= N) || (n < m && m == N));
148 microkernel_trsm_lut[r - 1](r + m - n, Aii, Lii);
152 auto L21 = Lii.middle_rows(r), L22 = Lii.block(r, r);
153 auto A22 = Aii.block(r, r);
157 auto Lj1 = L21.middle_rows(j), Ljj = L22.middle_cols(j);
158 auto Ajj = A22.middle_cols(j);
159 microkernel_syrk_lut_2[rem - 1][r - 1](m - n - j, Lj1, Ajj, Ljj);
169 process_bottom_right(A_, L_, n);
174 for (i = 0; i + R <= n; i += R) {
175 auto L11 = L_.
block(i, i);
176 auto A11 = i == 0 ? A_.
block(i, i) :
decltype(A_){L11};
184 auto L21 = L_.
block(j, i), L22 = L_.
block(j, j);
185 auto A22 = i == 0 ? A_.
block(j, j) :
decltype(A_){L22};
186 microkernel_syrk_lut[rem - 1](m - j, L21, A22, L22);
195 auto Lii = L_.
block(i, i);
196 auto Aii = i == 0 ? A_.
block(i, i) :
decltype(A_){Lii};
197 process_bottom_right(Aii, Lii, rem);
206template <
class T, index_t NC, index_t NR>
207[[gnu::flatten, gnu::hot]]
221 for (
index_t l = 0; l < k; ++l) {
225 Dr[j] -= L21l * L21l[j];
230 auto store_mask = load_mask;
232 const T Djj = Dr[j][j];
234 const T pivot = sqrt(Djj);
235 const T inv_pivot = 1 / pivot;
236 inv_pivots[j] = inv_pivot;
239 Dr[i] -= Dr[j] * Dr[j][i];
240#if BATMAT_WITH_GSI_HPC_SIMD
244 store_mask = store_mask && !mask_j;
248 store_mask[j] =
false;
253 auto trsm_tail = [&](
auto &trsm_tail,
index_t r,
auto N) {
255 for (; r + N <= m; r += N) {
259 for (
index_t l = 0; l < k; ++l) {
262 Xrx[j] -= L21rl * L21(j, l);
267 Xij -= Dr[i][j] * Xrx[i];
268 Xij *= inv_pivots[j];
273 trsm_tail(trsm_tail, r, std::integral_constant<index_t, N / 2>());
275 trsm_tail(trsm_tail, NC, std::integral_constant<index_t, NR>());
278template <
class T, index_t R, index_t S>
281 static const constinit auto microkernel_lut =
285 (void)microkernel_lut;
287 const index_t m = L.rows(), N = L.cols();
297 auto L22 = L_.
block(i, i);
298 auto A22 = A_.
block(i, i);
299 auto L21 = L_.
block(i, 0);
303 auto L22 = L_.
block(i, i);
304 auto A22 = A_.
block(i, i);
305 auto L21 = L_.
block(i, 0);
306 microkernel_lut[rem - 1](m - i, i, L21, A22, L22);
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
T rsqrt(T x)
Inverse square root.
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
consteval auto make_1d_lut(F f)
Returns an array of the form:
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
void masked_unaligned_store(V v, typename V::mask_type m, typename V::value_type *p)
V unaligned_load(const typename V::value_type *p)
V partial_load(const typename V::value_type *p)
deduced_abi< Tp, 1 > scalar_abi
void unaligned_store(V v, typename V::value_type *p)
simd< Tp, deduced_abi< Tp, Np > > deduced_simd
auto select(auto cond, auto t, auto f)
auto generate_mask_until()
uview< T, datapar::scalar_abi< std::remove_const_t< T > >, StorageOrder::ColMajor > scalar_view
void potrf_syrk_microkernel(index_t k, scalar_view< const T > L21, scalar_view< const T > A22, scalar_view< T > L22) noexcept
Outer product for updating the bottom right tail during Cholesky factorization.
void potrf_trsm_microkernel(index_t k, scalar_view< const T > A, scalar_view< T > L) noexcept
void small_potrf(view< const T, datapar::scalar_abi< T > > A, view< T, datapar::scalar_abi< T > > L, index_t n=-1) noexcept
void syrk_potrf_trsm_microkernel(index_t m, index_t k, scalar_view< const T > L21, scalar_view< const T > A22, scalar_view< T > L22) noexcept
Left-looking variant of small_potrf, which updates the current block with the outer product of the pr...
void small_potrf_left(view< const T, datapar::scalar_abi< T > > A, view< T, datapar::scalar_abi< T > > L) noexcept
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
std::integral_constant< index_t, I > index_constant
Self block(this const Self &self, index_t r, index_t c) noexcept