26 const index_t k1,
const index_t k2, T regularization)
noexcept {
35 const auto index = [](index_t r, index_t c) {
return c * (2 *
RowsReg - 1 - c) / 2 + r; };
38 C_reg[index(ii, jj)] = C_cached.load(ii, jj);
39 C_reg[index(ii, ii)] = C_cached.load(ii, ii) + simd(regularization);
43 for (index_t l = 0; l < k1; ++l)
45 simd Ail = A1_cached.load(ii, l);
47 simd &Cij = C_reg[index(ii, jj)];
48 simd Blj = A1_cached.load(jj, l);
49 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
53 for (index_t l = 0; l < k2; ++l)
55 simd Ail = A2_cached.load(ii, l);
57 simd &Cij = C_reg[index(ii, jj)];
58 simd Blj = A2_cached.load(jj, l);
66 C_reg[index(j, j)] -= C_reg[index(j, k)] * C_reg[index(j, k)];
67 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
68 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
72 C_reg[index(i, j)] -= C_reg[index(i, k)] * C_reg[index(j, k)];
73 C_reg[index(i, j)] = inv_pivot * C_reg[index(i, j)];
79 simd inv_pivot = rsqrt(C_reg[index(j, j)]);
80 C_reg[index(j, j)] = sqrt(C_reg[index(j, j)]);
83 C_reg[index(i, j)] *= inv_pivot;
86 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
90 simd inv_pivot = rsqrt(C_reg[index(0, 0)]);
91 C_reg[index(0, 0)] = sqrt(C_reg[index(0, 0)]);
95 C_reg[index(i, j)] *= inv_pivot;
98 C_reg[index(i, k)] -= C_reg[index(i, j)] * C_reg[index(k, j)];
99 if (k == j + 1 && i == j + 1) {
100 inv_pivot = rsqrt(C_reg[index(j + 1, j + 1)]);
101 C_reg[index(j + 1, j + 1)] = sqrt(C_reg[index(j + 1, j + 1)]);
110 D_cached.store(C_reg[index(ii, jj)], ii, jj);
130 const index_t k1,
const index_t k2)
noexcept {
141 C_reg[ii][jj] = C_cached.load(ii, jj);
145 for (index_t l = 0; l < k1; ++l)
147 simd Ail = A1_cached.load(ii, l);
149 simd &Cij = C_reg[ii][jj];
150 simd Blj = B1_cached.load(jj, l);
151 Conf.negate_A ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
156 for (index_t l = 0; l < k2; ++l)
158 simd Ail = A2_cached.load(ii, l);
160 simd &Cij = C_reg[ii][jj];
161 simd Blj = B2_cached.load(jj, l);
168 simd &Xij = C_reg[jj][ii];
171 simd Aik = L.load(ii, kk);
172 simd Xkj = C_reg[jj][kk];
181 D_cached.store(C_reg[ii][jj], ii, jj);
191 const index_t I = C.rows(), K = A.cols(), J = C.cols();
200 (void)trsm_microkernel;
208 if (I <= Rows && J <= Cols && I == J)
209 return potrf_microkernel[J - 1](A_, C_, C_, D_, invD, K, 0, regularization);
212 0, J, Cols, [&](index_t j,
auto nj) {
215 const auto Djj = D_.
block(j, j);
217 potrf_microkernel[nj - 1](Aj, Dj, C_.
block(j, j), Djj, invD, K, j, regularization);
219 j + nj, I, Rows, [&](index_t i,
auto ni) {
222 const auto Cij = C_.
block(i, j);
223 const auto Dij = D_.
block(i, j);
225 trsm_microkernel[ni - 1][nj - 1](Ai, Aj, Di, Dj, Djj, invD, Cij, Dij, K, j);
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
void potrf_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, T *invD, index_t k1, index_t k2, T regularization) noexcept
void trsm_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O1 > B1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > B2, uview< const T, Abi, O2 > L, const T *invL, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, index_t k1, index_t k2) noexcept