10#include <guanaqo/trace.hpp>
13#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
17template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
23template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
29template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OB>
35template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OB>
41template <
class T,
class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
42[[gnu::hot, gnu::flatten]]
void
53 static constexpr auto safe_min = std::numeric_limits<T>::min();
58 for (
index_t l = 0; l < kA; ++l) {
59 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0))
60 : A.load(j, l) * diag.load(l, 0);
62 bb[i] += A.load(i, l) * Ajl;
65 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
66 const simd abs_L̃jj = sqrt(Ljj * Ljj + α2);
67 const simd L̃jj = copysign(abs_L̃jj, Ljj), β = Ljj + L̃jj;
68 simd γoβ =
datapar::select(abs_L̃jj > safe_min, simd{1} / L̃jj, simd{0}), γ = β * γoβ,
70 L_cached.store(L̃jj, j, j);
73 simd Lij = L_cached.load(i, j);
74 bb[i] = γ * Lij + bb[i] * γoβ;
75 L_cached.store(bb[i] - Lij, i, j);
78 for (
index_t l = 0; l < kA; ++l) {
79 simd Ajl = A.load(j, l) * inv_β;
82 simd Ail = A.load(i, l);
97template <
class T,
class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
98[[gnu::hot, gnu::flatten]]
void
108 static constexpr auto safe_min = std::numeric_limits<T>::min();
113 for (
index_t l = 0; l < kA; ++l) {
114 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0))
115 : A.load(j, l) * diag.load(l, 0);
117 bb[i] += A.load(i, l) * Ajl;
120 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
121 const simd abs_L̃jj = sqrt(Ljj * Ljj + α2);
122 const simd L̃jj = copysign(abs_L̃jj, Ljj), β = Ljj + L̃jj;
123 simd γoβ =
datapar::select(abs_L̃jj > safe_min, simd{1} / L̃jj, simd{0}), γ = β * γoβ,
125 L_cached.store(L̃jj, j, j);
128 simd Lij = L_cached.load(i, j);
129 bb[i] = γ * Lij + bb[i] * γoβ;
130 L_cached.store(bb[i] - Lij, i, j);
133 for (
index_t l = 0; l < kA; ++l) {
134 simd Ajl = A.load(j, l) * inv_β;
137 simd Ail = A.load(i, l);
147template <
class T,
class Abi,
int S>
153template <
class T,
class Abi>
177 for (
index_t lA = 0; lA < kA_in; ++lA) {
178 index_t lB = lA + kA_in_offset;
180 auto Bjl = Conf.sign_only ? cneg(B.load(j, lB), diag.load(lB, 0))
181 : B.load(j, lB) * diag.load(lB, 0);
183 V[i][j] += A_in.load(i, lA) * Bjl;
195 Wj[i] = W.load(i, j);
197 simd Lij = L_cached.load(i, j);
200 V[i][j] -= V[i][l] * Wj[l];
201 V[i][j] *= W.load(j, j);
203 L_cached.store(Lij, i, j);
211 Wj[i] = W.load(i, j);
214 V[i][j] -= V[i][l] * Wj[l];
215 V[i][j] *= W.load(j, j);
223 Wj[i] = W.load(i, j);
227 Lij = L_cached.load(i, j);
231 V[i][j] -= V[i][l] * Wj[l];
232 V[i][j] *= W.load(j, j);
235 L_cached.store(Lij, i, j);
243 const auto update_A = [&] [[gnu::always_inline]] (
auto s) {
245 for (
index_t lB = 0; lB < kA_in_offset; ++lB) [[unlikely]] {
247 Bjl[j] = B.load(j, lB);
251 Ail -= V[i][j] * Bjl[j];
255 for (
index_t lB = kA_in_offset + kA_in; lB < k; ++lB) [[unlikely]] {
257 Bjl[j] = B.load(j, lB);
261 Ail -= V[i][j] * Bjl[j];
265 for (
index_t lA = 0; lA < kA_in; ++lA) [[likely]] {
266 index_t lB = lA + kA_in_offset;
268 Bjl[j] = B.load(j, lB);
270 auto Ail = A_in.load(i, lA);
272 Ail -= V[i][j] * Bjl[j];
277#if defined(__AVX512F__) && 0
282 update_A(std::integral_constant<int, 0>{});
284 case -1: update_A(std::integral_constant<int, -1>{});
break;
293template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
305 alignas(W_t::alignment()) T W[W_t::size()];
313 if (L.rows() == L.cols()) {
319 auto Ad = A_.middle_rows(j);
320 auto Ld = L_.block(j, j);
322 hyhound_diag_diag_microkernel<T, Abi, Conf, R, OL, OA>(k, W, Ld, Ad, D_);
324 foreach_chunked_merged(
326 [&](index_t i, auto rem_i) {
327 auto As = A_.middle_rows(i);
328 auto Ls = L_.block(i, j);
329 microkernel_tail_lut<T, Abi, Conf, OL, OA, OA>[rem_i - 1](
330 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
336 auto Ld = L_.
block(j, j);
343 auto Ad = A_.middle_rows(j);
344 auto Ld = L_.block(j, j);
346 microkernel_diag_lut<T, Abi, Conf, OL, OA>[rem_j - 1](k, W, Ld, Ad, D_);
348 foreach_chunked_merged(
349 j + rem_j, L.rows(), S,
350 [&](index_t i, auto rem_i) {
351 auto As = A_.middle_rows(i);
352 auto Ls = L_.block(i, j);
353 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[rem_j - 1][rem_i - 1](
354 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
363template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
384 static constexpr index_constant<SizeS<T, Abi>> S;
387 auto Ad = A_.middle_rows(j);
388 auto Ld = L_.block(j, j);
389 auto Wd = W_t{W_.middle_cols(j / R).data};
396 auto As = A_.middle_rows(i);
397 auto Ls = L_.block(i, j);
398 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
399 0, k, k, Wd, Ls, As, As, Ad, D_, Structure::General, 0);
406template <
class T,
class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
410 index_t kA_in_offset)
noexcept {
411 const index_t k_in = Ain.cols(), k = Aout.cols();
435 static constexpr index_constant<SizeS<T, Abi>> S;
438 auto Ad = B_.middle_rows(j);
439 auto Wd = W_t{W_.middle_cols(j / R).data};
444 auto Aini = j == 0 ? Ain_.middle_rows(i) : Aout_.middle_rows(i);
445 auto Aouti = Aout_.middle_rows(i);
446 auto Ls = L_.block(i, j);
447 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
448 j == 0 ? kA_in_offset : 0, j == 0 ? k_in : k, k, Wd, Ls, Aini, Aouti, Ad, D_,
449 Structure::General, 0);
471 alignas(W_t::alignment()) T W[W_t::size()];
482 static constexpr index_constant<SizeS<T, Abi>> S;
485 auto Ad = A1_.middle_rows(j);
486 auto Ld = L11_.block(j, j);
488 microkernel_diag_lut<T, Abi, Conf, OL1, OA1>[nj - 1](k, W, Ld, Ad, D_);
490 foreach_chunked_merged(
491 j + nj, L11.rows(), S,
492 [&](index_t i, auto ni) {
493 auto As = A1_.middle_rows(i);
494 auto Ls = L11_.block(i, j);
495 microkernel_tail_lut_2<T, Abi, Conf, OL1, OA1, OA1>[nj - 1][ni - 1](
496 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
502 auto As = A2_.middle_rows(i);
503 auto Ls = L21_.block(i, j);
504 microkernel_tail_lut_2<T, Abi, Conf, OL2, OA2, OA1>[nj - 1][ni - 1](
505 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
529 const index_t k = A1.cols(), k1 = A31.cols(), k2 = A22.cols();
544 alignas(W_t::alignment()) T W[W_t::size()];
559 static constexpr index_constant<SizeS<T, Abi>> S;
562 auto Ad = A1_.middle_rows(j);
563 auto Ld = L11_.block(j, j);
565 microkernel_diag_lut<T, Abi, Conf, OL, OW>[nj - 1](k, W, Ld, Ad, D_);
567 foreach_chunked_merged(
568 j + nj, L11.rows(), S,
569 [&](index_t i, auto ni) {
570 auto As = A1_.middle_rows(i);
571 auto Ls = L11_.block(i, j);
572 microkernel_tail_lut_2<T, Abi, Conf, OL, OW, OW>[nj - 1][ni - 1](
573 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
579 auto As_out = A2_out_.middle_rows(i);
580 auto As = j == 0 ? A22_.middle_rows(i) : As_out;
581 auto Ls = L21_.block(i, j);
583 index_t offset_s = j == 0 ? k1 : 0, k_s = j == 0 ? k2 : k;
584 microkernel_tail_lut_2<T, Abi, Conf, OY, OW, OW>[nj - 1][ni - 1](
585 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
591 auto As_out = A3_out_.middle_rows(i);
592 auto As = j == 0 ? A31_.middle_rows(i) : As_out;
593 auto Ls = L31_.block(i, j);
595 index_t offset_s = 0, k_s = j == 0 ? k1 : k;
596 microkernel_tail_lut_2<T, Abi, Conf, OU, OW, OW>[nj - 1][ni - 1](
597 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
618 bool shift_A_out)
noexcept {
634 static_assert(R == S);
636 alignas(W_t::alignment()) T W[W_t::size()];
650 const bool do_shift = shift_A_out && j + nj == L11.cols();
653 auto Ad = A1_.middle_rows(j);
654 auto Ld = L11_.block(j, j);
656 microkernel_diag_lut<T, Abi, Conf, OL, OA>[nj - 1](k, W, Ld, Ad, D_);
658 foreach_chunked_merged(
659 j + nj, L11.rows(), S,
660 [&](index_t i, auto ni) {
661 auto As = A1_.middle_rows(i);
662 auto Ls = L11_.block(i, j);
663 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
664 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
670 auto As_out = A2_out_.middle_rows(i);
671 auto As = j == 0 ? A2_.middle_rows(i) : As_out;
672 auto Ls = L21_.block(i, j);
673 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
674 0, k, k, W, Ls, As, As_out, Ad, D_, Structure::General, do_shift ? -1 : 0);
680 auto As_out = Au_out_.middle_rows(i);
682 auto Ls = Lu1_.block(i, j);
684 const auto struc = i == j ? Structure::Upper
685 : i < j ? Structure::General
687 microkernel_tail_lut_2<T, Abi, Conf, OLu, OAu, OA>[nj - 1][ni - 1](
688 0, j == 0 ? 0 : k, k, W, Ls, As, As_out, Ad, D_, struc, do_shift ? -1 : 0);
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
datapar::simd< F, Abi > rotr(datapar::simd< F, Abi > x)
Rotate the elements of x to the right by S positions.
datapar::simd< F, Abi > rot(datapar::simd< F, Abi > x, int s)
Rotate the elements of x to the right by s positions.
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
consteval auto make_1d_lut(F f)
Returns an array of the form:
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
auto select(auto cond, auto t, auto f)
stdx::simd< Tp, Abi > simd
auto rotate(datapar::simd< T, Abi > x, std::integral_constant< int, S >)
void hyhound_diag_full_microkernel(index_t kA, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
const constinit auto microkernel_full_lut
void hyhound_diag_cyclic_register(view< T, Abi, OL > L11, view< T, Abi, OW > A1, view< T, Abi, OY > L21, view< const T, Abi, OW > A22, view< T, Abi, OW > A2_out, view< T, Abi, OU > L31, view< const T, Abi, OW > A31, view< T, Abi, OW > A3_out, view< const T, Abi > D) noexcept
Performs a factorization update of the following matrix:
void hyhound_diag_register(view< T, Abi, OL > L, view< T, Abi, OA > A, view< const T, Abi > D) noexcept
Block hyperbolic Householder factorization update using register blocking.
void hyhound_diag_riccati_register(view< T, Abi, OL > L11, view< T, Abi, OA > A1, view< T, Abi, OL > L21, view< const T, Abi, OA > A2, view< T, Abi, OA > A2_out, view< T, Abi, OLu > Lu1, view< T, Abi, OAu > Au_out, view< const T, Abi > D, bool shift_A_out) noexcept
Performs a factorization update of the following matrix:
void hyhound_diag_tail_microkernel(index_t kA_in_offset, index_t kA_in, index_t k, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< const T, Abi, OA > A_in, uview< T, Abi, OA > A_out, uview< const T, Abi, OB > B, uview< const T, Abi, StorageOrder::ColMajor > diag, Structure struc_L, int rotate_A) noexcept
constexpr std::pair< index_t, index_t > hyhound_W_size(view< T, Abi, OL > L)
void hyhound_diag_diag_microkernel(index_t kA, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
void hyhound_diag_apply_register(view< T, Abi, OL > L, view< const T, Abi, OA > Ain, view< T, Abi, OA > Aout, view< const T, Abi, OA > B, view< const T, Abi > D, view< const T, Abi > W, index_t kA_in_offset=0) noexcept
Apply a block hyperbolic Householder transformation.
const constinit auto microkernel_tail_lut_2
const constinit auto microkernel_tail_lut
const constinit auto microkernel_diag_lut
void hyhound_diag_2_register(view< T, Abi, OL1 > L11, view< T, Abi, OA1 > A1, view< T, Abi, OL2 > L21, view< T, Abi, OA2 > A2, view< const T, Abi > D) noexcept
Same as hyhound_diag_register but for two block rows at once.
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
T cneg(T x, T signs)
Conditionally negates the sign bit of x, depending on signs, which should contain only ±0 (i....
std::integral_constant< index_t, I > index_constant
Self block(this const Self &self, index_t r, index_t c) noexcept
Self middle_rows(this const Self &self, index_t r) noexcept