10#include <guanaqo/trace.hpp>
12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
16template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
22template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
28template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
34template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
40template <
class T,
class Abi, KernelConfig Conf, index_t R, StorageOrder OA, StorageOrder OD>
41[[gnu::hot, gnu::flatten]]
void
50 const bool use_A = j == 0;
53 for (
index_t l = j + 1; l < k; ++l) {
54 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
56 bb[i] += (use_A ? A.load(l, i) : D.load(l, i)) * Alj;
60 aa[i] = use_A ? A.load(j, i) : D.load(j, i);
61 bb[j] += aa[j] * aa[j];
63 const simd ãjj = copysign(sqrt(bb[j]), aa[j]), β = aa[j] + ãjj;
64 simd inv_τ = β / ãjj, inv_β = simd{1} / β;
68 bb[i] = bb[i] * inv_β + aa[i];
74 bb[i] = (aa[i] + bb[i] * inv_β) * inv_τ;
75 D.store(aa[i] - bb[i], j, i);
78 for (
index_t l = j + 1; l < k; ++l) {
79 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
83 simd Ali = use_A ? A.load(l, i) : D.load(l, i);
93template <
class T,
class Abi, KernelConfig Conf, index_t R, StorageOrder OA, StorageOrder OD>
102 const bool use_A = j == 0;
107 for (
index_t l = j + 1; l < k; ++l) {
108 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
110 bb[i] += (use_A ? A.load(l, i) : D.load(l, i)) * Alj;
114 aa[i] = use_A ? A.load(j, i) : D.load(j, i);
115 bb[j] += aa[j] * aa[j];
117 const simd ãjj = copysign(sqrt(bb[j]), aa[j]), β = aa[j] + ãjj;
118 simd inv_τ = β / ãjj, inv_β = simd{1} / β;
122 bb[i] = (aa[i] + bb[i] * inv_β) * inv_τ;
123 D.store(aa[i] - bb[i], j, i);
126 for (
index_t l = j + 1; l < k; ++l) {
127 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
131 simd Ali = use_A ? A.load(l, i) : D.load(l, i);
164 V[l][i] = A.load(l, i);
166 V[j][i] += B.load(l, j) * A.load(l, i);
169 for (
index_t l = R; l < k; ++l)
171 auto Blj = B.load(l, j);
173 V[j][i] += Blj * A.load(l, i);
181 V[j][i] -= W.load(j, l) * V[l][i];
182 V[j][i] *= W.load(j, j);
189 V[j][i] -= W.load(l, j) * V[l][i];
190 V[j][i] *= W.load(j, j);
198 Bl[j] = B.load(l, j);
200 simd Dli = A.load(l, i) - V[l][i];
202 Dli -= V[j][i] * Bl[j];
207 for (
index_t l = R; l < k; ++l) {
209 Bl[j] = B.load(l, j);
211 simd Dli = A.load(l, i);
213 Dli -= V[j][i] * Bl[j];
221template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
231 BATMAT_ASSUME(W.rows() == 0 || (W.cols() == 1 && W.rows() == A.cols()) ||
235 alignas(W_t::alignment()) T W_sto[W_t::size()];
244 if (A.rows() == A.cols() && W.rows() == 0) {
245 auto Wj = W_t{W_sto};
249 auto Djj = D_.block(j, j);
253 geqrf_diag_microkernel<T, Abi, Conf, R, OA, OD>(k, Wj, A_, Djj);
255 foreach_chunked_merged(
257 [&](index_t i, auto rem_i) {
258 auto Dji = D_.block(j, i);
259 microkernel_tail_lut<T, Abi, Conf, OA, OD, OD>[rem_i - 1](
260 k, true, Wj, A_.block(j, i), Dji, Djj);
265 geqrf_diag_microkernel<T, Abi, Conf, R, OD, OD>(k - j, Wj, Djj, Djj);
267 foreach_chunked_merged(
269 [&](index_t i, auto rem_i) {
270 auto Dji = D_.block(j, i);
271 microkernel_tail_lut<T, Abi, Conf, OD, OD, OD>[rem_i - 1](
272 k - j, true, Wj, Dji, Dji, Djj);
278 auto Djj = D_.block(j, j);
286 auto Wj = store_full_W ? W_t{W_.middle_cols(j / R).data} : W_t{W_sto};
287 auto Djj = D_.block(j, j);
291 microkernel_diag_lut<T, Abi, Conf, OA, OD>[rem_j - 1](k, Wj, A_, Djj);
296 auto Dji = D_.block(j, i);
297 microkernel_tail_lut_2<T, Abi, Conf, OA, OD, OD>[rem_j - 1][rem_i - 1](
298 k, true, Wj, A_.block(j, i), Dji, Djj);
303 microkernel_diag_lut<T, Abi, Conf, OD, OD>[rem_j - 1](k - j, Wj, Djj, Djj);
305 foreach_chunked_merged(
308 auto Dji = D_.block(j, i);
309 microkernel_tail_lut_2<T, Abi, Conf, OD, OD, OD>[rem_j - 1][rem_i - 1](
310 k - j, true, Wj, Dji, Dji, Djj);
314 if (!store_full_W && W.rows() > 0) [[unlikely]]
315 for (
index_t l = 0; l < rem_j; ++l)
316 W_.store(Wj.load(l, l), j + l, 0);
322template <
class T,
class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
325 bool transposed,
bool reversed)
noexcept {
343 const bool forward = transposed ^ reversed;
347 const bool first = forward ? j == 0 : j + nj >= B.cols();
348 static constexpr index_constant<SizeS<T, Abi>> S;
350 auto Bjj = B_.block(j, j);
351 auto Wj = W_t{W_.middle_cols(j / R).data};
356 auto Dji = D_.block(j, i);
358 microkernel_tail_lut_2<T, Abi, Conf, OA, OD, OB>[nj - 1][ni - 1](
359 k - j, transposed, Wj, A_.block(j, i), Dji, Bjj);
361 microkernel_tail_lut_2<T, Abi, Conf, OD, OD, OB>[nj - 1][ni - 1](
362 k - j, transposed, Wj, Dji, Dji, Bjj);
364 if (first && !transposed && D_.data != A_.data)
365 for (index_t l = 0; l < j; ++l)
366 for (index_t ii = i; ii < i + ni; ++ii)
367 D_.store(A_.load(l, ii), l, ii);
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
consteval auto make_1d_lut(F f)
Returns an array of the form:
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
stdx::simd< Tp, Abi > simd
void geqrf_copy_register(view< const T, Abi, OA > A, view< T, Abi, OD > D, view< T, Abi > W) noexcept
Block hyperbolic Householder factorization update using register blocking.
constexpr std::pair< index_t, index_t > geqrf_W_size(view< T, Abi, OA > A)
void geqrf_diag_microkernel(index_t k, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< const T, Abi, OA > A, uview< T, Abi, OD > D) noexcept
const constinit auto microkernel_full_lut
void geqrf_tail_microkernel(index_t k, bool transposed, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< const T, Abi, OA > A, uview< T, Abi, OD > D, uview< const T, Abi, OB > B) noexcept
const constinit auto microkernel_tail_lut_2
const constinit auto microkernel_tail_lut
const constinit auto microkernel_diag_lut
void geqrf_full_microkernel(index_t k, uview< const T, Abi, OA > A, uview< T, Abi, OD > D) noexcept
A (k×R) D (k×R).
void geqrf_apply_register(view< const T, Abi, OA > A, view< T, Abi, OD > D, view< const T, Abi, OB > B, view< const T, Abi > W, bool transposed, bool reversed) noexcept
Apply a block Householder transformation.
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
std::integral_constant< index_t, I > index_constant