batmat 0.0.23
Batched linear algebra routines
Loading...
Searching...
No Matches
hyhound.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
7#include <batmat/lut.hpp>
8#include <batmat/ops/cneg.hpp>
10#include <guanaqo/trace.hpp>
11#include <type_traits>
12
13#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
14
16
17template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
18inline const constinit auto microkernel_diag_lut =
21 });
22
23template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
24inline const constinit auto microkernel_full_lut =
27 });
28
29template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OB>
30inline const constinit auto microkernel_tail_lut =
33 });
34
35template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OB>
39 });
40
41template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
42[[gnu::hot, gnu::flatten]] void
47 using std::copysign;
48 using std::sqrt;
49 using simd = datapar::simd<T, Abi>;
50 // Pre-compute the offsets of the columns of L
51 auto L_cached = with_cached_access<R, R>(L);
52 BATMAT_ASSUME(kA > 0);
53 static constexpr auto safe_min = std::numeric_limits<T>::min();
54
55 UNROLL_FOR (index_t j = 0; j < R; ++j) {
56 // Compute all inner products between A and a
57 simd bb[R]{};
58 for (index_t l = 0; l < kA; ++l) {
59 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0)) //
60 : A.load(j, l) * diag.load(l, 0);
61 UNROLL_FOR (index_t i = 0; i < R; ++i)
62 bb[i] += A.load(i, l) * Ajl;
63 }
64 // Energy condition and Householder coefficients
65 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
66 const simd abs_L̃jj = sqrt(Ljj * Ljj + α2);
67 const simd L̃jj = copysign(abs_L̃jj, Ljj), β = Ljj + L̃jj;
68 simd γoβ = datapar::select(abs_L̃jj > safe_min, simd{1} / L̃jj, simd{0}), γ = β * γoβ,
69 inv_β = datapar::select(abs_L̃jj > safe_min, simd{1} / β, simd{0});
70 L_cached.store(L̃jj, j, j);
71 // Compute L̃
72 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
73 simd Lij = L_cached.load(i, j);
74 bb[i] = γ * Lij + bb[i] * γoβ;
75 L_cached.store(bb[i] - Lij, i, j);
76 }
77 // Update A
78 for (index_t l = 0; l < kA; ++l) {
79 simd Ajl = A.load(j, l) * inv_β; // Scale Householder vector
80 A.store(Ajl, j, l);
81 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
82 simd Ail = A.load(i, l);
83 Ail -= bb[i] * Ajl;
84 A.store(Ail, i, l);
85 }
86 }
87 // Save block Householder matrix W
88 UNROLL_FOR (index_t i = 0; i < j + 1; ++i)
89 bb[i] *= inv_β;
90 bb[j] = γ; // inverse of diagonal
91 UNROLL_FOR (index_t i = 0; i < j + 1; ++i)
92 W.store(bb[i], i, j);
93 // TODO: try moving this to before update of A
94 }
95}
96
97template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
98[[gnu::hot, gnu::flatten]] void
101 using batmat::ops::cneg;
102 using std::copysign;
103 using std::sqrt;
104 using simd = datapar::simd<T, Abi>;
105 // Pre-compute the offsets of the columns of L
106 auto L_cached = with_cached_access<R, R>(L);
107 BATMAT_ASSUME(kA > 0);
108 static constexpr auto safe_min = std::numeric_limits<T>::min();
109
110 UNROLL_FOR (index_t j = 0; j < R; ++j) {
111 // Compute some inner products between A and a
112 simd bb[R]{};
113 for (index_t l = 0; l < kA; ++l) {
114 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0)) //
115 : A.load(j, l) * diag.load(l, 0);
116 UNROLL_FOR (index_t i = j; i < R; ++i)
117 bb[i] += A.load(i, l) * Ajl;
118 }
119 // Energy condition and Householder coefficients
120 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
121 const simd abs_L̃jj = sqrt(Ljj * Ljj + α2);
122 const simd L̃jj = copysign(abs_L̃jj, Ljj), β = Ljj + L̃jj;
123 simd γoβ = datapar::select(abs_L̃jj > safe_min, simd{1} / L̃jj, simd{0}), γ = β * γoβ,
124 inv_β = datapar::select(abs_L̃jj > safe_min, simd{1} / β, simd{0});
125 L_cached.store(L̃jj, j, j);
126 // Compute L̃
127 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
128 simd Lij = L_cached.load(i, j);
129 bb[i] = γ * Lij + bb[i] * γoβ;
130 L_cached.store(bb[i] - Lij, i, j);
131 }
132 // Update A
133 for (index_t l = 0; l < kA; ++l) {
134 simd Ajl = A.load(j, l) * inv_β; // Scale Householder vector
135 A.store(Ajl, j, l);
136 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
137 simd Ail = A.load(i, l);
138 Ail -= bb[i] * Ajl;
139 A.store(Ail, i, l);
140 }
141 }
142 }
143}
144
145namespace detail {
146
147template <class T, class Abi, int S>
148[[gnu::always_inline]] inline auto rotate(datapar::simd<T, Abi> x, std::integral_constant<int, S>) {
149 using ops::rotr;
150 return rotr<S>(x);
151}
152
153template <class T, class Abi>
154[[gnu::always_inline]] inline auto rotate(datapar::simd<T, Abi> x, int s) {
155 using ops::rot;
156 return rot(x, s);
157}
158
159} // namespace detail
160
161// A_out and B have the same size. A_in has the same number of rows but may have a different number
162// of columns, which means that only a part of A_in is nonzero. The nonzero part is defined by kAin
163// and kAin_offset.
164template <class T, class Abi, KernelConfig Conf, index_t R, index_t S, StorageOrder OL,
166[[gnu::hot, gnu::flatten]] void hyhound_diag_tail_microkernel(
167 index_t kA_in_offset, index_t kA_in, index_t k,
170 uview<const T, Abi, StorageOrder::ColMajor> diag, Structure struc_L, int rotate_A) noexcept {
171 using batmat::ops::cneg;
172 using simd = datapar::simd<T, Abi>;
173 BATMAT_ASSUME(k > 0);
174
175 // Compute product W = A B
176 simd V[S][R]{};
177 for (index_t lA = 0; lA < kA_in; ++lA) {
178 index_t lB = lA + kA_in_offset;
179 UNROLL_FOR (index_t j = 0; j < R; ++j) {
180 auto Bjl = Conf.sign_only ? cneg(B.load(j, lB), diag.load(lB, 0)) //
181 : B.load(j, lB) * diag.load(lB, 0);
182 UNROLL_FOR (index_t i = 0; i < S; ++i)
183 V[i][j] += A_in.load(i, lA) * Bjl;
184 }
185 }
186
187 // Solve system V = (L+U)W⁻¹ (in-place)
188 auto L_cached = with_cached_access<S, R>(L);
189 switch (struc_L) {
190 [[likely]]
191 case Structure::General: {
192 UNROLL_FOR (index_t j = 0; j < R; ++j) {
193 simd Wj[R];
194 UNROLL_FOR (index_t i = 0; i < j; ++i)
195 Wj[i] = W.load(i, j);
196 UNROLL_FOR (index_t i = 0; i < S; ++i) {
197 simd Lij = L_cached.load(i, j);
198 V[i][j] += Lij;
199 UNROLL_FOR (index_t l = 0; l < j; ++l)
200 V[i][j] -= V[i][l] * Wj[l];
201 V[i][j] *= W.load(j, j); // diagonal already inverted
202 Lij = V[i][j] - Lij;
203 L_cached.store(Lij, i, j);
204 }
205 }
206 } break;
207 case Structure::Zero: {
208 UNROLL_FOR (index_t j = 0; j < R; ++j) {
209 simd Wj[R];
210 UNROLL_FOR (index_t i = 0; i < j; ++i)
211 Wj[i] = W.load(i, j);
212 UNROLL_FOR (index_t i = 0; i < S; ++i) {
213 UNROLL_FOR (index_t l = 0; l < j; ++l)
214 V[i][j] -= V[i][l] * Wj[l];
215 V[i][j] *= W.load(j, j); // diagonal already inverted
216 }
217 }
218 } break;
219 case Structure::Upper: {
220 UNROLL_FOR (index_t j = 0; j < R; ++j) {
221 simd Wj[R];
222 UNROLL_FOR (index_t i = 0; i < j; ++i)
223 Wj[i] = W.load(i, j);
224 UNROLL_FOR (index_t i = 0; i < S; ++i) {
225 simd Lij;
226 if (i <= j) {
227 Lij = L_cached.load(i, j);
228 V[i][j] += Lij;
229 }
230 UNROLL_FOR (index_t l = 0; l < j; ++l)
231 V[i][j] -= V[i][l] * Wj[l];
232 V[i][j] *= W.load(j, j); // diagonal already inverted
233 if (i <= j) {
234 Lij = V[i][j] - Lij;
235 L_cached.store(Lij, i, j);
236 }
237 }
238 }
239 } break;
240 default: BATMAT_ASSUME(false);
241 }
242 // Update A -= V Bᵀ
243 const auto update_A = [&] [[gnu::always_inline]] (auto s) {
244 simd Bjl[R];
245 for (index_t lB = 0; lB < kA_in_offset; ++lB) [[unlikely]] {
246 UNROLL_FOR (index_t j = 0; j < R; ++j)
247 Bjl[j] = B.load(j, lB);
248 UNROLL_FOR (index_t i = 0; i < S; ++i) {
249 simd Ail{0};
250 UNROLL_FOR (index_t j = 0; j < R; ++j)
251 Ail -= V[i][j] * Bjl[j];
252 A_out.store(detail::rotate(Ail, s), i, lB);
253 }
254 }
255 for (index_t lB = kA_in_offset + kA_in; lB < k; ++lB) [[unlikely]] {
256 UNROLL_FOR (index_t j = 0; j < R; ++j)
257 Bjl[j] = B.load(j, lB);
258 UNROLL_FOR (index_t i = 0; i < S; ++i) {
259 simd Ail{0};
260 UNROLL_FOR (index_t j = 0; j < R; ++j)
261 Ail -= V[i][j] * Bjl[j];
262 A_out.store(detail::rotate(Ail, s), i, lB);
263 }
264 }
265 for (index_t lA = 0; lA < kA_in; ++lA) [[likely]] {
266 index_t lB = lA + kA_in_offset;
267 UNROLL_FOR (index_t j = 0; j < R; ++j)
268 Bjl[j] = B.load(j, lB);
269 UNROLL_FOR (index_t i = 0; i < S; ++i) {
270 auto Ail = A_in.load(i, lA);
271 UNROLL_FOR (index_t j = 0; j < R; ++j)
272 Ail -= V[i][j] * Bjl[j];
273 A_out.store(detail::rotate(Ail, s), i, lB);
274 }
275 }
276 };
277#if defined(__AVX512F__) && 0
278 update_A(rotate_A);
279#else
280 switch (rotate_A) {
281 [[likely]] case 0:
282 update_A(std::integral_constant<int, 0>{});
283 break;
284 case -1: update_A(std::integral_constant<int, -1>{}); break;
285 // case 1: update_A(std::integral_constant<int, 1>{}); break;
286 default: BATMAT_ASSUME(false);
287 }
288#endif
289}
290
291/// Block hyperbolic Householder factorization update using register blocking.
292/// This variant does not store the Householder representation W.
293template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
295 const view<const T, Abi> D) noexcept {
296 static constexpr index_constant<SizeR<T, Abi>> R;
297 static constexpr index_constant<SizeS<T, Abi>> S;
298 const index_t k = A.cols();
299 BATMAT_ASSUME(k > 0);
300 BATMAT_ASSUME(L.rows() >= L.cols());
301 BATMAT_ASSUME(L.rows() == A.rows());
302 BATMAT_ASSUME(A.cols() == D.rows());
303
305 alignas(W_t::alignment()) T W[W_t::size()];
306
307 // Sizeless views to partition and pass to the micro-kernels
308 const uview<T, Abi, OL> L_ = L;
309 const uview<T, Abi, OA> A_ = A;
311
312 // Process all diagonal blocks (in multiples of R, except the last).
313 if (L.rows() == L.cols()) {
315 0, L.cols(), R,
316 [&](index_t j) {
317 // Part of A corresponding to this diagonal block
318 // TODO: packing
319 auto Ad = A_.middle_rows(j);
320 auto Ld = L_.block(j, j);
321 // Process the diagonal block itself
322 hyhound_diag_diag_microkernel<T, Abi, Conf, R, OL, OA>(k, W, Ld, Ad, D_);
323 // Process all rows below the diagonal block (in multiples of S).
324 foreach_chunked_merged(
325 j + R, L.rows(), S,
326 [&](index_t i, auto rem_i) {
327 auto As = A_.middle_rows(i);
328 auto Ls = L_.block(i, j);
329 microkernel_tail_lut<T, Abi, Conf, OL, OA, OA>[rem_i - 1](
330 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
331 },
332 LoopDir::Backward); // TODO: decide on order
333 },
334 [&](index_t j, index_t rem_j) {
335 auto Ad = A_.middle_rows(j);
336 auto Ld = L_.block(j, j);
337 microkernel_full_lut<T, Abi, Conf, OL, OA>[rem_j - 1](k, Ld, Ad, D_);
338 });
339 } else {
340 foreach_chunked_merged(0, L.cols(), R, [&](index_t j, auto rem_j) {
341 // Part of A corresponding to this diagonal block
342 // TODO: packing
343 auto Ad = A_.middle_rows(j);
344 auto Ld = L_.block(j, j);
345 // Process the diagonal block itself
346 microkernel_diag_lut<T, Abi, Conf, OL, OA>[rem_j - 1](k, W, Ld, Ad, D_);
347 // Process all rows below the diagonal block (in multiples of S).
348 foreach_chunked_merged(
349 j + rem_j, L.rows(), S,
350 [&](index_t i, auto rem_i) {
351 auto As = A_.middle_rows(i);
352 auto Ls = L_.block(i, j);
353 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[rem_j - 1][rem_i - 1](
354 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
355 },
356 LoopDir::Backward); // TODO: decide on order
357 });
358 }
359}
360
361/// Block hyperbolic Householder factorization update using register blocking.
362/// This variant stores the Householder representation W.
363template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
365 const view<const T, Abi> D, const view<T, Abi> W) noexcept {
366 const index_t k = A.cols();
367 BATMAT_ASSUME(k > 0);
368 BATMAT_ASSUME(L.rows() >= L.cols());
369 BATMAT_ASSUME(L.rows() == A.rows());
370 BATMAT_ASSUME(D.rows() == k);
371 BATMAT_ASSUME(std::make_pair(W.rows(), W.cols()) == (hyhound_W_size<T, Abi>)(L));
372
373 static constexpr index_constant<SizeR<T, Abi>> R;
375
376 // Sizeless views to partition and pass to the micro-kernels
377 const uview<T, Abi, OL> L_ = L;
378 const uview<T, Abi, OA> A_ = A;
381
382 // Process all diagonal blocks (in multiples of R, except the last).
383 foreach_chunked_merged(0, L.cols(), R, [&](index_t j, auto nj) {
384 static constexpr index_constant<SizeS<T, Abi>> S;
385 // Part of A corresponding to this diagonal block
386 // TODO: packing
387 auto Ad = A_.middle_rows(j);
388 auto Ld = L_.block(j, j);
389 auto Wd = W_t{W_.middle_cols(j / R).data};
390 // Process the diagonal block itself
391 microkernel_diag_lut<T, Abi, Conf, OL, OA>[nj - 1](k, Wd, Ld, Ad, D_);
392 // Process all rows below the diagonal block (in multiples of S).
394 j + nj, L.rows(), S,
395 [&](index_t i, auto ni) {
396 auto As = A_.middle_rows(i);
397 auto Ls = L_.block(i, j);
398 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
399 0, k, k, Wd, Ls, As, As, Ad, D_, Structure::General, 0);
400 },
401 LoopDir::Backward); // TODO: decide on order
402 });
403}
404
405/// Apply a block hyperbolic Householder transformation.
406template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
408 const view<T, Abi, OA> Aout, const view<const T, Abi, OA> B,
409 const view<const T, Abi> D, const view<const T, Abi> W,
410 index_t kA_in_offset) noexcept {
411 const index_t k_in = Ain.cols(), k = Aout.cols();
412 BATMAT_ASSUME(k > 0);
413 BATMAT_ASSUME(Aout.rows() == Ain.rows());
414 BATMAT_ASSUME(Ain.rows() == L.rows());
415 BATMAT_ASSUME(B.rows() == L.cols());
416 BATMAT_ASSUME(B.cols() == k);
417 BATMAT_ASSUME(D.rows() == k);
418 BATMAT_ASSUME(0 <= kA_in_offset);
419 BATMAT_ASSUME(kA_in_offset + k_in <= k);
420
421 static constexpr index_constant<SizeR<T, Abi>> R;
423 BATMAT_ASSUME(std::make_pair(W.rows(), W.cols()) == (hyhound_W_size<T, Abi>)(L));
424
425 // Sizeless views to partition and pass to the micro-kernels
426 const uview<T, Abi, OL> L_ = L;
427 const uview<const T, Abi, OA> Ain_ = Ain;
428 const uview<T, Abi, OA> Aout_ = Aout;
429 const uview<const T, Abi, OA> B_ = B;
432
433 // Process all diagonal blocks (in multiples of R, except the last).
434 foreach_chunked_merged(0, L.cols(), R, [&](index_t j, auto nj) {
435 static constexpr index_constant<SizeS<T, Abi>> S;
436 // Part of A corresponding to this diagonal block
437 // TODO: packing
438 auto Ad = B_.middle_rows(j);
439 auto Wd = W_t{W_.middle_cols(j / R).data};
440 // Process all rows (in multiples of S).
441 foreach_chunked_merged( // TODO: swap loop order?
442 0, L.rows(), S,
443 [&](index_t i, auto ni) {
444 auto Aini = j == 0 ? Ain_.middle_rows(i) : Aout_.middle_rows(i);
445 auto Aouti = Aout_.middle_rows(i);
446 auto Ls = L_.block(i, j);
447 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
448 j == 0 ? kA_in_offset : 0, j == 0 ? k_in : k, k, Wd, Ls, Aini, Aouti, Ad, D_,
449 Structure::General, 0);
450 },
451 LoopDir::Backward); // TODO: decide on order
452 });
453}
454
455/// Same as hyhound_diag_register but for two block rows at once.
456template <class T, class Abi, KernelConfig Conf, StorageOrder OL1, StorageOrder OA1,
457 StorageOrder OL2, StorageOrder OA2>
459 const view<T, Abi, OL2> L21, const view<T, Abi, OA2> A2,
460 const view<const T, Abi> D) noexcept {
461 const index_t k = A1.cols();
462 BATMAT_ASSUME(k > 0);
463 BATMAT_ASSUME(L11.rows() >= L11.cols());
464 BATMAT_ASSUME(L11.rows() == A1.rows());
465 BATMAT_ASSUME(D.rows() == k);
466 BATMAT_ASSUME(A2.cols() == k);
467 BATMAT_ASSUME(L21.cols() == L11.cols());
468
469 static constexpr index_constant<SizeR<T, Abi>> R;
471 alignas(W_t::alignment()) T W[W_t::size()];
472
473 // Sizeless views to partition and pass to the micro-kernels
474 const uview<T, Abi, OL1> L11_ = L11;
475 const uview<T, Abi, OA1> A1_ = A1;
476 const uview<T, Abi, OL2> L21_ = L21;
477 const uview<T, Abi, OA2> A2_ = A2;
479
480 // Process all diagonal blocks (in multiples of R, except the last).
481 foreach_chunked_merged(0, L11.cols(), R, [&](index_t j, auto nj) {
482 static constexpr index_constant<SizeS<T, Abi>> S;
483 // Part of A corresponding to this diagonal block
484 // TODO: packing
485 auto Ad = A1_.middle_rows(j);
486 auto Ld = L11_.block(j, j);
487 // Process the diagonal block itself
488 microkernel_diag_lut<T, Abi, Conf, OL1, OA1>[nj - 1](k, W, Ld, Ad, D_);
489 // Process all rows below the diagonal block (in multiples of S).
490 foreach_chunked_merged(
491 j + nj, L11.rows(), S,
492 [&](index_t i, auto ni) {
493 auto As = A1_.middle_rows(i);
494 auto Ls = L11_.block(i, j);
495 microkernel_tail_lut_2<T, Abi, Conf, OL1, OA1, OA1>[nj - 1][ni - 1](
496 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
497 },
498 LoopDir::Backward); // TODO: decide on order
500 0, L21.rows(), S,
501 [&](index_t i, auto ni) {
502 auto As = A2_.middle_rows(i);
503 auto Ls = L21_.block(i, j);
504 microkernel_tail_lut_2<T, Abi, Conf, OL2, OA2, OA1>[nj - 1][ni - 1](
505 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
506 },
507 LoopDir::Backward); // TODO: decide on order
508 });
509}
510
511/**
512 * Performs a factorization update of the following matrix:
513 *
514 * [ A11 A12 | L11 ] [ 0 0 | L̃11 ]
515 * [ 0 A22 | L21 ] Q = [ Ã21 Ã22 | L̃21 ]
516 * [ A31 0 | L31 ] [ Ã31 Ã32 | L̃31 ]
517 */
518template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OW, StorageOrder OY,
519 StorageOrder OU>
521 const view<T, Abi, OW> A1, // work
522 const view<T, Abi, OY> L21, // Y
523 const view<const T, Abi, OW> A22, // work
524 const view<T, Abi, OW> A2_out, // work
525 const view<T, Abi, OU> L31, // U
526 const view<const T, Abi, OW> A31, // work
527 const view<T, Abi, OW> A3_out, // work
528 const view<const T, Abi> D) noexcept {
529 const index_t k = A1.cols(), k1 = A31.cols(), k2 = A22.cols();
530 BATMAT_ASSUME(k > 0);
531 BATMAT_ASSUME(L11.rows() >= L11.cols());
532 BATMAT_ASSUME(L11.rows() == A1.rows());
533 BATMAT_ASSUME(L21.rows() == A22.rows());
534 BATMAT_ASSUME(L31.rows() == A31.rows());
535 BATMAT_ASSUME(A22.rows() == A2_out.rows());
536 BATMAT_ASSUME(A31.rows() == A3_out.rows());
537 BATMAT_ASSUME(D.rows() == k);
538 BATMAT_ASSUME(L21.cols() == L11.cols());
539 BATMAT_ASSUME(L31.cols() == L11.cols());
540 BATMAT_ASSUME(k1 + k2 == k);
541
542 static constexpr index_constant<SizeR<T, Abi>> R;
544 alignas(W_t::alignment()) T W[W_t::size()];
545
546 // Sizeless views to partition and pass to the micro-kernels
547 const uview<T, Abi, OL> L11_ = L11;
548 const uview<T, Abi, OW> A1_ = A1;
549 const uview<T, Abi, OY> L21_ = L21;
550 const uview<const T, Abi, OW> A22_ = A22;
551 const uview<T, Abi, OW> A2_out_ = A2_out;
552 const uview<T, Abi, OU> L31_ = L31;
553 const uview<const T, Abi, OW> A31_ = A31;
554 const uview<T, Abi, OW> A3_out_ = A3_out;
556
557 // Process all diagonal blocks (in multiples of R, except the last).
558 foreach_chunked_merged(0, L11.cols(), R, [&](index_t j, auto nj) {
559 static constexpr index_constant<SizeS<T, Abi>> S;
560 // Part of A corresponding to this diagonal block
561 // TODO: packing
562 auto Ad = A1_.middle_rows(j);
563 auto Ld = L11_.block(j, j);
564 // Process the diagonal block itself
565 microkernel_diag_lut<T, Abi, Conf, OL, OW>[nj - 1](k, W, Ld, Ad, D_);
566 // Process all rows below the diagonal block (in multiples of S).
567 foreach_chunked_merged(
568 j + nj, L11.rows(), S,
569 [&](index_t i, auto ni) {
570 auto As = A1_.middle_rows(i);
571 auto Ls = L11_.block(i, j);
572 microkernel_tail_lut_2<T, Abi, Conf, OL, OW, OW>[nj - 1][ni - 1](
573 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
574 },
575 LoopDir::Backward); // TODO: decide on order
577 0, L21.rows(), S,
578 [&](index_t i, auto ni) {
579 auto As_out = A2_out_.middle_rows(i);
580 auto As = j == 0 ? A22_.middle_rows(i) : As_out;
581 auto Ls = L21_.block(i, j);
582 // First half of A2 is implicitly zero in first pass
583 index_t offset_s = j == 0 ? k1 : 0, k_s = j == 0 ? k2 : k;
584 microkernel_tail_lut_2<T, Abi, Conf, OY, OW, OW>[nj - 1][ni - 1](
585 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
586 },
587 LoopDir::Backward); // TODO: decide on order
589 0, L31.rows(), S,
590 [&](index_t i, auto ni) {
591 auto As_out = A3_out_.middle_rows(i);
592 auto As = j == 0 ? A31_.middle_rows(i) : As_out;
593 auto Ls = L31_.block(i, j);
594 // Second half of A3 is implicitly zero in first pass
595 index_t offset_s = 0, k_s = j == 0 ? k1 : k;
596 microkernel_tail_lut_2<T, Abi, Conf, OU, OW, OW>[nj - 1][ni - 1](
597 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
598 },
599 LoopDir::Backward); // TODO: decide on order
600 });
601}
602
603/**
604 * Performs a factorization update of the following matrix:
605 *
606 * [ A1 | L11 ] [ 0 | L̃11 ]
607 * [ A2 | L21 ] Q = [ Ã2 | L̃21 ]
608 * [ 0 | Lu1 ] [ Ãu | L̃u1 ]
609 *
610 * where Lu1 and L̃u1 are upper triangular
611 */
612template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OLu,
613 StorageOrder OAu>
615 const view<T, Abi, OL> L21, const view<const T, Abi, OA> A2,
616 const view<T, Abi, OA> A2_out, const view<T, Abi, OLu> Lu1,
617 const view<T, Abi, OAu> Au_out, const view<const T, Abi> D,
618 bool shift_A_out) noexcept {
619 const index_t k = A1.cols();
620 BATMAT_ASSUME(k > 0);
621 BATMAT_ASSUME(L11.rows() >= L11.cols());
622 BATMAT_ASSUME(L11.rows() == A1.rows());
623 BATMAT_ASSUME(L21.rows() == A2.rows());
624 BATMAT_ASSUME(A2_out.rows() == A2.rows());
625 BATMAT_ASSUME(A2_out.cols() == A2.cols());
626 BATMAT_ASSUME(Lu1.rows() == Au_out.rows());
627 BATMAT_ASSUME(A1.cols() == D.rows());
628 BATMAT_ASSUME(A2.cols() == A1.cols());
629 BATMAT_ASSUME(L21.cols() == L11.cols());
630 BATMAT_ASSUME(Lu1.cols() == L11.cols());
631
632 static constexpr index_constant<SizeR<T, Abi>> R;
633 static constexpr index_constant<SizeS<T, Abi>> S;
634 static_assert(R == S);
636 alignas(W_t::alignment()) T W[W_t::size()];
637
638 // Sizeless views to partition and pass to the micro-kernels
639 const uview<T, Abi, OL> L11_ = L11;
640 const uview<T, Abi, OA> A1_ = A1;
641 const uview<T, Abi, OL> L21_ = L21;
642 const uview<const T, Abi, OA> A2_ = A2;
643 const uview<T, Abi, OA> A2_out_ = A2_out;
644 const uview<T, Abi, OLu> Lu1_ = Lu1;
645 const uview<T, Abi, OAu> Au_out_ = Au_out;
647
648 // Process all diagonal blocks (in multiples of R, except the last).
649 foreach_chunked_merged(0, L11.cols(), R, [&](index_t j, auto nj) {
650 const bool do_shift = shift_A_out && j + nj == L11.cols();
651 // Part of A corresponding to this diagonal block
652 // TODO: packing
653 auto Ad = A1_.middle_rows(j);
654 auto Ld = L11_.block(j, j);
655 // Process the diagonal block itself
656 microkernel_diag_lut<T, Abi, Conf, OL, OA>[nj - 1](k, W, Ld, Ad, D_);
657 // Process all rows below the diagonal block (in multiples of S).
658 foreach_chunked_merged(
659 j + nj, L11.rows(), S,
660 [&](index_t i, auto ni) {
661 auto As = A1_.middle_rows(i);
662 auto Ls = L11_.block(i, j);
663 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
664 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
665 },
666 LoopDir::Backward); // TODO: decide on order
668 0, L21.rows(), S,
669 [&](index_t i, auto ni) {
670 auto As_out = A2_out_.middle_rows(i);
671 auto As = j == 0 ? A2_.middle_rows(i) : As_out;
672 auto Ls = L21_.block(i, j);
673 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
674 0, k, k, W, Ls, As, As_out, Ad, D_, Structure::General, do_shift ? -1 : 0);
675 },
676 LoopDir::Backward); // TODO: decide on order
678 0, Lu1.rows(), S,
679 [&](index_t i, auto ni) {
680 auto As_out = Au_out_.middle_rows(i);
681 auto As = As_out;
682 auto Ls = Lu1_.block(i, j);
683 // Au is implicitly zero in first pass
684 const auto struc = i == j ? Structure::Upper
685 : i < j ? Structure::General
686 : Structure::Zero;
687 microkernel_tail_lut_2<T, Abi, Conf, OLu, OAu, OA>[nj - 1][ni - 1](
688 0, j == 0 ? 0 : k, k, W, Ls, As, As_out, Ad, D_, struc, do_shift ? -1 : 0);
689 },
690 LoopDir::Backward); // TODO: decide on order
691 });
692}
693
694} // namespace batmat::linalg::micro_kernels::hyhound
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
datapar::simd< F, Abi > rotr(datapar::simd< F, Abi > x)
Rotate the elements of x to the right by S positions.
Definition rotate.hpp:239
datapar::simd< F, Abi > rot(datapar::simd< F, Abi > x, int s)
Rotate the elements of x to the right by s positions.
Definition rotate.hpp:18
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
Definition loop.hpp:20
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
auto select(auto cond, auto t, auto f)
Definition simd.hpp:245
stdx::simd< Tp, Abi > simd
Definition simd.hpp:148
auto rotate(datapar::simd< T, Abi > x, std::integral_constant< int, S >)
Definition hyhound.tpp:148
void hyhound_diag_full_microkernel(index_t kA, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
Definition hyhound.tpp:99
const constinit auto microkernel_full_lut
Definition hyhound.tpp:24
void hyhound_diag_cyclic_register(view< T, Abi, OL > L11, view< T, Abi, OW > A1, view< T, Abi, OY > L21, view< const T, Abi, OW > A22, view< T, Abi, OW > A2_out, view< T, Abi, OU > L31, view< const T, Abi, OW > A31, view< T, Abi, OW > A3_out, view< const T, Abi > D) noexcept
Performs a factorization update of the following matrix:
Definition hyhound.tpp:520
void hyhound_diag_register(view< T, Abi, OL > L, view< T, Abi, OA > A, view< const T, Abi > D) noexcept
Block hyperbolic Householder factorization update using register blocking.
Definition hyhound.tpp:294
void hyhound_diag_riccati_register(view< T, Abi, OL > L11, view< T, Abi, OA > A1, view< T, Abi, OL > L21, view< const T, Abi, OA > A2, view< T, Abi, OA > A2_out, view< T, Abi, OLu > Lu1, view< T, Abi, OAu > Au_out, view< const T, Abi > D, bool shift_A_out) noexcept
Performs a factorization update of the following matrix:
Definition hyhound.tpp:614
void hyhound_diag_tail_microkernel(index_t kA_in_offset, index_t kA_in, index_t k, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< const T, Abi, OA > A_in, uview< T, Abi, OA > A_out, uview< const T, Abi, OB > B, uview< const T, Abi, StorageOrder::ColMajor > diag, Structure struc_L, int rotate_A) noexcept
Definition hyhound.tpp:166
constexpr std::pair< index_t, index_t > hyhound_W_size(view< T, Abi, OL > L)
Definition hyhound.hpp:48
void hyhound_diag_diag_microkernel(index_t kA, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
Definition hyhound.tpp:43
void hyhound_diag_apply_register(view< T, Abi, OL > L, view< const T, Abi, OA > Ain, view< T, Abi, OA > Aout, view< const T, Abi, OA > B, view< const T, Abi > D, view< const T, Abi > W, index_t kA_in_offset=0) noexcept
Apply a block hyperbolic Householder transformation.
Definition hyhound.tpp:407
const constinit auto microkernel_tail_lut_2
Definition hyhound.tpp:36
const constinit auto microkernel_tail_lut
Definition hyhound.tpp:30
const constinit auto microkernel_diag_lut
Definition hyhound.tpp:18
void hyhound_diag_2_register(view< T, Abi, OL1 > L11, view< T, Abi, OA1 > A1, view< T, Abi, OL2 > L21, view< T, Abi, OA2 > A2, view< const T, Abi > D) noexcept
Same as hyhound_diag_register but for two block rows at once.
Definition hyhound.tpp:458
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
T cneg(T x, T signs)
Conditionally negates the sign bit of x, depending on signs, which should contain only ±0 (i....
Definition cneg.hpp:42
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
int index_t
Definition config.hpp:13
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114