12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
16template <
class T, index_t NC>
17[[gnu::flatten, gnu::hot]]
22 T Dr[NC * (NC + 1) / 2];
23 static constexpr auto index = [](
index_t r,
index_t c) {
return c * (2 * NC - 1 - c) / 2 + r; };
27 Dr[index(i, j)] = A(i, j);
30 const auto pivot = sqrt(Dr[index(j, j)]);
31 const auto inv_pivot = rsqrt(Dr[index(j, j)]);
32 Dr[index(j, j)] = inv_pivot;
34 Dr[index(i, j)] *= inv_pivot;
36 const T fac = Dr[index(kk, j)];
38 Dr[index(i, kk)] -= Dr[index(i, j)] * fac;
42 L(i, j) = Dr[index(i, j)];
45 auto trsm_tail = [&](
auto &trsm_tail,
index_t r,
auto N) {
47 for (; r + N <= k; r += N) {
54 const T Aik = Dr[index(c, kk)];
57 Xij *= Dr[index(c, c)];
62 trsm_tail(trsm_tail, r, std::integral_constant<index_t, N / 2>());
64 trsm_tail(trsm_tail, NC, std::integral_constant<index_t, NR>());
70template <
class T, index_t RowsReg, index_t ColsReg>
71[[gnu::flatten, gnu::hot]]
80 T A21_reg[RowsReg][ColsReg];
83 A21_reg[i][j] = L21_cached(i, j);
88 A22ix[j] = A22_cached(i, j);
91 A22ix[j] -= A21_reg[i][kk] * A21_reg[j][kk];
93 L22_cached(i, j) = A22ix[j];
96 auto gemm_tail = [&](
auto &gemm_tail,
index_t i,
auto N) {
98 using simd_index_t =
decltype(simd::size());
99 for (; i + N <= k; i += N) {
106 Aix[j] -= A21ik * A21_reg[j][
static_cast<simd_index_t
>(kk)];
112 gemm_tail(gemm_tail, i, std::integral_constant<index_t, N / 2>());
114 gemm_tail(gemm_tail, RowsReg, std::integral_constant<index_t, NR>());
117template <
class T, index_t R>
120 static const constinit auto microkernel_trsm_lut =
make_1d_lut<R>(
122 static const constinit auto microkernel_syrk_lut =
make_1d_lut<R>(
124 static const constinit auto microkernel_syrk_lut_2 =
128 (void)microkernel_syrk_lut;
129 (void)microkernel_syrk_lut_2;
131 const index_t m = L.rows(), N = L.cols();
135 BATMAT_ASSUME((n == m && m == N) || (n == N && m >= N) || (n < m && m == N));
149 microkernel_trsm_lut[r - 1](r + m - n, Aii, Lii);
153 auto L21 = Lii.middle_rows(r), L22 = Lii.block(r, r);
154 auto A22 = Aii.block(r, r);
158 auto Lj1 = L21.middle_rows(j), Ljj = L22.middle_cols(j);
159 auto Ajj = A22.middle_cols(j);
160 microkernel_syrk_lut_2[rem - 1][r - 1](m - n - j, Lj1, Ajj, Ljj);
170 process_bottom_right(A_, L_, n);
175 for (i = 0; i + R <= n; i += R) {
176 auto L11 = L_.
block(i, i);
177 auto A11 = i == 0 ? A_.
block(i, i) :
decltype(A_){L11};
185 auto L21 = L_.
block(j, i), L22 = L_.
block(j, j);
186 auto A22 = i == 0 ? A_.
block(j, j) :
decltype(A_){L22};
187 microkernel_syrk_lut[rem - 1](m - j, L21, A22, L22);
196 auto Lii = L_.
block(i, i);
197 auto Aii = i == 0 ? A_.
block(i, i) :
decltype(A_){Lii};
198 process_bottom_right(Aii, Lii, rem);
207template <
class T, index_t NC, index_t NR>
208[[gnu::flatten, gnu::hot]]
214 using simd_index_t =
decltype(simd::size());
223 for (
index_t l = 0; l < k; ++l) {
227 Dr[j] -= L21l * L21l[
static_cast<simd_index_t
>(j)];
232 auto store_mask = load_mask;
234 const T Djj = Dr[j][
static_cast<simd_index_t
>(j)];
236 const T pivot = sqrt(Djj);
237 const T inv_pivot = 1 / pivot;
238 inv_pivots[j] = inv_pivot;
241 Dr[i] -= Dr[j] * Dr[j][
static_cast<simd_index_t
>(i)];
242#if BATMAT_WITH_GSI_HPC_SIMD
246 store_mask = store_mask && !mask_j;
248 Dr[j][
static_cast<simd_index_t
>(j)] = pivot;
250 store_mask[
static_cast<simd_index_t
>(j)] =
false;
255 auto trsm_tail = [&](
auto &trsm_tail,
index_t r,
auto N) {
257 for (; r + N <= m; r += N) {
261 for (
index_t l = 0; l < k; ++l) {
264 Xrx[j] -= L21rl * L21(j, l);
269 Xij -= Dr[i][
static_cast<simd_index_t
>(j)] * Xrx[i];
270 Xij *= inv_pivots[j];
275 trsm_tail(trsm_tail, r, std::integral_constant<index_t, N / 2>());
277 trsm_tail(trsm_tail, NC, std::integral_constant<index_t, NR>());
280template <
class T, index_t R, index_t S>
283 static const constinit auto microkernel_lut =
287 (void)microkernel_lut;
289 const index_t m = L.rows(), N = L.cols();
299 auto L22 = L_.
block(i, i);
300 auto A22 = A_.
block(i, i);
301 auto L21 = L_.
block(i, 0);
305 auto L22 = L_.
block(i, i);
306 auto A22 = A_.
block(i, i);
307 auto L21 = L_.
block(i, 0);
308 microkernel_lut[rem - 1](m - i, i, L21, A22, L22);
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
T rsqrt(T x)
Inverse square root.
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
consteval auto make_1d_lut(F f)
Returns an array of the form:
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
void masked_unaligned_store(V v, typename V::mask_type m, typename V::value_type *p)
V unaligned_load(const typename V::value_type *p)
V partial_load(const typename V::value_type *p)
deduced_abi< Tp, 1 > scalar_abi
void unaligned_store(V v, typename V::value_type *p)
simd< Tp, deduced_abi< Tp, Np > > deduced_simd
auto select(auto cond, auto t, auto f)
auto generate_mask_until()
uview< T, datapar::scalar_abi< std::remove_const_t< T > >, StorageOrder::ColMajor > scalar_view
void potrf_syrk_microkernel(index_t k, scalar_view< const T > L21, scalar_view< const T > A22, scalar_view< T > L22) noexcept
Outer product for updating the bottom right tail during Cholesky factorization.
void potrf_trsm_microkernel(index_t k, scalar_view< const T > A, scalar_view< T > L) noexcept
void small_potrf(view< const T, datapar::scalar_abi< T > > A, view< T, datapar::scalar_abi< T > > L, index_t n=-1) noexcept
void syrk_potrf_trsm_microkernel(index_t m, index_t k, scalar_view< const T > L21, scalar_view< const T > A22, scalar_view< T > L22) noexcept
Left-looking variant of small_potrf, which updates the current block with the outer product of the pr...
void small_potrf_left(view< const T, datapar::scalar_abi< T > > A, view< T, datapar::scalar_abi< T > > L) noexcept
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
std::integral_constant< index_t, I > index_constant
Self block(this const Self &self, index_t r, index_t c) noexcept