batmat main
Batched linear algebra routines
Loading...
Searching...
No Matches
hyhound.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
7#include <batmat/lut.hpp>
8#include <batmat/ops/cneg.hpp>
10#include <guanaqo/trace.hpp>
11#include <type_traits>
12
13#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
14
16
17template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
18inline const constinit auto microkernel_diag_lut =
21 });
22
23template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
24inline const constinit auto microkernel_full_lut =
27 });
28
29template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OB>
30inline const constinit auto microkernel_tail_lut =
33 });
34
35template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OB>
37 []<index_t NR, index_t NS>(index_constant<NR>, index_constant<NS>) {
39 });
40
41template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
42[[gnu::hot, gnu::flatten]] void
47 using std::copysign;
48 using std::sqrt;
49 using simd = datapar::simd<T, Abi>;
50 // Pre-compute the offsets of the columns of L
51 auto L_cached = with_cached_access<R, R>(L);
52 BATMAT_ASSUME(kA > 0);
53
54 UNROLL_FOR (index_t j = 0; j < R; ++j) {
55 // Compute all inner products between A and a
56 simd bb[R]{};
57 for (index_t l = 0; l < kA; ++l) {
58 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0)) //
59 : A.load(j, l) * diag.load(l, 0);
60 UNROLL_FOR (index_t i = 0; i < R; ++i)
61 bb[i] += A.load(i, l) * Ajl;
62 }
63 // Energy condition and Householder coefficients
64 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
65 const simd L̃jj = copysign(sqrt(Ljj * Ljj + α2), Ljj), β = Ljj + L̃jj;
66 simd γoβ = simd{2} * β / (β * β + α2), γ = β * γoβ, inv_β = simd{1} / β;
67 L_cached.store(L̃jj, j, j);
68 // Compute L̃
69 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
70 simd Lij = L_cached.load(i, j);
71 bb[i] = γ * Lij + bb[i] * γoβ;
72 L_cached.store(bb[i] - Lij, i, j);
73 }
74 // Update A
75 for (index_t l = 0; l < kA; ++l) {
76 simd Ajl = A.load(j, l) * inv_β; // Scale Householder vector
77 A.store(Ajl, j, l);
78 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
79 simd Ail = A.load(i, l);
80 Ail -= bb[i] * Ajl;
81 A.store(Ail, i, l);
82 }
83 }
84 // Save block Householder matrix W
85 UNROLL_FOR (index_t i = 0; i < j + 1; ++i)
86 bb[i] *= inv_β;
87 bb[j] = γ; // inverse of diagonal
88 UNROLL_FOR (index_t i = 0; i < j + 1; ++i)
89 W.store(bb[i], i, j);
90 // TODO: try moving this to before update of A
91 }
92}
93
94template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
95[[gnu::hot, gnu::flatten]] void
99 using std::copysign;
100 using std::sqrt;
101 using simd = datapar::simd<T, Abi>;
102 // Pre-compute the offsets of the columns of L
103 auto L_cached = with_cached_access<R, R>(L);
104 BATMAT_ASSUME(kA > 0);
105
106 UNROLL_FOR (index_t j = 0; j < R; ++j) {
107 // Compute some inner products between A and a
108 simd bb[R]{};
109 for (index_t l = 0; l < kA; ++l) {
110 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0)) //
111 : A.load(j, l) * diag.load(l, 0);
112 UNROLL_FOR (index_t i = j; i < R; ++i)
113 bb[i] += A.load(i, l) * Ajl;
114 }
115 // Energy condition and Householder coefficients
116 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
117 const simd L̃jj = copysign(sqrt(Ljj * Ljj + α2), Ljj), β = Ljj + L̃jj;
118 simd γoβ = simd{2} * β / (β * β + α2), γ = β * γoβ, inv_β = simd{1} / β;
119 L_cached.store(L̃jj, j, j);
120 // Compute L̃
121 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
122 simd Lij = L_cached.load(i, j);
123 bb[i] = γ * Lij + bb[i] * γoβ;
124 L_cached.store(bb[i] - Lij, i, j);
125 }
126 // Update A
127 for (index_t l = 0; l < kA; ++l) {
128 simd Ajl = A.load(j, l) * inv_β; // Scale Householder vector
129 A.store(Ajl, j, l);
130 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
131 simd Ail = A.load(i, l);
132 Ail -= bb[i] * Ajl;
133 A.store(Ail, i, l);
134 }
135 }
136 }
137}
138
139namespace detail {
140
141template <class T, class Abi, int S>
142[[gnu::always_inline]] inline auto rotate(datapar::simd<T, Abi> x, std::integral_constant<int, S>) {
143 using ops::rotr;
144 return rotr<S>(x);
145}
146
147template <class T, class Abi>
148[[gnu::always_inline]] inline auto rotate(datapar::simd<T, Abi> x, int s) {
149 using ops::rot;
150 return rot(x, s);
151}
152
153} // namespace detail
154
155// A_out and B have the same size. A_in has the same number of rows but may have a different number
156// of columns, which means that only a part of A_in is nonzero. The nonzero part is defined by kAin
157// and kAin_offset.
158template <class T, class Abi, KernelConfig Conf, index_t R, index_t S, StorageOrder OL,
160[[gnu::hot, gnu::flatten]] void hyhound_diag_tail_microkernel(
161 index_t kA_in_offset, index_t kA_in, index_t k,
164 uview<const T, Abi, StorageOrder::ColMajor> diag, Structure struc_L, int rotate_A) noexcept {
165 using batmat::ops::cneg;
166 using simd = datapar::simd<T, Abi>;
167 BATMAT_ASSUME(k > 0);
168
169 // Compute product W = A B
170 simd V[S][R]{};
171 for (index_t lA = 0; lA < kA_in; ++lA) {
172 index_t lB = lA + kA_in_offset;
173 UNROLL_FOR (index_t j = 0; j < R; ++j) {
174 auto Bjl = Conf.sign_only ? cneg(B.load(j, lB), diag.load(lB, 0)) //
175 : B.load(j, lB) * diag.load(lB, 0);
176 UNROLL_FOR (index_t i = 0; i < S; ++i)
177 V[i][j] += A_in.load(i, lA) * Bjl;
178 }
179 }
180
181 // Solve system V = (L+U)W⁻¹ (in-place)
182 auto L_cached = with_cached_access<S, R>(L);
183 switch (struc_L) {
184 [[likely]]
185 case Structure::General: {
186 UNROLL_FOR (index_t j = 0; j < R; ++j) {
187 simd Wj[R];
188 UNROLL_FOR (index_t i = 0; i < j; ++i)
189 Wj[i] = W.load(i, j);
190 UNROLL_FOR (index_t i = 0; i < S; ++i) {
191 simd Lij = L_cached.load(i, j);
192 V[i][j] += Lij;
193 UNROLL_FOR (index_t l = 0; l < j; ++l)
194 V[i][j] -= V[i][l] * Wj[l];
195 V[i][j] *= W.load(j, j); // diagonal already inverted
196 Lij = V[i][j] - Lij;
197 L_cached.store(Lij, i, j);
198 }
199 }
200 } break;
201 case Structure::Zero: {
202 UNROLL_FOR (index_t j = 0; j < R; ++j) {
203 simd Wj[R];
204 UNROLL_FOR (index_t i = 0; i < j; ++i)
205 Wj[i] = W.load(i, j);
206 UNROLL_FOR (index_t i = 0; i < S; ++i) {
207 UNROLL_FOR (index_t l = 0; l < j; ++l)
208 V[i][j] -= V[i][l] * Wj[l];
209 V[i][j] *= W.load(j, j); // diagonal already inverted
210 }
211 }
212 } break;
213 case Structure::Upper: {
214 UNROLL_FOR (index_t j = 0; j < R; ++j) {
215 simd Wj[R];
216 UNROLL_FOR (index_t i = 0; i < j; ++i)
217 Wj[i] = W.load(i, j);
218 UNROLL_FOR (index_t i = 0; i < S; ++i) {
219 simd Lij;
220 if (i <= j) {
221 Lij = L_cached.load(i, j);
222 V[i][j] += Lij;
223 }
224 UNROLL_FOR (index_t l = 0; l < j; ++l)
225 V[i][j] -= V[i][l] * Wj[l];
226 V[i][j] *= W.load(j, j); // diagonal already inverted
227 if (i <= j) {
228 Lij = V[i][j] - Lij;
229 L_cached.store(Lij, i, j);
230 }
231 }
232 }
233 } break;
234 default: BATMAT_ASSUME(false);
235 }
236 // Update A -= V Bᵀ
237 const auto update_A = [&] [[gnu::always_inline]] (auto s) {
238 simd Bjl[R];
239 for (index_t lB = 0; lB < kA_in_offset; ++lB) [[unlikely]] {
240 UNROLL_FOR (index_t j = 0; j < R; ++j)
241 Bjl[j] = B.load(j, lB);
242 UNROLL_FOR (index_t i = 0; i < S; ++i) {
243 simd Ail{0};
244 UNROLL_FOR (index_t j = 0; j < R; ++j)
245 Ail -= V[i][j] * Bjl[j];
246 A_out.store(detail::rotate(Ail, s), i, lB);
247 }
248 }
249 for (index_t lB = kA_in_offset + kA_in; lB < k; ++lB) [[unlikely]] {
250 UNROLL_FOR (index_t j = 0; j < R; ++j)
251 Bjl[j] = B.load(j, lB);
252 UNROLL_FOR (index_t i = 0; i < S; ++i) {
253 simd Ail{0};
254 UNROLL_FOR (index_t j = 0; j < R; ++j)
255 Ail -= V[i][j] * Bjl[j];
256 A_out.store(detail::rotate(Ail, s), i, lB);
257 }
258 }
259 for (index_t lA = 0; lA < kA_in; ++lA) [[likely]] {
260 index_t lB = lA + kA_in_offset;
261 UNROLL_FOR (index_t j = 0; j < R; ++j)
262 Bjl[j] = B.load(j, lB);
263 UNROLL_FOR (index_t i = 0; i < S; ++i) {
264 auto Ail = A_in.load(i, lA);
265 UNROLL_FOR (index_t j = 0; j < R; ++j)
266 Ail -= V[i][j] * Bjl[j];
267 A_out.store(detail::rotate(Ail, s), i, lB);
268 }
269 }
270 };
271#if defined(__AVX512F__) && 0
272 update_A(rotate_A);
273#else
274 switch (rotate_A) {
275 [[likely]] case 0:
276 update_A(std::integral_constant<int, 0>{});
277 break;
278 case -1: update_A(std::integral_constant<int, -1>{}); break;
279 // case 1: update_A(std::integral_constant<int, 1>{}); break;
280 default: BATMAT_ASSUME(false);
281 }
282#endif
283}
284
285/// Block hyperbolic Householder factorization update using register blocking.
286/// This variant does not store the Householder representation W.
287template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
289 const view<const T, Abi> D) noexcept {
290 static constexpr index_constant<SizeR<T, Abi>> R;
291 static constexpr index_constant<SizeS<T, Abi>> S;
292 const index_t k = A.cols();
293 BATMAT_ASSUME(k > 0);
294 BATMAT_ASSUME(L.rows() >= L.cols());
295 BATMAT_ASSUME(L.rows() == A.rows());
296 BATMAT_ASSUME(A.cols() == D.rows());
297
299 alignas(W_t::alignment()) T W[W_t::size()];
300
301 // Sizeless views to partition and pass to the micro-kernels
302 const uview<T, Abi, OL> L_ = L;
303 const uview<T, Abi, OA> A_ = A;
305
306 // Process all diagonal blocks (in multiples of R, except the last).
307 if (L.rows() == L.cols()) {
309 0, L.cols(), R,
310 [&](index_t j) {
311 // Part of A corresponding to this diagonal block
312 // TODO: packing
313 auto Ad = A_.middle_rows(j);
314 auto Ld = L_.block(j, j);
315 // Process the diagonal block itself
316 hyhound_diag_diag_microkernel<T, Abi, Conf, R, OL, OA>(k, W, Ld, Ad, D_);
317 // Process all rows below the diagonal block (in multiples of S).
318 foreach_chunked_merged(
319 j + R, L.rows(), S,
320 [&](index_t i, auto rem_i) {
321 auto As = A_.middle_rows(i);
322 auto Ls = L_.block(i, j);
323 microkernel_tail_lut<T, Abi, Conf, OL, OA, OA>[rem_i - 1](
324 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
325 },
326 LoopDir::Backward); // TODO: decide on order
327 },
328 [&](index_t j, index_t rem_j) {
329 auto Ad = A_.middle_rows(j);
330 auto Ld = L_.block(j, j);
331 microkernel_full_lut<T, Abi, Conf, OL, OA>[rem_j - 1](k, Ld, Ad, D_);
332 });
333 } else {
334 foreach_chunked_merged(0, L.cols(), R, [&](index_t j, auto rem_j) {
335 // Part of A corresponding to this diagonal block
336 // TODO: packing
337 auto Ad = A_.middle_rows(j);
338 auto Ld = L_.block(j, j);
339 // Process the diagonal block itself
340 microkernel_diag_lut<T, Abi, Conf, OL, OA>[rem_j - 1](k, W, Ld, Ad, D_);
341 // Process all rows below the diagonal block (in multiples of S).
342 foreach_chunked_merged(
343 j + rem_j, L.rows(), S,
344 [&](index_t i, auto rem_i) {
345 auto As = A_.middle_rows(i);
346 auto Ls = L_.block(i, j);
347 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[rem_j - 1][rem_i - 1](
348 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
349 },
350 LoopDir::Backward); // TODO: decide on order
351 });
352 }
353}
354
355/// Block hyperbolic Householder factorization update using register blocking.
356/// This variant stores the Householder representation W.
357template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
359 const view<const T, Abi> D, const view<T, Abi> W) noexcept {
360 const index_t k = A.cols();
361 BATMAT_ASSUME(k > 0);
362 BATMAT_ASSUME(L.rows() >= L.cols());
363 BATMAT_ASSUME(L.rows() == A.rows());
364 BATMAT_ASSUME(D.rows() == k);
365 BATMAT_ASSUME(std::make_pair(W.rows(), W.cols()) == (hyhound_W_size<T, Abi>)(L));
366
367 static constexpr index_constant<SizeR<T, Abi>> R;
369
370 // Sizeless views to partition and pass to the micro-kernels
371 const uview<T, Abi, OL> L_ = L;
372 const uview<T, Abi, OA> A_ = A;
375
376 // Process all diagonal blocks (in multiples of R, except the last).
377 foreach_chunked_merged(0, L.cols(), R, [&](index_t j, auto nj) {
378 static constexpr index_constant<SizeS<T, Abi>> S;
379 // Part of A corresponding to this diagonal block
380 // TODO: packing
381 auto Ad = A_.middle_rows(j);
382 auto Ld = L_.block(j, j);
383 auto Wd = W_t{W_.middle_cols(j / R).data};
384 // Process the diagonal block itself
385 microkernel_diag_lut<T, Abi, Conf, OL, OA>[nj - 1](k, Wd, Ld, Ad, D_);
386 // Process all rows below the diagonal block (in multiples of S).
388 j + nj, L.rows(), S,
389 [&](index_t i, auto ni) {
390 auto As = A_.middle_rows(i);
391 auto Ls = L_.block(i, j);
392 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
393 0, k, k, Wd, Ls, As, As, Ad, D_, Structure::General, 0);
394 },
395 LoopDir::Backward); // TODO: decide on order
396 });
397}
398
399/// Apply a block hyperbolic Householder transformation.
400template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
402 const view<T, Abi, OA> Aout, const view<const T, Abi, OA> B,
403 const view<const T, Abi> D, const view<const T, Abi> W,
404 index_t kA_in_offset) noexcept {
405 const index_t k_in = Ain.cols(), k = Aout.cols();
406 BATMAT_ASSUME(k > 0);
407 BATMAT_ASSUME(Aout.rows() == Ain.rows());
408 BATMAT_ASSUME(Ain.rows() == L.rows());
409 BATMAT_ASSUME(B.rows() == L.cols());
410 BATMAT_ASSUME(B.cols() == k);
411 BATMAT_ASSUME(D.rows() == k);
412 BATMAT_ASSUME(0 <= kA_in_offset);
413 BATMAT_ASSUME(kA_in_offset + k_in <= k);
414
415 static constexpr index_constant<SizeR<T, Abi>> R;
417 BATMAT_ASSUME(std::make_pair(W.rows(), W.cols()) == (hyhound_W_size<T, Abi>)(L));
418
419 // Sizeless views to partition and pass to the micro-kernels
420 const uview<T, Abi, OL> L_ = L;
421 const uview<const T, Abi, OA> Ain_ = Ain;
422 const uview<T, Abi, OA> Aout_ = Aout;
423 const uview<const T, Abi, OA> B_ = B;
426
427 // Process all diagonal blocks (in multiples of R, except the last).
428 foreach_chunked_merged(0, L.cols(), R, [&](index_t j, auto nj) {
429 static constexpr index_constant<SizeS<T, Abi>> S;
430 // Part of A corresponding to this diagonal block
431 // TODO: packing
432 auto Ad = B_.middle_rows(j);
433 auto Wd = W_t{W_.middle_cols(j / R).data};
434 // Process all rows (in multiples of S).
435 foreach_chunked_merged( // TODO: swap loop order?
436 0, L.rows(), S,
437 [&](index_t i, auto ni) {
438 auto Aini = j == 0 ? Ain_.middle_rows(i) : Aout_.middle_rows(i);
439 auto Aouti = Aout_.middle_rows(i);
440 auto Ls = L_.block(i, j);
441 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
442 j == 0 ? kA_in_offset : 0, j == 0 ? k_in : k, k, Wd, Ls, Aini, Aouti, Ad, D_,
443 Structure::General, 0);
444 },
445 LoopDir::Backward); // TODO: decide on order
446 });
447}
448
449/// Same as hyhound_diag_register but for two block rows at once.
450template <class T, class Abi, KernelConfig Conf, StorageOrder OL1, StorageOrder OA1,
451 StorageOrder OL2, StorageOrder OA2>
453 const view<T, Abi, OL2> L21, const view<T, Abi, OA2> A2,
454 const view<const T, Abi> D) noexcept {
455 const index_t k = A1.cols();
456 BATMAT_ASSUME(k > 0);
457 BATMAT_ASSUME(L11.rows() >= L11.cols());
458 BATMAT_ASSUME(L11.rows() == A1.rows());
459 BATMAT_ASSUME(D.rows() == k);
460 BATMAT_ASSUME(A2.cols() == k);
461 BATMAT_ASSUME(L21.cols() == L11.cols());
462
463 static constexpr index_constant<SizeR<T, Abi>> R;
465 alignas(W_t::alignment()) T W[W_t::size()];
466
467 // Sizeless views to partition and pass to the micro-kernels
468 const uview<T, Abi, OL1> L11_ = L11;
469 const uview<T, Abi, OA1> A1_ = A1;
470 const uview<T, Abi, OL2> L21_ = L21;
471 const uview<T, Abi, OA2> A2_ = A2;
473
474 // Process all diagonal blocks (in multiples of R, except the last).
475 foreach_chunked_merged(0, L11.cols(), R, [&](index_t j, auto nj) {
476 static constexpr index_constant<SizeS<T, Abi>> S;
477 // Part of A corresponding to this diagonal block
478 // TODO: packing
479 auto Ad = A1_.middle_rows(j);
480 auto Ld = L11_.block(j, j);
481 // Process the diagonal block itself
482 microkernel_diag_lut<T, Abi, Conf, OL1, OA1>[nj - 1](k, W, Ld, Ad, D_);
483 // Process all rows below the diagonal block (in multiples of S).
484 foreach_chunked_merged(
485 j + nj, L11.rows(), S,
486 [&](index_t i, auto ni) {
487 auto As = A1_.middle_rows(i);
488 auto Ls = L11_.block(i, j);
489 microkernel_tail_lut_2<T, Abi, Conf, OL1, OA1, OA1>[nj - 1][ni - 1](
490 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
491 },
492 LoopDir::Backward); // TODO: decide on order
494 0, L21.rows(), S,
495 [&](index_t i, auto ni) {
496 auto As = A2_.middle_rows(i);
497 auto Ls = L21_.block(i, j);
498 microkernel_tail_lut_2<T, Abi, Conf, OL2, OA2, OA1>[nj - 1][ni - 1](
499 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
500 },
501 LoopDir::Backward); // TODO: decide on order
502 });
503}
504
505/**
506 * Performs a factorization update of the following matrix:
507 *
508 * [ A11 A12 | L11 ] [ 0 0 | L̃11 ]
509 * [ 0 A22 | L21 ] Q = [ Ã21 Ã22 | L̃21 ]
510 * [ A31 0 | L31 ] [ Ã31 Ã32 | L̃31 ]
511 */
512template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OW, StorageOrder OY,
513 StorageOrder OU>
515 const view<T, Abi, OW> A1, // work
516 const view<T, Abi, OY> L21, // Y
517 const view<const T, Abi, OW> A22, // work
518 const view<T, Abi, OW> A2_out, // work
519 const view<T, Abi, OU> L31, // U
520 const view<const T, Abi, OW> A31, // work
521 const view<T, Abi, OW> A3_out, // work
522 const view<const T, Abi> D) noexcept {
523 const index_t k = A1.cols(), k1 = A31.cols(), k2 = A22.cols();
524 BATMAT_ASSUME(k > 0);
525 BATMAT_ASSUME(L11.rows() >= L11.cols());
526 BATMAT_ASSUME(L11.rows() == A1.rows());
527 BATMAT_ASSUME(L21.rows() == A22.rows());
528 BATMAT_ASSUME(L31.rows() == A31.rows());
529 BATMAT_ASSUME(A22.rows() == A2_out.rows());
530 BATMAT_ASSUME(A31.rows() == A3_out.rows());
531 BATMAT_ASSUME(D.rows() == k);
532 BATMAT_ASSUME(L21.cols() == L11.cols());
533 BATMAT_ASSUME(L31.cols() == L11.cols());
534 BATMAT_ASSUME(k1 + k2 == k);
535
536 static constexpr index_constant<SizeR<T, Abi>> R;
538 alignas(W_t::alignment()) T W[W_t::size()];
539
540 // Sizeless views to partition and pass to the micro-kernels
541 const uview<T, Abi, OL> L11_ = L11;
542 const uview<T, Abi, OW> A1_ = A1;
543 const uview<T, Abi, OY> L21_ = L21;
544 const uview<const T, Abi, OW> A22_ = A22;
545 const uview<T, Abi, OW> A2_out_ = A2_out;
546 const uview<T, Abi, OU> L31_ = L31;
547 const uview<const T, Abi, OW> A31_ = A31;
548 const uview<T, Abi, OW> A3_out_ = A3_out;
550
551 // Process all diagonal blocks (in multiples of R, except the last).
552 foreach_chunked_merged(0, L11.cols(), R, [&](index_t j, auto nj) {
553 static constexpr index_constant<SizeS<T, Abi>> S;
554 // Part of A corresponding to this diagonal block
555 // TODO: packing
556 auto Ad = A1_.middle_rows(j);
557 auto Ld = L11_.block(j, j);
558 // Process the diagonal block itself
559 microkernel_diag_lut<T, Abi, Conf, OL, OW>[nj - 1](k, W, Ld, Ad, D_);
560 // Process all rows below the diagonal block (in multiples of S).
561 foreach_chunked_merged(
562 j + nj, L11.rows(), S,
563 [&](index_t i, auto ni) {
564 auto As = A1_.middle_rows(i);
565 auto Ls = L11_.block(i, j);
566 microkernel_tail_lut_2<T, Abi, Conf, OL, OW, OW>[nj - 1][ni - 1](
567 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
568 },
569 LoopDir::Backward); // TODO: decide on order
571 0, L21.rows(), S,
572 [&](index_t i, auto ni) {
573 auto As_out = A2_out_.middle_rows(i);
574 auto As = j == 0 ? A22_.middle_rows(i) : As_out;
575 auto Ls = L21_.block(i, j);
576 // First half of A2 is implicitly zero in first pass
577 index_t offset_s = j == 0 ? k1 : 0, k_s = j == 0 ? k2 : k;
578 microkernel_tail_lut_2<T, Abi, Conf, OY, OW, OW>[nj - 1][ni - 1](
579 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
580 },
581 LoopDir::Backward); // TODO: decide on order
583 0, L31.rows(), S,
584 [&](index_t i, auto ni) {
585 auto As_out = A3_out_.middle_rows(i);
586 auto As = j == 0 ? A31_.middle_rows(i) : As_out;
587 auto Ls = L31_.block(i, j);
588 // Second half of A3 is implicitly zero in first pass
589 index_t offset_s = 0, k_s = j == 0 ? k1 : k;
590 microkernel_tail_lut_2<T, Abi, Conf, OU, OW, OW>[nj - 1][ni - 1](
591 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
592 },
593 LoopDir::Backward); // TODO: decide on order
594 });
595}
596
597/**
598 * Performs a factorization update of the following matrix:
599 *
600 * [ A1 | L11 ] [ 0 | L̃11 ]
601 * [ A2 | L21 ] Q = [ Ã2 | L̃21 ]
602 * [ 0 | Lu1 ] [ Ãu | L̃u1 ]
603 *
604 * where Lu1 and L̃u1 are upper triangular
605 */
606template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OLu,
607 StorageOrder OAu>
609 const view<T, Abi, OL> L21, const view<const T, Abi, OA> A2,
610 const view<T, Abi, OA> A2_out, const view<T, Abi, OLu> Lu1,
611 const view<T, Abi, OAu> Au_out, const view<const T, Abi> D,
612 bool shift_A_out) noexcept {
613 const index_t k = A1.cols();
614 BATMAT_ASSUME(k > 0);
615 BATMAT_ASSUME(L11.rows() >= L11.cols());
616 BATMAT_ASSUME(L11.rows() == A1.rows());
617 BATMAT_ASSUME(L21.rows() == A2.rows());
618 BATMAT_ASSUME(A2_out.rows() == A2.rows());
619 BATMAT_ASSUME(A2_out.cols() == A2.cols());
620 BATMAT_ASSUME(Lu1.rows() == Au_out.rows());
621 BATMAT_ASSUME(A1.cols() == D.rows());
622 BATMAT_ASSUME(A2.cols() == A1.cols());
623 BATMAT_ASSUME(L21.cols() == L11.cols());
624 BATMAT_ASSUME(Lu1.cols() == L11.cols());
625
626 static constexpr index_constant<SizeR<T, Abi>> R;
627 static constexpr index_constant<SizeS<T, Abi>> S;
628 static_assert(R == S);
630 alignas(W_t::alignment()) T W[W_t::size()];
631
632 // Sizeless views to partition and pass to the micro-kernels
633 const uview<T, Abi, OL> L11_ = L11;
634 const uview<T, Abi, OA> A1_ = A1;
635 const uview<T, Abi, OL> L21_ = L21;
636 const uview<const T, Abi, OA> A2_ = A2;
637 const uview<T, Abi, OA> A2_out_ = A2_out;
638 const uview<T, Abi, OLu> Lu1_ = Lu1;
639 const uview<T, Abi, OAu> Au_out_ = Au_out;
641
642 // Process all diagonal blocks (in multiples of R, except the last).
643 foreach_chunked_merged(0, L11.cols(), R, [&](index_t j, auto nj) {
644 const bool do_shift = shift_A_out && j + nj == L11.cols();
645 // Part of A corresponding to this diagonal block
646 // TODO: packing
647 auto Ad = A1_.middle_rows(j);
648 auto Ld = L11_.block(j, j);
649 // Process the diagonal block itself
650 microkernel_diag_lut<T, Abi, Conf, OL, OA>[nj - 1](k, W, Ld, Ad, D_);
651 // Process all rows below the diagonal block (in multiples of S).
652 foreach_chunked_merged(
653 j + nj, L11.rows(), S,
654 [&](index_t i, auto ni) {
655 auto As = A1_.middle_rows(i);
656 auto Ls = L11_.block(i, j);
657 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
658 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
659 },
660 LoopDir::Backward); // TODO: decide on order
662 0, L21.rows(), S,
663 [&](index_t i, auto ni) {
664 auto As_out = A2_out_.middle_rows(i);
665 auto As = j == 0 ? A2_.middle_rows(i) : As_out;
666 auto Ls = L21_.block(i, j);
667 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
668 0, k, k, W, Ls, As, As_out, Ad, D_, Structure::General, do_shift ? -1 : 0);
669 },
670 LoopDir::Backward); // TODO: decide on order
672 0, Lu1.rows(), S,
673 [&](index_t i, auto ni) {
674 auto As_out = Au_out_.middle_rows(i);
675 auto As = As_out;
676 auto Ls = Lu1_.block(i, j);
677 // Au is implicitly zero in first pass
678 const auto struc = i == j ? Structure::Upper
679 : i < j ? Structure::General
680 : Structure::Zero;
681 microkernel_tail_lut_2<T, Abi, Conf, OLu, OAu, OA>[nj - 1][ni - 1](
682 0, j == 0 ? 0 : k, k, W, Ls, As, As_out, Ad, D_, struc, do_shift ? -1 : 0);
683 },
684 LoopDir::Backward); // TODO: decide on order
685 });
686}
687
688} // namespace batmat::linalg::micro_kernels::hyhound
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
datapar::simd< F, Abi > rotr(datapar::simd< F, Abi > x)
Rotate the elements of x to the right by S positions.
Definition rotate.hpp:239
datapar::simd< F, Abi > rot(datapar::simd< F, Abi > x, int s)
Rotate the elements of x to the right by s positions.
Definition rotate.hpp:18
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
Definition loop.hpp:20
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
auto rotate(datapar::simd< T, Abi > x, std::integral_constant< int, S >)
Definition hyhound.tpp:142
void hyhound_diag_full_microkernel(index_t kA, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
Definition hyhound.tpp:96
const constinit auto microkernel_full_lut
Definition hyhound.tpp:24
void hyhound_diag_cyclic_register(view< T, Abi, OL > L11, view< T, Abi, OW > A1, view< T, Abi, OY > L21, view< const T, Abi, OW > A22, view< T, Abi, OW > A2_out, view< T, Abi, OU > L31, view< const T, Abi, OW > A31, view< T, Abi, OW > A3_out, view< const T, Abi > D) noexcept
Performs a factorization update of the following matrix:
Definition hyhound.tpp:514
void hyhound_diag_register(view< T, Abi, OL > L, view< T, Abi, OA > A, view< const T, Abi > D) noexcept
Block hyperbolic Householder factorization update using register blocking.
Definition hyhound.tpp:288
void hyhound_diag_riccati_register(view< T, Abi, OL > L11, view< T, Abi, OA > A1, view< T, Abi, OL > L21, view< const T, Abi, OA > A2, view< T, Abi, OA > A2_out, view< T, Abi, OLu > Lu1, view< T, Abi, OAu > Au_out, view< const T, Abi > D, bool shift_A_out) noexcept
Performs a factorization update of the following matrix:
Definition hyhound.tpp:608
void hyhound_diag_tail_microkernel(index_t kA_in_offset, index_t kA_in, index_t k, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< const T, Abi, OA > A_in, uview< T, Abi, OA > A_out, uview< const T, Abi, OB > B, uview< const T, Abi, StorageOrder::ColMajor > diag, Structure struc_L, int rotate_A) noexcept
Definition hyhound.tpp:160
constexpr std::pair< index_t, index_t > hyhound_W_size(view< T, Abi, OL > L)
Definition hyhound.hpp:82
void hyhound_diag_diag_microkernel(index_t kA, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
Definition hyhound.tpp:43
void hyhound_diag_apply_register(view< T, Abi, OL > L, view< const T, Abi, OA > Ain, view< T, Abi, OA > Aout, view< const T, Abi, OA > B, view< const T, Abi > D, view< const T, Abi > W, index_t kA_in_offset=0) noexcept
Apply a block hyperbolic Householder transformation.
Definition hyhound.tpp:401
const constinit auto microkernel_tail_lut_2
Definition hyhound.tpp:36
const constinit auto microkernel_tail_lut
Definition hyhound.tpp:30
const constinit auto microkernel_diag_lut
Definition hyhound.tpp:18
void hyhound_diag_2_register(view< T, Abi, OL1 > L11, view< T, Abi, OA1 > A1, view< T, Abi, OL2 > L21, view< T, Abi, OA2 > A2, view< const T, Abi > D) noexcept
Same as hyhound_diag_register but for two block rows at once.
Definition hyhound.tpp:452
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
T cneg(T x, T signs)
Conditionally negates the sign bit of x, depending on signs, which should contain only ±0 (i....
Definition cneg.hpp:42
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114