9#include <guanaqo/trace.hpp>
13#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
17template <
class T,
class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
18[[gnu::hot, gnu::flatten]]
void
33 for (index_t l = 0; l < kA; ++l) {
34 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0))
35 : A.load(j, l) * diag.load(l, 0);
37 bb[i] += A.load(i, l) * Ajl;
40 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
41 const simd L̃jj = copysign(sqrt(Ljj * Ljj + α2), Ljj), β = Ljj + L̃jj;
42 simd γoβ = simd{2} * β / (β * β + α2), γ = β * γoβ, inv_β = simd{1} / β;
43 L_cached.store(L̃jj, j, j);
46 simd Lij = L_cached.load(i, j);
47 bb[i] = γ * Lij + bb[i] * γoβ;
48 L_cached.store(bb[i] - Lij, i, j);
51 for (index_t l = 0; l < kA; ++l) {
52 simd Ajl = A.load(j, l) * inv_β;
55 simd Ail = A.load(i, l);
70template <
class T,
class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
71[[gnu::hot, gnu::flatten]]
void
85 for (index_t l = 0; l < kA; ++l) {
86 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0))
87 : A.load(j, l) * diag.load(l, 0);
89 bb[i] += A.load(i, l) * Ajl;
92 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
93 const simd L̃jj = copysign(sqrt(Ljj * Ljj + α2), Ljj), β = Ljj + L̃jj;
94 simd γoβ = simd{2} * β / (β * β + α2), γ = β * γoβ, inv_β = simd{1} / β;
95 L_cached.store(L̃jj, j, j);
98 simd Lij = L_cached.load(i, j);
99 bb[i] = γ * Lij + bb[i] * γoβ;
100 L_cached.store(bb[i] - Lij, i, j);
103 for (index_t l = 0; l < kA; ++l) {
104 simd Ajl = A.load(j, l) * inv_β;
107 simd Ail = A.load(i, l);
117template <
class T,
class Abi,
int S>
123template <
class T,
class Abi>
134template <
class T,
class Abi, KernelConfig Conf, index_t R, index_t S,
StorageOrder OL,
137 index_t kA_in_offset, index_t kA_in, index_t k,
147 for (index_t lA = 0; lA < kA_in; ++lA) {
148 index_t lB = lA + kA_in_offset;
150 auto Bjl = Conf.sign_only ? cneg(B.load(j, lB), diag.load(lB, 0))
151 : B.load(j, lB) * diag.load(lB, 0);
153 V[i][j] += A_in.load(i, lA) * Bjl;
165 Wj[i] = W.load(i, j);
167 simd Lij = L_cached.load(i, j);
170 V[i][j] -= V[i][l] * Wj[l];
171 V[i][j] *= W.load(j, j);
173 L_cached.store(Lij, i, j);
181 Wj[i] = W.load(i, j);
184 V[i][j] -= V[i][l] * Wj[l];
185 V[i][j] *= W.load(j, j);
193 Wj[i] = W.load(i, j);
197 Lij = L_cached.load(i, j);
201 V[i][j] -= V[i][l] * Wj[l];
202 V[i][j] *= W.load(j, j);
205 L_cached.store(Lij, i, j);
213 const auto update_A = [&] [[gnu::always_inline]] (
auto s) {
215 for (index_t lB = 0; lB < kA_in_offset; ++lB) [[unlikely]] {
217 Bjl[j] = B.load(j, lB);
221 Ail -= V[i][j] * Bjl[j];
225 for (index_t lB = kA_in_offset + kA_in; lB < k; ++lB) [[unlikely]] {
227 Bjl[j] = B.load(j, lB);
231 Ail -= V[i][j] * Bjl[j];
235 for (index_t lA = 0; lA < kA_in; ++lA) [[likely]] {
236 index_t lB = lA + kA_in_offset;
238 Bjl[j] = B.load(j, lB);
240 auto Ail = A_in.load(i, lA);
242 Ail -= V[i][j] * Bjl[j];
247#if defined(__AVX512F__) && 0
252 update_A(std::integral_constant<int, 0>{});
254 case -1: update_A(std::integral_constant<int, -1>{});
break;
263template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
268 const index_t k = A.cols();
275 alignas(W_t::alignment()) T W[W_t::size()];
283 if (L.rows() == L.cols()) {
289 auto Ad = A_.middle_rows(j);
290 auto Ld = L_.block(j, j);
292 hyhound_diag_diag_microkernel<T, Abi, Conf, R, OL, OA>(k, W, Ld, Ad, D_);
294 foreach_chunked_merged(
296 [&](index_t i, auto rem_i) {
297 auto As = A_.middle_rows(i);
298 auto Ls = L_.block(i, j);
299 microkernel_tail_lut<T, Abi, Conf, OL, OA, OA>[rem_i - 1](
300 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
304 [&](index_t j, index_t rem_j) {
306 auto Ld = L_.
block(j, j);
313 auto Ad = A_.middle_rows(j);
314 auto Ld = L_.block(j, j);
316 microkernel_diag_lut<T, Abi, Conf, OL, OA>[rem_j - 1](k, W, Ld, Ad, D_);
318 foreach_chunked_merged(
319 j + rem_j, L.rows(), S,
320 [&](index_t i, auto rem_i) {
321 auto As = A_.middle_rows(i);
322 auto Ls = L_.block(i, j);
323 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[rem_j - 1][rem_i - 1](
324 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
333template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
336 const index_t k = A.cols();
354 static constexpr index_constant<SizeS<T, Abi>> S;
357 auto Ad = A_.middle_rows(j);
358 auto Ld = L_.block(j, j);
359 auto Wd = W_t{W_.middle_cols(j / R).data};
365 [&](index_t i,
auto ni) {
366 auto As = A_.middle_rows(i);
367 auto Ls = L_.block(i, j);
368 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
369 0, k, k, Wd, Ls, As, As, Ad, D_, Structure::General, 0);
376template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
380 index_t kA_in_offset)
noexcept {
381 const index_t k_in = Ain.cols(), k = Aout.cols();
405 static constexpr index_constant<SizeS<T, Abi>> S;
408 auto Ad = B_.middle_rows(j);
409 auto Wd = W_t{W_.middle_cols(j / R).data};
413 [&](index_t i,
auto ni) {
414 auto Aini = j == 0 ? Ain_.middle_rows(i) : Aout_.middle_rows(i);
415 auto Aouti = Aout_.middle_rows(i);
416 auto Ls = L_.block(i, j);
417 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
418 j == 0 ? kA_in_offset : 0, j == 0 ? k_in : k, k, Wd, Ls, Aini, Aouti, Ad, D_,
419 Structure::General, 0);
431 const index_t k = A1.cols();
441 alignas(W_t::alignment()) T W[W_t::size()];
452 static constexpr index_constant<SizeS<T, Abi>> S;
455 auto Ad = A1_.middle_rows(j);
456 auto Ld = L11_.block(j, j);
458 microkernel_diag_lut<T, Abi, Conf, OL1, OA1>[nj - 1](k, W, Ld, Ad, D_);
460 foreach_chunked_merged(
461 j + nj, L11.rows(), S,
462 [&](index_t i, auto ni) {
463 auto As = A1_.middle_rows(i);
464 auto Ls = L11_.block(i, j);
465 microkernel_tail_lut_2<T, Abi, Conf, OL1, OA1, OA1>[nj - 1][ni - 1](
466 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
471 [&](index_t i,
auto ni) {
472 auto As = A2_.middle_rows(i);
473 auto Ls = L21_.block(i, j);
474 microkernel_tail_lut_2<T, Abi, Conf, OL2, OA2, OA1>[nj - 1][ni - 1](
475 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
499 const index_t k = A1.cols(), k1 = A31.cols(), k2 = A22.cols();
514 alignas(W_t::alignment()) T W[W_t::size()];
529 static constexpr index_constant<SizeS<T, Abi>> S;
532 auto Ad = A1_.middle_rows(j);
533 auto Ld = L11_.block(j, j);
535 microkernel_diag_lut<T, Abi, Conf, OL, OW>[nj - 1](k, W, Ld, Ad, D_);
537 foreach_chunked_merged(
538 j + nj, L11.rows(), S,
539 [&](index_t i, auto ni) {
540 auto As = A1_.middle_rows(i);
541 auto Ls = L11_.block(i, j);
542 microkernel_tail_lut_2<T, Abi, Conf, OL, OW, OW>[nj - 1][ni - 1](
543 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
548 [&](index_t i,
auto ni) {
549 auto As_out = A2_out_.middle_rows(i);
550 auto As = j == 0 ? A22_.middle_rows(i) : As_out;
551 auto Ls = L21_.block(i, j);
553 index_t offset_s = j == 0 ? k1 : 0, k_s = j == 0 ? k2 : k;
554 microkernel_tail_lut_2<T, Abi, Conf, OY, OW, OW>[nj - 1][ni - 1](
555 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
560 [&](index_t i,
auto ni) {
561 auto As_out = A3_out_.middle_rows(i);
562 auto As = j == 0 ? A31_.middle_rows(i) : As_out;
563 auto Ls = L31_.block(i, j);
565 index_t offset_s = 0, k_s = j == 0 ? k1 : k;
566 microkernel_tail_lut_2<T, Abi, Conf, OU, OW, OW>[nj - 1][ni - 1](
567 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
588 bool shift_A_out)
noexcept {
589 const index_t k = A1.cols();
604 static_assert(R == S);
606 alignas(W_t::alignment()) T W[W_t::size()];
620 const bool do_shift = shift_A_out && j + nj == L11.cols();
623 auto Ad = A1_.middle_rows(j);
624 auto Ld = L11_.block(j, j);
626 microkernel_diag_lut<T, Abi, Conf, OL, OA>[nj - 1](k, W, Ld, Ad, D_);
628 foreach_chunked_merged(
629 j + nj, L11.rows(), S,
630 [&](index_t i, auto ni) {
631 auto As = A1_.middle_rows(i);
632 auto Ls = L11_.block(i, j);
633 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
634 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
639 [&](index_t i,
auto ni) {
640 auto As_out = A2_out_.middle_rows(i);
641 auto As = j == 0 ? A2_.middle_rows(i) : As_out;
642 auto Ls = L21_.block(i, j);
643 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
644 0, k, k, W, Ls, As, As_out, Ad, D_, Structure::General, do_shift ? -1 : 0);
649 [&](index_t i,
auto ni) {
650 auto As_out = Au_out_.middle_rows(i);
652 auto Ls = Lu1_.block(i, j);
654 const auto struc = i == j ? Structure::Upper
655 : i < j ? Structure::General
657 microkernel_tail_lut_2<T, Abi, Conf, OLu, OAu, OA>[nj - 1][ni - 1](
658 0, j == 0 ? 0 : k, k, W, Ls, As, As_out, Ad, D_, struc, do_shift ? -1 : 0);
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
datapar::simd< F, Abi > rotr(datapar::simd< F, Abi > x)
Rotate the elements of x to the right by S positions.
datapar::simd< F, Abi > rot(datapar::simd< F, Abi > x, int s)
Rotate the elements of x to the right by s positions.
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
stdx::simd< Tp, Abi > simd
auto rotate(datapar::simd< T, Abi > x, std::integral_constant< int, S >)
void hyhound_diag_full_microkernel(index_t kA, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
const constinit auto microkernel_full_lut
void hyhound_diag_cyclic_register(view< T, Abi, OL > L11, view< T, Abi, OW > A1, view< T, Abi, OY > L21, view< const T, Abi, OW > A22, view< T, Abi, OW > A2_out, view< T, Abi, OU > L31, view< const T, Abi, OW > A31, view< T, Abi, OW > A3_out, view< const T, Abi > D) noexcept
Performs a factorization update of the following matrix:
void hyhound_diag_register(view< T, Abi, OL > L, view< T, Abi, OA > A, view< const T, Abi > D) noexcept
Block hyperbolic Householder factorization update using register blocking.
void hyhound_diag_riccati_register(view< T, Abi, OL > L11, view< T, Abi, OA > A1, view< T, Abi, OL > L21, view< const T, Abi, OA > A2, view< T, Abi, OA > A2_out, view< T, Abi, OLu > Lu1, view< T, Abi, OAu > Au_out, view< const T, Abi > D, bool shift_A_out) noexcept
Performs a factorization update of the following matrix:
void hyhound_diag_tail_microkernel(index_t kA_in_offset, index_t kA_in, index_t k, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< const T, Abi, OA > A_in, uview< T, Abi, OA > A_out, uview< const T, Abi, OB > B, uview< const T, Abi, StorageOrder::ColMajor > diag, Structure struc_L, int rotate_A) noexcept
constexpr std::pair< index_t, index_t > hyhound_W_size(view< T, Abi, OL > L)
void hyhound_diag_diag_microkernel(index_t kA, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
void hyhound_diag_apply_register(view< T, Abi, OL > L, view< const T, Abi, OA > Ain, view< T, Abi, OA > Aout, view< const T, Abi, OA > B, view< const T, Abi > D, view< const T, Abi > W, index_t kA_in_offset=0) noexcept
Apply a block hyperbolic Householder transformation.
const constinit auto microkernel_diag_lut
void hyhound_diag_2_register(view< T, Abi, OL1 > L11, view< T, Abi, OA1 > A1, view< T, Abi, OL2 > L21, view< T, Abi, OA2 > A2, view< const T, Abi > D) noexcept
Same as hyhound_diag_register but for two block rows at once.
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
T cneg(T x, T signs)
Conditionally negates the sign bit of x, depending on signs, which should contain only ±0 (i....
std::integral_constant< index_t, I > index_constant
Self block(this const Self &self, index_t r, index_t c) noexcept
Self middle_rows(this const Self &self, index_t r) noexcept