10#include <guanaqo/trace.hpp>
13#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
17template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
23template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
29template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OB>
35template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OB>
41template <
class T,
class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
42[[gnu::hot, gnu::flatten]]
void
57 for (index_t l = 0; l < kA; ++l) {
58 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0))
59 : A.load(j, l) * diag.load(l, 0);
61 bb[i] += A.load(i, l) * Ajl;
64 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
65 const simd L̃jj = copysign(sqrt(Ljj * Ljj + α2), Ljj), β = Ljj + L̃jj;
66 simd γoβ = simd{2} * β / (β * β + α2), γ = β * γoβ, inv_β = simd{1} / β;
67 L_cached.store(L̃jj, j, j);
70 simd Lij = L_cached.load(i, j);
71 bb[i] = γ * Lij + bb[i] * γoβ;
72 L_cached.store(bb[i] - Lij, i, j);
75 for (index_t l = 0; l < kA; ++l) {
76 simd Ajl = A.load(j, l) * inv_β;
79 simd Ail = A.load(i, l);
94template <
class T,
class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
95[[gnu::hot, gnu::flatten]]
void
109 for (index_t l = 0; l < kA; ++l) {
110 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0))
111 : A.load(j, l) * diag.load(l, 0);
113 bb[i] += A.load(i, l) * Ajl;
116 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
117 const simd L̃jj = copysign(sqrt(Ljj * Ljj + α2), Ljj), β = Ljj + L̃jj;
118 simd γoβ = simd{2} * β / (β * β + α2), γ = β * γoβ, inv_β = simd{1} / β;
119 L_cached.store(L̃jj, j, j);
122 simd Lij = L_cached.load(i, j);
123 bb[i] = γ * Lij + bb[i] * γoβ;
124 L_cached.store(bb[i] - Lij, i, j);
127 for (index_t l = 0; l < kA; ++l) {
128 simd Ajl = A.load(j, l) * inv_β;
131 simd Ail = A.load(i, l);
141template <
class T,
class Abi,
int S>
147template <
class T,
class Abi>
158template <
class T,
class Abi, KernelConfig Conf, index_t R, index_t S,
StorageOrder OL,
161 index_t kA_in_offset, index_t kA_in, index_t k,
171 for (index_t lA = 0; lA < kA_in; ++lA) {
172 index_t lB = lA + kA_in_offset;
174 auto Bjl = Conf.sign_only ? cneg(B.load(j, lB), diag.load(lB, 0))
175 : B.load(j, lB) * diag.load(lB, 0);
177 V[i][j] += A_in.load(i, lA) * Bjl;
189 Wj[i] = W.load(i, j);
191 simd Lij = L_cached.load(i, j);
194 V[i][j] -= V[i][l] * Wj[l];
195 V[i][j] *= W.load(j, j);
197 L_cached.store(Lij, i, j);
205 Wj[i] = W.load(i, j);
208 V[i][j] -= V[i][l] * Wj[l];
209 V[i][j] *= W.load(j, j);
217 Wj[i] = W.load(i, j);
221 Lij = L_cached.load(i, j);
225 V[i][j] -= V[i][l] * Wj[l];
226 V[i][j] *= W.load(j, j);
229 L_cached.store(Lij, i, j);
237 const auto update_A = [&] [[gnu::always_inline]] (
auto s) {
239 for (index_t lB = 0; lB < kA_in_offset; ++lB) [[unlikely]] {
241 Bjl[j] = B.load(j, lB);
245 Ail -= V[i][j] * Bjl[j];
249 for (index_t lB = kA_in_offset + kA_in; lB < k; ++lB) [[unlikely]] {
251 Bjl[j] = B.load(j, lB);
255 Ail -= V[i][j] * Bjl[j];
259 for (index_t lA = 0; lA < kA_in; ++lA) [[likely]] {
260 index_t lB = lA + kA_in_offset;
262 Bjl[j] = B.load(j, lB);
264 auto Ail = A_in.load(i, lA);
266 Ail -= V[i][j] * Bjl[j];
271#if defined(__AVX512F__) && 0
276 update_A(std::integral_constant<int, 0>{});
278 case -1: update_A(std::integral_constant<int, -1>{});
break;
287template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
292 const index_t k = A.cols();
299 alignas(W_t::alignment()) T W[W_t::size()];
307 if (L.rows() == L.cols()) {
313 auto Ad = A_.middle_rows(j);
314 auto Ld = L_.block(j, j);
316 hyhound_diag_diag_microkernel<T, Abi, Conf, R, OL, OA>(k, W, Ld, Ad, D_);
318 foreach_chunked_merged(
320 [&](index_t i, auto rem_i) {
321 auto As = A_.middle_rows(i);
322 auto Ls = L_.block(i, j);
323 microkernel_tail_lut<T, Abi, Conf, OL, OA, OA>[rem_i - 1](
324 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
328 [&](index_t j, index_t rem_j) {
330 auto Ld = L_.
block(j, j);
337 auto Ad = A_.middle_rows(j);
338 auto Ld = L_.block(j, j);
340 microkernel_diag_lut<T, Abi, Conf, OL, OA>[rem_j - 1](k, W, Ld, Ad, D_);
342 foreach_chunked_merged(
343 j + rem_j, L.rows(), S,
344 [&](index_t i, auto rem_i) {
345 auto As = A_.middle_rows(i);
346 auto Ls = L_.block(i, j);
347 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[rem_j - 1][rem_i - 1](
348 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
357template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
360 const index_t k = A.cols();
378 static constexpr index_constant<SizeS<T, Abi>> S;
381 auto Ad = A_.middle_rows(j);
382 auto Ld = L_.block(j, j);
383 auto Wd = W_t{W_.middle_cols(j / R).data};
389 [&](index_t i,
auto ni) {
390 auto As = A_.middle_rows(i);
391 auto Ls = L_.block(i, j);
392 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
393 0, k, k, Wd, Ls, As, As, Ad, D_, Structure::General, 0);
400template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
404 index_t kA_in_offset)
noexcept {
405 const index_t k_in = Ain.cols(), k = Aout.cols();
429 static constexpr index_constant<SizeS<T, Abi>> S;
432 auto Ad = B_.middle_rows(j);
433 auto Wd = W_t{W_.middle_cols(j / R).data};
437 [&](index_t i,
auto ni) {
438 auto Aini = j == 0 ? Ain_.middle_rows(i) : Aout_.middle_rows(i);
439 auto Aouti = Aout_.middle_rows(i);
440 auto Ls = L_.block(i, j);
441 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
442 j == 0 ? kA_in_offset : 0, j == 0 ? k_in : k, k, Wd, Ls, Aini, Aouti, Ad, D_,
443 Structure::General, 0);
455 const index_t k = A1.cols();
465 alignas(W_t::alignment()) T W[W_t::size()];
476 static constexpr index_constant<SizeS<T, Abi>> S;
479 auto Ad = A1_.middle_rows(j);
480 auto Ld = L11_.block(j, j);
482 microkernel_diag_lut<T, Abi, Conf, OL1, OA1>[nj - 1](k, W, Ld, Ad, D_);
484 foreach_chunked_merged(
485 j + nj, L11.rows(), S,
486 [&](index_t i, auto ni) {
487 auto As = A1_.middle_rows(i);
488 auto Ls = L11_.block(i, j);
489 microkernel_tail_lut_2<T, Abi, Conf, OL1, OA1, OA1>[nj - 1][ni - 1](
490 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
495 [&](index_t i,
auto ni) {
496 auto As = A2_.middle_rows(i);
497 auto Ls = L21_.block(i, j);
498 microkernel_tail_lut_2<T, Abi, Conf, OL2, OA2, OA1>[nj - 1][ni - 1](
499 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
523 const index_t k = A1.cols(), k1 = A31.cols(), k2 = A22.cols();
538 alignas(W_t::alignment()) T W[W_t::size()];
553 static constexpr index_constant<SizeS<T, Abi>> S;
556 auto Ad = A1_.middle_rows(j);
557 auto Ld = L11_.block(j, j);
559 microkernel_diag_lut<T, Abi, Conf, OL, OW>[nj - 1](k, W, Ld, Ad, D_);
561 foreach_chunked_merged(
562 j + nj, L11.rows(), S,
563 [&](index_t i, auto ni) {
564 auto As = A1_.middle_rows(i);
565 auto Ls = L11_.block(i, j);
566 microkernel_tail_lut_2<T, Abi, Conf, OL, OW, OW>[nj - 1][ni - 1](
567 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
572 [&](index_t i,
auto ni) {
573 auto As_out = A2_out_.middle_rows(i);
574 auto As = j == 0 ? A22_.middle_rows(i) : As_out;
575 auto Ls = L21_.block(i, j);
577 index_t offset_s = j == 0 ? k1 : 0, k_s = j == 0 ? k2 : k;
578 microkernel_tail_lut_2<T, Abi, Conf, OY, OW, OW>[nj - 1][ni - 1](
579 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
584 [&](index_t i,
auto ni) {
585 auto As_out = A3_out_.middle_rows(i);
586 auto As = j == 0 ? A31_.middle_rows(i) : As_out;
587 auto Ls = L31_.block(i, j);
589 index_t offset_s = 0, k_s = j == 0 ? k1 : k;
590 microkernel_tail_lut_2<T, Abi, Conf, OU, OW, OW>[nj - 1][ni - 1](
591 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
612 bool shift_A_out)
noexcept {
613 const index_t k = A1.cols();
628 static_assert(R == S);
630 alignas(W_t::alignment()) T W[W_t::size()];
644 const bool do_shift = shift_A_out && j + nj == L11.cols();
647 auto Ad = A1_.middle_rows(j);
648 auto Ld = L11_.block(j, j);
650 microkernel_diag_lut<T, Abi, Conf, OL, OA>[nj - 1](k, W, Ld, Ad, D_);
652 foreach_chunked_merged(
653 j + nj, L11.rows(), S,
654 [&](index_t i, auto ni) {
655 auto As = A1_.middle_rows(i);
656 auto Ls = L11_.block(i, j);
657 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
658 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
663 [&](index_t i,
auto ni) {
664 auto As_out = A2_out_.middle_rows(i);
665 auto As = j == 0 ? A2_.middle_rows(i) : As_out;
666 auto Ls = L21_.block(i, j);
667 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
668 0, k, k, W, Ls, As, As_out, Ad, D_, Structure::General, do_shift ? -1 : 0);
673 [&](index_t i,
auto ni) {
674 auto As_out = Au_out_.middle_rows(i);
676 auto Ls = Lu1_.block(i, j);
678 const auto struc = i == j ? Structure::Upper
679 : i < j ? Structure::General
681 microkernel_tail_lut_2<T, Abi, Conf, OLu, OAu, OA>[nj - 1][ni - 1](
682 0, j == 0 ? 0 : k, k, W, Ls, As, As_out, Ad, D_, struc, do_shift ? -1 : 0);
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
datapar::simd< F, Abi > rotr(datapar::simd< F, Abi > x)
Rotate the elements of x to the right by S positions.
datapar::simd< F, Abi > rot(datapar::simd< F, Abi > x, int s)
Rotate the elements of x to the right by s positions.
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
consteval auto make_1d_lut(F f)
Returns an array of the form:
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
stdx::simd< Tp, Abi > simd
auto rotate(datapar::simd< T, Abi > x, std::integral_constant< int, S >)
void hyhound_diag_full_microkernel(index_t kA, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
const constinit auto microkernel_full_lut
void hyhound_diag_cyclic_register(view< T, Abi, OL > L11, view< T, Abi, OW > A1, view< T, Abi, OY > L21, view< const T, Abi, OW > A22, view< T, Abi, OW > A2_out, view< T, Abi, OU > L31, view< const T, Abi, OW > A31, view< T, Abi, OW > A3_out, view< const T, Abi > D) noexcept
Performs a factorization update of the following matrix:
void hyhound_diag_register(view< T, Abi, OL > L, view< T, Abi, OA > A, view< const T, Abi > D) noexcept
Block hyperbolic Householder factorization update using register blocking.
void hyhound_diag_riccati_register(view< T, Abi, OL > L11, view< T, Abi, OA > A1, view< T, Abi, OL > L21, view< const T, Abi, OA > A2, view< T, Abi, OA > A2_out, view< T, Abi, OLu > Lu1, view< T, Abi, OAu > Au_out, view< const T, Abi > D, bool shift_A_out) noexcept
Performs a factorization update of the following matrix:
void hyhound_diag_tail_microkernel(index_t kA_in_offset, index_t kA_in, index_t k, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< const T, Abi, OA > A_in, uview< T, Abi, OA > A_out, uview< const T, Abi, OB > B, uview< const T, Abi, StorageOrder::ColMajor > diag, Structure struc_L, int rotate_A) noexcept
constexpr std::pair< index_t, index_t > hyhound_W_size(view< T, Abi, OL > L)
void hyhound_diag_diag_microkernel(index_t kA, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
void hyhound_diag_apply_register(view< T, Abi, OL > L, view< const T, Abi, OA > Ain, view< T, Abi, OA > Aout, view< const T, Abi, OA > B, view< const T, Abi > D, view< const T, Abi > W, index_t kA_in_offset=0) noexcept
Apply a block hyperbolic Householder transformation.
const constinit auto microkernel_tail_lut_2
const constinit auto microkernel_tail_lut
const constinit auto microkernel_diag_lut
void hyhound_diag_2_register(view< T, Abi, OL1 > L11, view< T, Abi, OA1 > A1, view< T, Abi, OL2 > L21, view< T, Abi, OA2 > A2, view< const T, Abi > D) noexcept
Same as hyhound_diag_register but for two block rows at once.
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
T cneg(T x, T signs)
Conditionally negates the sign bit of x, depending on signs, which should contain only ±0 (i....
std::integral_constant< index_t, I > index_constant
Self block(this const Self &self, index_t r, index_t c) noexcept
Self middle_rows(this const Self &self, index_t r) noexcept