11#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
15template <KernelConfig Conf>
17 if constexpr (Conf.with_diag())
20 return std::false_type{};
23template <KernelConfig Conf>
26 if constexpr (Conf.diag_A == Conf.diag_sign_only)
28 else if constexpr (Conf.diag_A == Conf.diag)
42template <
class T,
class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder O1, StorageOrder O2>
43[[gnu::hot, gnu::flatten]]
void
46 const index_t k1,
const index_t k2, T regularization,
56 const auto index = [](index_t r, index_t c) {
return c * (2 *
RowsReg - 1 - c) / 2 + r; };
59 C_reg[index(ii, jj)] = C_cached.load(ii, jj);
60 C_reg[index(ii, ii)] = C_cached.load(ii, ii) + simd(regularization);
64 for (index_t l = 0; l < k1; ++l) {
69 simd &Cij = C_reg[index(ii, jj)];
70 simd Blj = A1_cached.load(jj, l);
71 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
76 for (index_t l = 0; l < k2; ++l)
78 simd Ail = A2_cached.load(ii, l);
80 simd &Cij = C_reg[index(ii, jj)];
81 simd Blj = A2_cached.load(jj, l);
89 C_reg[index(j, j)] -= C_reg[index(j, k)] * C_reg[index(j, k)];
90 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
91 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
95 C_reg[index(i, j)] -= C_reg[index(i, k)] * C_reg[index(j, k)];
96 C_reg[index(i, j)] = inv_pivot * C_reg[index(i, j)];
102 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
103 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
106 C_reg[index(i, j)] *= inv_pivot;
109 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
113 simd inv_pivot = rsqrt(C_reg[index(0, 0)]);
114 C_reg[index(0, 0)] = sqrt(C_reg[index(0, 0)]);
118 C_reg[index(i, j)] *= inv_pivot;
121 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
122 if (k == j + 1 && i == j + 1) {
123 inv_pivot = rsqrt(C_reg[index(j + 1, j + 1)]);
124 C_reg[index(j + 1, j + 1)] = sqrt(C_reg[index(j + 1, j + 1)]);
133 D_cached.store(C_reg[index(ii, jj)], ii, jj);
148[[gnu::hot, gnu::flatten]]
153 const index_t k1,
const index_t k2,
165 C_reg[ii][jj] = C_cached.load(ii, jj);
169 for (index_t l = 0; l < k1; ++l) {
174 simd &Cij = C_reg[ii][jj];
175 simd Blj = B1_cached.load(jj, l);
176 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
182 for (index_t l = 0; l < k2; ++l)
184 simd Ail = A2_cached.load(ii, l);
186 simd &Cij = C_reg[ii][jj];
187 simd Blj = B2_cached.load(jj, l);
194 simd &Xij = C_reg[jj][ii];
197 simd Aik = L.load(ii, kk);
198 simd Xkj = C_reg[jj][kk];
207 D_cached.store(C_reg[ii][jj], ii, jj);
210template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OCD>
218 const index_t I = C.rows(), K = A.cols(), J = C.cols();
223 if constexpr (Conf.with_diag()) {
231 (void)trsm_microkernel;
240 if (I <= Rows && J <= Cols && I == J)
241 return potrf_microkernel[J - 1](A_, C_, C_, D_, invD, K, 0, regularization, d_);
244 0, J, Cols, [&](index_t j,
auto nj) {
247 const auto Djj = D_.
block(j, j);
249 potrf_microkernel[nj - 1](Aj, Dj, C_.
block(j, j), Djj, invD, K, j, regularization, d_);
251 j + nj, I, Rows, [&](index_t i,
auto ni) {
254 const auto Cij = C_.
block(i, j);
255 const auto Dij = D_.
block(i, j);
257 trsm_microkernel[ni - 1][nj - 1](Ai, Aj, Di, Dj, Djj, invD, Cij, Dij, K, j, d_);
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
T rsqrt(T x)
Inverse square root.
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
void aligned_store(V v, typename V::value_type *p)
stdx::memory_alignment< simd< Tp, Abi > > simd_align
stdx::simd_size< Tp, Abi > simd_size
V aligned_load(const typename V::value_type *p)
stdx::simd< Tp, Abi > simd
constexpr index_t ColsReg
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
auto apply_diag(auto x, auto d) noexcept
const constinit auto potrf_copy_lut
std::conditional_t< Conf.with_diag(), uview_vec< T, Abi >, std::false_type > diag_uview_type
void trsm_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O1 > B1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > B2, uview< const T, Abi, O2 > L, const T *invL, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, index_t k1, index_t k2, diag_uview_type< const T, Abi, Conf > diag) noexcept
auto load_diag(auto diag, index_t l) noexcept
void potrf_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OCD > C, view< T, Abi, OCD > D, T regularization, diag_view_type< const T, Abi, Conf > diag) noexcept
const constinit auto trsm_copy_lut
std::conditional_t< Conf.with_diag(), view< T, Abi >, std::false_type > diag_view_type
void potrf_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, T *invD, index_t k1, index_t k2, T regularization, diag_uview_type< const T, Abi, Conf > diag) noexcept
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
T cneg(T x, T signs)
Conditionally negates the sign bit of x, depending on signs, which should contain only ±0 (i....
Self block(this const Self &self, index_t r, index_t c) noexcept
Self middle_rows(this const Self &self, index_t r) noexcept