10#include <guanaqo/trace.hpp>
12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
16template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
22template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
28template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
34template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
40template <
class T,
class Abi, KernelConfig Conf, index_t R, StorageOrder OA, StorageOrder OD>
41[[gnu::hot, gnu::flatten]]
void
48 static constexpr auto safe_min = std::numeric_limits<T>::min();
51 const bool use_A = j == 0;
54 for (
index_t l = j + 1; l < k; ++l) {
55 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
57 bb[i] += (use_A ? A.load(l, i) : D.load(l, i)) * Alj;
61 aa[i] = use_A ? A.load(j, i) : D.load(j, i);
62 bb[j] += aa[j] * aa[j];
64 const simd abs_ãjj = sqrt(bb[j]);
65 const simd ãjj = copysign(abs_ãjj, aa[j]), β = aa[j] + ãjj;
71 bb[i] = bb[i] * inv_β + aa[i];
77 bb[i] = (aa[i] + bb[i] * inv_β) * inv_τ;
78 D.store(aa[i] - bb[i], j, i);
81 for (
index_t l = j + 1; l < k; ++l) {
82 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
86 simd Ali = use_A ? A.load(l, i) : D.load(l, i);
96template <
class T,
class Abi, KernelConfig Conf, index_t R, StorageOrder OA, StorageOrder OD>
104 static constexpr auto safe_min = std::numeric_limits<T>::min();
107 const bool use_A = j == 0;
112 for (
index_t l = j + 1; l < k; ++l) {
113 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
115 bb[i] += (use_A ? A.load(l, i) : D.load(l, i)) * Alj;
119 aa[i] = use_A ? A.load(j, i) : D.load(j, i);
120 bb[j] += aa[j] * aa[j];
122 const simd abs_ãjj = sqrt(bb[j]);
123 const simd ãjj = copysign(abs_ãjj, aa[j]), β = aa[j] + ãjj;
129 bb[i] = (aa[i] + bb[i] * inv_β) * inv_τ;
130 D.store(aa[i] - bb[i], j, i);
133 for (
index_t l = j + 1; l < k; ++l) {
134 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
138 simd Ali = use_A ? A.load(l, i) : D.load(l, i);
171 V[l][i] = A.load(l, i);
173 V[j][i] += B.load(l, j) * A.load(l, i);
176 for (
index_t l = R; l < k; ++l)
178 auto Blj = B.load(l, j);
180 V[j][i] += Blj * A.load(l, i);
188 V[j][i] -= W.load(j, l) * V[l][i];
189 V[j][i] *= W.load(j, j);
196 V[j][i] -= W.load(l, j) * V[l][i];
197 V[j][i] *= W.load(j, j);
205 Bl[j] = B.load(l, j);
207 simd Dli = A.load(l, i) - V[l][i];
209 Dli -= V[j][i] * Bl[j];
214 for (
index_t l = R; l < k; ++l) {
216 Bl[j] = B.load(l, j);
218 simd Dli = A.load(l, i);
220 Dli -= V[j][i] * Bl[j];
228template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
238 BATMAT_ASSUME(W.rows() == 0 || (W.cols() == 1 && W.rows() == A.cols()) ||
242 alignas(W_t::alignment()) T W_sto[W_t::size()];
251 if (A.rows() == A.cols() && W.rows() == 0) {
252 auto Wj = W_t{W_sto};
256 auto Djj = D_.block(j, j);
260 geqrf_diag_microkernel<T, Abi, Conf, R, OA, OD>(k, Wj, A_, Djj);
262 foreach_chunked_merged(
264 [&](index_t i, auto rem_i) {
265 auto Dji = D_.block(j, i);
266 microkernel_tail_lut<T, Abi, Conf, OA, OD, OD>[rem_i - 1](
267 k, true, Wj, A_.block(j, i), Dji, Djj);
272 geqrf_diag_microkernel<T, Abi, Conf, R, OD, OD>(k - j, Wj, Djj, Djj);
274 foreach_chunked_merged(
276 [&](index_t i, auto rem_i) {
277 auto Dji = D_.block(j, i);
278 microkernel_tail_lut<T, Abi, Conf, OD, OD, OD>[rem_i - 1](
279 k - j, true, Wj, Dji, Dji, Djj);
285 auto Djj = D_.block(j, j);
293 auto Wj = store_full_W ? W_t{W_.middle_cols(j / R).data} : W_t{W_sto};
294 auto Djj = D_.block(j, j);
298 microkernel_diag_lut<T, Abi, Conf, OA, OD>[rem_j - 1](k, Wj, A_, Djj);
303 auto Dji = D_.block(j, i);
304 microkernel_tail_lut_2<T, Abi, Conf, OA, OD, OD>[rem_j - 1][rem_i - 1](
305 k, true, Wj, A_.block(j, i), Dji, Djj);
310 microkernel_diag_lut<T, Abi, Conf, OD, OD>[rem_j - 1](k - j, Wj, Djj, Djj);
312 foreach_chunked_merged(
315 auto Dji = D_.block(j, i);
316 microkernel_tail_lut_2<T, Abi, Conf, OD, OD, OD>[rem_j - 1][rem_i - 1](
317 k - j, true, Wj, Dji, Dji, Djj);
321 if (!store_full_W && W.rows() > 0) [[unlikely]]
322 for (
index_t l = 0; l < rem_j; ++l)
323 W_.store(Wj.load(l, l), j + l, 0);
329template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
332 bool transposed,
bool reversed)
noexcept {
350 const bool forward = transposed ^ reversed;
354 const bool first = forward ? j == 0 : j + nj >= B.cols();
355 static constexpr index_constant<SizeS<T, Abi>> S;
357 auto Bjj = B_.block(j, j);
358 auto Wj = W_t{W_.middle_cols(j / R).data};
363 auto Dji = D_.block(j, i);
365 microkernel_tail_lut_2<T, Abi, Conf, OA, OD, OB>[nj - 1][ni - 1](
366 k - j, transposed, Wj, A_.block(j, i), Dji, Bjj);
368 microkernel_tail_lut_2<T, Abi, Conf, OD, OD, OB>[nj - 1][ni - 1](
369 k - j, transposed, Wj, Dji, Dji, Bjj);
371 if (first && !transposed && D_.data != A_.data)
372 for (index_t l = 0; l < j; ++l)
373 for (index_t ii = i; ii < i + ni; ++ii)
374 D_.store(A_.load(l, ii), l, ii);
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
consteval auto make_1d_lut(F f)
Returns an array of the form:
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
auto select(auto cond, auto t, auto f)
stdx::simd< Tp, Abi > simd
void geqrf_copy_register(view< const T, Abi, OA > A, view< T, Abi, OD > D, view< T, Abi > W) noexcept
Block hyperbolic Householder factorization update using register blocking.
constexpr std::pair< index_t, index_t > geqrf_W_size(view< T, Abi, OA > A)
void geqrf_diag_microkernel(index_t k, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< const T, Abi, OA > A, uview< T, Abi, OD > D) noexcept
const constinit auto microkernel_full_lut
void geqrf_tail_microkernel(index_t k, bool transposed, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< const T, Abi, OA > A, uview< T, Abi, OD > D, uview< const T, Abi, OB > B) noexcept
const constinit auto microkernel_tail_lut_2
const constinit auto microkernel_tail_lut
const constinit auto microkernel_diag_lut
void geqrf_full_microkernel(index_t k, uview< const T, Abi, OA > A, uview< T, Abi, OD > D) noexcept
A (k×R) D (k×R).
void geqrf_apply_register(view< const T, Abi, OA > A, view< T, Abi, OD > D, view< const T, Abi, OB > B, view< const T, Abi > W, bool transposed, bool reversed) noexcept
Apply a block Householder transformation.
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
std::integral_constant< index_t, I > index_constant