batmat 0.0.14
Batched linear algebra routines
Loading...
Searching...
No Matches
hyhound.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
7#include <batmat/ops/cneg.hpp>
9#include <guanaqo/trace.hpp>
10#include <type_traits>
11#include <utility>
12
13#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
14
16
17template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
18[[gnu::hot, gnu::flatten]] void
23 using std::copysign;
24 using std::sqrt;
25 using simd = datapar::simd<T, Abi>;
26 // Pre-compute the offsets of the columns of L
27 auto L_cached = with_cached_access<R, R>(L);
28 BATMAT_ASSUME(kA > 0);
29
30 UNROLL_FOR (index_t j = 0; j < R; ++j) {
31 // Compute all inner products between A and a
32 simd bb[R]{};
33 for (index_t l = 0; l < kA; ++l) {
34 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0)) //
35 : A.load(j, l) * diag.load(l, 0);
36 UNROLL_FOR (index_t i = 0; i < R; ++i)
37 bb[i] += A.load(i, l) * Ajl;
38 }
39 // Energy condition and Householder coefficients
40 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
41 const simd L̃jj = copysign(sqrt(Ljj * Ljj + α2), Ljj), β = Ljj + L̃jj;
42 simd γoβ = simd{2} * β / (β * β + α2), γ = β * γoβ, inv_β = simd{1} / β;
43 L_cached.store(L̃jj, j, j);
44 // Compute L̃
45 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
46 simd Lij = L_cached.load(i, j);
47 bb[i] = γ * Lij + bb[i] * γoβ;
48 L_cached.store(bb[i] - Lij, i, j);
49 }
50 // Update A
51 for (index_t l = 0; l < kA; ++l) {
52 simd Ajl = A.load(j, l) * inv_β; // Scale Householder vector
53 A.store(Ajl, j, l);
54 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
55 simd Ail = A.load(i, l);
56 Ail -= bb[i] * Ajl;
57 A.store(Ail, i, l);
58 }
59 }
60 // Save block Householder matrix W
61 UNROLL_FOR (index_t i = 0; i < j + 1; ++i)
62 bb[i] *= inv_β;
63 bb[j] = γ; // inverse of diagonal
64 UNROLL_FOR (index_t i = 0; i < j + 1; ++i)
65 W.store(bb[i], i, j);
66 // TODO: try moving this to before update of A
67 }
68}
69
70template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
71[[gnu::hot, gnu::flatten]] void
75 using std::copysign;
76 using std::sqrt;
77 using simd = datapar::simd<T, Abi>;
78 // Pre-compute the offsets of the columns of L
79 auto L_cached = with_cached_access<R, R>(L);
80 BATMAT_ASSUME(kA > 0);
81
82 UNROLL_FOR (index_t j = 0; j < R; ++j) {
83 // Compute some inner products between A and a
84 simd bb[R]{};
85 for (index_t l = 0; l < kA; ++l) {
86 simd Ajl = Conf.sign_only ? cneg(A.load(j, l), diag.load(l, 0)) //
87 : A.load(j, l) * diag.load(l, 0);
88 UNROLL_FOR (index_t i = j; i < R; ++i)
89 bb[i] += A.load(i, l) * Ajl;
90 }
91 // Energy condition and Householder coefficients
92 const simd α2 = bb[j], Ljj = L_cached.load(j, j);
93 const simd L̃jj = copysign(sqrt(Ljj * Ljj + α2), Ljj), β = Ljj + L̃jj;
94 simd γoβ = simd{2} * β / (β * β + α2), γ = β * γoβ, inv_β = simd{1} / β;
95 L_cached.store(L̃jj, j, j);
96 // Compute L̃
97 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
98 simd Lij = L_cached.load(i, j);
99 bb[i] = γ * Lij + bb[i] * γoβ;
100 L_cached.store(bb[i] - Lij, i, j);
101 }
102 // Update A
103 for (index_t l = 0; l < kA; ++l) {
104 simd Ajl = A.load(j, l) * inv_β; // Scale Householder vector
105 A.store(Ajl, j, l);
106 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
107 simd Ail = A.load(i, l);
108 Ail -= bb[i] * Ajl;
109 A.store(Ail, i, l);
110 }
111 }
112 }
113}
114
115namespace detail {
116
117template <class T, class Abi, int S>
118[[gnu::always_inline]] inline auto rotate(datapar::simd<T, Abi> x, std::integral_constant<int, S>) {
119 using ops::rotr;
120 return rotr<S>(x);
121}
122
123template <class T, class Abi>
124[[gnu::always_inline]] inline auto rotate(datapar::simd<T, Abi> x, int s) {
125 using ops::rot;
126 return rot(x, s);
127}
128
129} // namespace detail
130
131// A_out and B have the same size. A_in has the same number of rows but may have a different number
132// of columns, which means that only a part of A_in is nonzero. The nonzero part is defined by kAin
133// and kAin_offset.
134template <class T, class Abi, KernelConfig Conf, index_t R, index_t S, StorageOrder OL,
136[[gnu::hot, gnu::flatten]] void hyhound_diag_tail_microkernel(
137 index_t kA_in_offset, index_t kA_in, index_t k,
140 uview<const T, Abi, StorageOrder::ColMajor> diag, Structure struc_L, int rotate_A) noexcept {
141 using batmat::ops::cneg;
142 using simd = datapar::simd<T, Abi>;
143 BATMAT_ASSUME(k > 0);
144
145 // Compute product W = A B
146 simd V[S][R]{};
147 for (index_t lA = 0; lA < kA_in; ++lA) {
148 index_t lB = lA + kA_in_offset;
149 UNROLL_FOR (index_t j = 0; j < R; ++j) {
150 auto Bjl = Conf.sign_only ? cneg(B.load(j, lB), diag.load(lB, 0)) //
151 : B.load(j, lB) * diag.load(lB, 0);
152 UNROLL_FOR (index_t i = 0; i < S; ++i)
153 V[i][j] += A_in.load(i, lA) * Bjl;
154 }
155 }
156
157 // Solve system V = (L+U)W⁻¹ (in-place)
158 auto L_cached = with_cached_access<S, R>(L);
159 switch (struc_L) {
160 [[likely]]
161 case Structure::General: {
162 UNROLL_FOR (index_t j = 0; j < R; ++j) {
163 simd Wj[R];
164 UNROLL_FOR (index_t i = 0; i < j; ++i)
165 Wj[i] = W.load(i, j);
166 UNROLL_FOR (index_t i = 0; i < S; ++i) {
167 simd Lij = L_cached.load(i, j);
168 V[i][j] += Lij;
169 UNROLL_FOR (index_t l = 0; l < j; ++l)
170 V[i][j] -= V[i][l] * Wj[l];
171 V[i][j] *= W.load(j, j); // diagonal already inverted
172 Lij = V[i][j] - Lij;
173 L_cached.store(Lij, i, j);
174 }
175 }
176 } break;
177 case Structure::Zero: {
178 UNROLL_FOR (index_t j = 0; j < R; ++j) {
179 simd Wj[R];
180 UNROLL_FOR (index_t i = 0; i < j; ++i)
181 Wj[i] = W.load(i, j);
182 UNROLL_FOR (index_t i = 0; i < S; ++i) {
183 UNROLL_FOR (index_t l = 0; l < j; ++l)
184 V[i][j] -= V[i][l] * Wj[l];
185 V[i][j] *= W.load(j, j); // diagonal already inverted
186 }
187 }
188 } break;
189 case Structure::Upper: {
190 UNROLL_FOR (index_t j = 0; j < R; ++j) {
191 simd Wj[R];
192 UNROLL_FOR (index_t i = 0; i < j; ++i)
193 Wj[i] = W.load(i, j);
194 UNROLL_FOR (index_t i = 0; i < S; ++i) {
195 simd Lij;
196 if (i <= j) {
197 Lij = L_cached.load(i, j);
198 V[i][j] += Lij;
199 }
200 UNROLL_FOR (index_t l = 0; l < j; ++l)
201 V[i][j] -= V[i][l] * Wj[l];
202 V[i][j] *= W.load(j, j); // diagonal already inverted
203 if (i <= j) {
204 Lij = V[i][j] - Lij;
205 L_cached.store(Lij, i, j);
206 }
207 }
208 }
209 } break;
210 default: BATMAT_ASSUME(false);
211 }
212 // Update A -= V Bᵀ
213 const auto update_A = [&] [[gnu::always_inline]] (auto s) {
214 simd Bjl[R];
215 for (index_t lB = 0; lB < kA_in_offset; ++lB) [[unlikely]] {
216 UNROLL_FOR (index_t j = 0; j < R; ++j)
217 Bjl[j] = B.load(j, lB);
218 UNROLL_FOR (index_t i = 0; i < S; ++i) {
219 simd Ail{0};
220 UNROLL_FOR (index_t j = 0; j < R; ++j)
221 Ail -= V[i][j] * Bjl[j];
222 A_out.store(detail::rotate(Ail, s), i, lB);
223 }
224 }
225 for (index_t lB = kA_in_offset + kA_in; lB < k; ++lB) [[unlikely]] {
226 UNROLL_FOR (index_t j = 0; j < R; ++j)
227 Bjl[j] = B.load(j, lB);
228 UNROLL_FOR (index_t i = 0; i < S; ++i) {
229 simd Ail{0};
230 UNROLL_FOR (index_t j = 0; j < R; ++j)
231 Ail -= V[i][j] * Bjl[j];
232 A_out.store(detail::rotate(Ail, s), i, lB);
233 }
234 }
235 for (index_t lA = 0; lA < kA_in; ++lA) [[likely]] {
236 index_t lB = lA + kA_in_offset;
237 UNROLL_FOR (index_t j = 0; j < R; ++j)
238 Bjl[j] = B.load(j, lB);
239 UNROLL_FOR (index_t i = 0; i < S; ++i) {
240 auto Ail = A_in.load(i, lA);
241 UNROLL_FOR (index_t j = 0; j < R; ++j)
242 Ail -= V[i][j] * Bjl[j];
243 A_out.store(detail::rotate(Ail, s), i, lB);
244 }
245 }
246 };
247#if defined(__AVX512F__) && 0
248 update_A(rotate_A);
249#else
250 switch (rotate_A) {
251 [[likely]] case 0:
252 update_A(std::integral_constant<int, 0>{});
253 break;
254 case -1: update_A(std::integral_constant<int, -1>{}); break;
255 // case 1: update_A(std::integral_constant<int, 1>{}); break;
256 default: BATMAT_ASSUME(false);
257 }
258#endif
259}
260
261/// Block hyperbolic Householder factorization update using register blocking.
262/// This variant does not store the Householder representation W.
263template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
265 const view<const T, Abi> D) noexcept {
266 static constexpr index_constant<SizeR<T, Abi>> R;
267 static constexpr index_constant<SizeS<T, Abi>> S;
268 const index_t k = A.cols();
269 BATMAT_ASSUME(k > 0);
270 BATMAT_ASSUME(L.rows() >= L.cols());
271 BATMAT_ASSUME(L.rows() == A.rows());
272 BATMAT_ASSUME(A.cols() == D.rows());
273
275 alignas(W_t::alignment()) T W[W_t::size()];
276
277 // Sizeless views to partition and pass to the micro-kernels
278 const uview<T, Abi, OL> L_ = L;
279 const uview<T, Abi, OA> A_ = A;
281
282 // Process all diagonal blocks (in multiples of R, except the last).
283 if (L.rows() == L.cols()) {
285 0, L.cols(), R,
286 [&](index_t j) {
287 // Part of A corresponding to this diagonal block
288 // TODO: packing
289 auto Ad = A_.middle_rows(j);
290 auto Ld = L_.block(j, j);
291 // Process the diagonal block itself
292 hyhound_diag_diag_microkernel<T, Abi, Conf, R, OL, OA>(k, W, Ld, Ad, D_);
293 // Process all rows below the diagonal block (in multiples of S).
294 foreach_chunked_merged(
295 j + R, L.rows(), S,
296 [&](index_t i, auto rem_i) {
297 auto As = A_.middle_rows(i);
298 auto Ls = L_.block(i, j);
299 microkernel_tail_lut<T, Abi, Conf, OL, OA, OA>[rem_i - 1](
300 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
301 },
302 LoopDir::Backward); // TODO: decide on order
303 },
304 [&](index_t j, index_t rem_j) {
305 auto Ad = A_.middle_rows(j);
306 auto Ld = L_.block(j, j);
307 microkernel_full_lut<T, Abi, Conf, OL, OA>[rem_j - 1](k, Ld, Ad, D_);
308 });
309 } else {
310 foreach_chunked_merged(0, L.cols(), R, [&](index_t j, auto rem_j) {
311 // Part of A corresponding to this diagonal block
312 // TODO: packing
313 auto Ad = A_.middle_rows(j);
314 auto Ld = L_.block(j, j);
315 // Process the diagonal block itself
316 microkernel_diag_lut<T, Abi, Conf, OL, OA>[rem_j - 1](k, W, Ld, Ad, D_);
317 // Process all rows below the diagonal block (in multiples of S).
318 foreach_chunked_merged(
319 j + rem_j, L.rows(), S,
320 [&](index_t i, auto rem_i) {
321 auto As = A_.middle_rows(i);
322 auto Ls = L_.block(i, j);
323 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[rem_j - 1][rem_i - 1](
324 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
325 },
326 LoopDir::Backward); // TODO: decide on order
327 });
328 }
329}
330
331/// Block hyperbolic Householder factorization update using register blocking.
332/// This variant stores the Householder representation W.
333template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
335 const view<const T, Abi> D, const view<T, Abi> W) noexcept {
336 const index_t k = A.cols();
337 BATMAT_ASSUME(k > 0);
338 BATMAT_ASSUME(L.rows() >= L.cols());
339 BATMAT_ASSUME(L.rows() == A.rows());
340 BATMAT_ASSUME(D.rows() == k);
341 BATMAT_ASSUME(std::make_pair(W.rows(), W.cols()) == (hyhound_W_size<T, Abi>)(L));
342
343 static constexpr index_constant<SizeR<T, Abi>> R;
345
346 // Sizeless views to partition and pass to the micro-kernels
347 const uview<T, Abi, OL> L_ = L;
348 const uview<T, Abi, OA> A_ = A;
351
352 // Process all diagonal blocks (in multiples of R, except the last).
353 foreach_chunked_merged(0, L.cols(), R, [&](index_t j, auto nj) {
354 static constexpr index_constant<SizeS<T, Abi>> S;
355 // Part of A corresponding to this diagonal block
356 // TODO: packing
357 auto Ad = A_.middle_rows(j);
358 auto Ld = L_.block(j, j);
359 auto Wd = W_t{W_.middle_cols(j / R).data};
360 // Process the diagonal block itself
361 microkernel_diag_lut<T, Abi, Conf, OL, OA>[nj - 1](k, Wd, Ld, Ad, D_);
362 // Process all rows below the diagonal block (in multiples of S).
364 j + nj, L.rows(), S,
365 [&](index_t i, auto ni) {
366 auto As = A_.middle_rows(i);
367 auto Ls = L_.block(i, j);
368 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
369 0, k, k, Wd, Ls, As, As, Ad, D_, Structure::General, 0);
370 },
371 LoopDir::Backward); // TODO: decide on order
372 });
373}
374
375/// Apply a block hyperbolic Householder transformation.
376template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
378 const view<T, Abi, OA> Aout, const view<const T, Abi, OA> B,
379 const view<const T, Abi> D, const view<const T, Abi> W,
380 index_t kA_in_offset) noexcept {
381 const index_t k_in = Ain.cols(), k = Aout.cols();
382 BATMAT_ASSUME(k > 0);
383 BATMAT_ASSUME(Aout.rows() == Ain.rows());
384 BATMAT_ASSUME(Ain.rows() == L.rows());
385 BATMAT_ASSUME(B.rows() == L.cols());
386 BATMAT_ASSUME(B.cols() == k);
387 BATMAT_ASSUME(D.rows() == k);
388 BATMAT_ASSUME(0 <= kA_in_offset);
389 BATMAT_ASSUME(kA_in_offset + k_in <= k);
390
391 static constexpr index_constant<SizeR<T, Abi>> R;
393 BATMAT_ASSUME(std::make_pair(W.rows(), W.cols()) == (hyhound_W_size<T, Abi>)(L));
394
395 // Sizeless views to partition and pass to the micro-kernels
396 const uview<T, Abi, OL> L_ = L;
397 const uview<const T, Abi, OA> Ain_ = Ain;
398 const uview<T, Abi, OA> Aout_ = Aout;
399 const uview<const T, Abi, OA> B_ = B;
402
403 // Process all diagonal blocks (in multiples of R, except the last).
404 foreach_chunked_merged(0, L.cols(), R, [&](index_t j, auto nj) {
405 static constexpr index_constant<SizeS<T, Abi>> S;
406 // Part of A corresponding to this diagonal block
407 // TODO: packing
408 auto Ad = B_.middle_rows(j);
409 auto Wd = W_t{W_.middle_cols(j / R).data};
410 // Process all rows (in multiples of S).
411 foreach_chunked_merged( // TODO: swap loop order?
412 0, L.rows(), S,
413 [&](index_t i, auto ni) {
414 auto Aini = j == 0 ? Ain_.middle_rows(i) : Aout_.middle_rows(i);
415 auto Aouti = Aout_.middle_rows(i);
416 auto Ls = L_.block(i, j);
417 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
418 j == 0 ? kA_in_offset : 0, j == 0 ? k_in : k, k, Wd, Ls, Aini, Aouti, Ad, D_,
419 Structure::General, 0);
420 },
421 LoopDir::Backward); // TODO: decide on order
422 });
423}
424
425/// Same as hyhound_diag_register but for two block rows at once.
426template <class T, class Abi, StorageOrder OL1, StorageOrder OA1, StorageOrder OL2,
427 StorageOrder OA2, KernelConfig Conf>
429 const view<T, Abi, OL2> L21, const view<T, Abi, OA2> A2,
430 const view<const T, Abi> D) noexcept {
431 const index_t k = A1.cols();
432 BATMAT_ASSUME(k > 0);
433 BATMAT_ASSUME(L11.rows() >= L11.cols());
434 BATMAT_ASSUME(L11.rows() == A1.rows());
435 BATMAT_ASSUME(D.rows() == k);
436 BATMAT_ASSUME(A2.cols() == k);
437 BATMAT_ASSUME(L21.cols() == L11.cols());
438
439 static constexpr index_constant<SizeR<T, Abi>> R;
441 alignas(W_t::alignment()) T W[W_t::size()];
442
443 // Sizeless views to partition and pass to the micro-kernels
444 const uview<T, Abi, OL1> L11_ = L11;
445 const uview<T, Abi, OA1> A1_ = A1;
446 const uview<T, Abi, OL2> L21_ = L21;
447 const uview<T, Abi, OA2> A2_ = A2;
449
450 // Process all diagonal blocks (in multiples of R, except the last).
451 foreach_chunked_merged(0, L11.cols(), R, [&](index_t j, auto nj) {
452 static constexpr index_constant<SizeS<T, Abi>> S;
453 // Part of A corresponding to this diagonal block
454 // TODO: packing
455 auto Ad = A1_.middle_rows(j);
456 auto Ld = L11_.block(j, j);
457 // Process the diagonal block itself
458 microkernel_diag_lut<T, Abi, Conf, OL1, OA1>[nj - 1](k, W, Ld, Ad, D_);
459 // Process all rows below the diagonal block (in multiples of S).
460 foreach_chunked_merged(
461 j + nj, L11.rows(), S,
462 [&](index_t i, auto ni) {
463 auto As = A1_.middle_rows(i);
464 auto Ls = L11_.block(i, j);
465 microkernel_tail_lut_2<T, Abi, Conf, OL1, OA1, OA1>[nj - 1][ni - 1](
466 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
467 },
468 LoopDir::Backward); // TODO: decide on order
470 0, L21.rows(), S,
471 [&](index_t i, auto ni) {
472 auto As = A2_.middle_rows(i);
473 auto Ls = L21_.block(i, j);
474 microkernel_tail_lut_2<T, Abi, Conf, OL2, OA2, OA1>[nj - 1][ni - 1](
475 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
476 },
477 LoopDir::Backward); // TODO: decide on order
478 });
479}
480
481/**
482 * Performs a factorization update of the following matrix:
483 *
484 * [ A11 A12 | L11 ] [ 0 0 | L̃11 ]
485 * [ 0 A22 | L21 ] Q = [ Ã21 Ã22 | L̃21 ]
486 * [ A31 0 | L31 ] [ Ã31 Ã32 | L̃31 ]
487 */
488template <class T, class Abi, StorageOrder OL, StorageOrder OW, StorageOrder OY, StorageOrder OU,
489 KernelConfig Conf>
491 const view<T, Abi, OW> A1, // work
492 const view<T, Abi, OY> L21, // Y
493 const view<const T, Abi, OW> A22, // work
494 const view<T, Abi, OW> A2_out, // work
495 const view<T, Abi, OU> L31, // U
496 const view<const T, Abi, OW> A31, // work
497 const view<T, Abi, OW> A3_out, // work
498 const view<const T, Abi> D) noexcept {
499 const index_t k = A1.cols(), k1 = A31.cols(), k2 = A22.cols();
500 BATMAT_ASSUME(k > 0);
501 BATMAT_ASSUME(L11.rows() >= L11.cols());
502 BATMAT_ASSUME(L11.rows() == A1.rows());
503 BATMAT_ASSUME(L21.rows() == A22.rows());
504 BATMAT_ASSUME(L31.rows() == A31.rows());
505 BATMAT_ASSUME(A22.rows() == A2_out.rows());
506 BATMAT_ASSUME(A31.rows() == A3_out.rows());
507 BATMAT_ASSUME(D.rows() == k);
508 BATMAT_ASSUME(L21.cols() == L11.cols());
509 BATMAT_ASSUME(L31.cols() == L11.cols());
510 BATMAT_ASSUME(k1 + k2 == k);
511
512 static constexpr index_constant<SizeR<T, Abi>> R;
514 alignas(W_t::alignment()) T W[W_t::size()];
515
516 // Sizeless views to partition and pass to the micro-kernels
517 const uview<T, Abi, OL> L11_ = L11;
518 const uview<T, Abi, OW> A1_ = A1;
519 const uview<T, Abi, OY> L21_ = L21;
520 const uview<const T, Abi, OW> A22_ = A22;
521 const uview<T, Abi, OW> A2_out_ = A2_out;
522 const uview<T, Abi, OU> L31_ = L31;
523 const uview<const T, Abi, OW> A31_ = A31;
524 const uview<T, Abi, OW> A3_out_ = A3_out;
526
527 // Process all diagonal blocks (in multiples of R, except the last).
528 foreach_chunked_merged(0, L11.cols(), R, [&](index_t j, auto nj) {
529 static constexpr index_constant<SizeS<T, Abi>> S;
530 // Part of A corresponding to this diagonal block
531 // TODO: packing
532 auto Ad = A1_.middle_rows(j);
533 auto Ld = L11_.block(j, j);
534 // Process the diagonal block itself
535 microkernel_diag_lut<T, Abi, Conf, OL, OW>[nj - 1](k, W, Ld, Ad, D_);
536 // Process all rows below the diagonal block (in multiples of S).
537 foreach_chunked_merged(
538 j + nj, L11.rows(), S,
539 [&](index_t i, auto ni) {
540 auto As = A1_.middle_rows(i);
541 auto Ls = L11_.block(i, j);
542 microkernel_tail_lut_2<T, Abi, Conf, OL, OW, OW>[nj - 1][ni - 1](
543 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
544 },
545 LoopDir::Backward); // TODO: decide on order
547 0, L21.rows(), S,
548 [&](index_t i, auto ni) {
549 auto As_out = A2_out_.middle_rows(i);
550 auto As = j == 0 ? A22_.middle_rows(i) : As_out;
551 auto Ls = L21_.block(i, j);
552 // First half of A2 is implicitly zero in first pass
553 index_t offset_s = j == 0 ? k1 : 0, k_s = j == 0 ? k2 : k;
554 microkernel_tail_lut_2<T, Abi, Conf, OY, OW, OW>[nj - 1][ni - 1](
555 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
556 },
557 LoopDir::Backward); // TODO: decide on order
559 0, L31.rows(), S,
560 [&](index_t i, auto ni) {
561 auto As_out = A3_out_.middle_rows(i);
562 auto As = j == 0 ? A31_.middle_rows(i) : As_out;
563 auto Ls = L31_.block(i, j);
564 // Second half of A3 is implicitly zero in first pass
565 index_t offset_s = 0, k_s = j == 0 ? k1 : k;
566 microkernel_tail_lut_2<T, Abi, Conf, OU, OW, OW>[nj - 1][ni - 1](
567 offset_s, k_s, k, W, Ls, As, As_out, Ad, D_, Structure::General, 0);
568 },
569 LoopDir::Backward); // TODO: decide on order
570 });
571}
572
573/**
574 * Performs a factorization update of the following matrix:
575 *
576 * [ A1 | L11 ] [ 0 | L̃11 ]
577 * [ A2 | L21 ] Q = [ Ã2 | L̃21 ]
578 * [ 0 | Lu1 ] [ Ãu | L̃u1 ]
579 *
580 * where Lu1 and L̃u1 are upper triangular
581 */
582template <class T, class Abi, StorageOrder OL, StorageOrder OA, StorageOrder OLu, StorageOrder OAu,
583 KernelConfig Conf>
585 const view<T, Abi, OL> L21, const view<const T, Abi, OA> A2,
586 const view<T, Abi, OA> A2_out, const view<T, Abi, OLu> Lu1,
587 const view<T, Abi, OAu> Au_out, const view<const T, Abi> D,
588 bool shift_A_out) noexcept {
589 const index_t k = A1.cols();
590 BATMAT_ASSUME(k > 0);
591 BATMAT_ASSUME(L11.rows() >= L11.cols());
592 BATMAT_ASSUME(L11.rows() == A1.rows());
593 BATMAT_ASSUME(L21.rows() == A2.rows());
594 BATMAT_ASSUME(A2_out.rows() == A2.rows());
595 BATMAT_ASSUME(A2_out.cols() == A2.cols());
596 BATMAT_ASSUME(Lu1.rows() == Au_out.rows());
597 BATMAT_ASSUME(A1.cols() == D.rows());
598 BATMAT_ASSUME(A2.cols() == A1.cols());
599 BATMAT_ASSUME(L21.cols() == L11.cols());
600 BATMAT_ASSUME(Lu1.cols() == L11.cols());
601
602 static constexpr index_constant<SizeR<T, Abi>> R;
603 static constexpr index_constant<SizeS<T, Abi>> S;
604 static_assert(R == S);
606 alignas(W_t::alignment()) T W[W_t::size()];
607
608 // Sizeless views to partition and pass to the micro-kernels
609 const uview<T, Abi, OL> L11_ = L11;
610 const uview<T, Abi, OA> A1_ = A1;
611 const uview<T, Abi, OL> L21_ = L21;
612 const uview<const T, Abi, OA> A2_ = A2;
613 const uview<T, Abi, OA> A2_out_ = A2_out;
614 const uview<T, Abi, OLu> Lu1_ = Lu1;
615 const uview<T, Abi, OAu> Au_out_ = Au_out;
617
618 // Process all diagonal blocks (in multiples of R, except the last).
619 foreach_chunked_merged(0, L11.cols(), R, [&](index_t j, auto nj) {
620 const bool do_shift = shift_A_out && j + nj == L11.cols();
621 // Part of A corresponding to this diagonal block
622 // TODO: packing
623 auto Ad = A1_.middle_rows(j);
624 auto Ld = L11_.block(j, j);
625 // Process the diagonal block itself
626 microkernel_diag_lut<T, Abi, Conf, OL, OA>[nj - 1](k, W, Ld, Ad, D_);
627 // Process all rows below the diagonal block (in multiples of S).
628 foreach_chunked_merged(
629 j + nj, L11.rows(), S,
630 [&](index_t i, auto ni) {
631 auto As = A1_.middle_rows(i);
632 auto Ls = L11_.block(i, j);
633 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
634 0, k, k, W, Ls, As, As, Ad, D_, Structure::General, 0);
635 },
636 LoopDir::Backward); // TODO: decide on order
638 0, L21.rows(), S,
639 [&](index_t i, auto ni) {
640 auto As_out = A2_out_.middle_rows(i);
641 auto As = j == 0 ? A2_.middle_rows(i) : As_out;
642 auto Ls = L21_.block(i, j);
643 microkernel_tail_lut_2<T, Abi, Conf, OL, OA, OA>[nj - 1][ni - 1](
644 0, k, k, W, Ls, As, As_out, Ad, D_, Structure::General, do_shift ? -1 : 0);
645 },
646 LoopDir::Backward); // TODO: decide on order
648 0, Lu1.rows(), S,
649 [&](index_t i, auto ni) {
650 auto As_out = Au_out_.middle_rows(i);
651 auto As = As_out;
652 auto Ls = Lu1_.block(i, j);
653 // Au is implicitly zero in first pass
654 const auto struc = i == j ? Structure::Upper
655 : i < j ? Structure::General
656 : Structure::Zero;
657 microkernel_tail_lut_2<T, Abi, Conf, OLu, OAu, OA>[nj - 1][ni - 1](
658 0, j == 0 ? 0 : k, k, W, Ls, As, As_out, Ad, D_, struc, do_shift ? -1 : 0);
659 },
660 LoopDir::Backward); // TODO: decide on order
661 });
662}
663
664} // namespace batmat::linalg::micro_kernels::hyhound
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:9
datapar::simd< F, Abi > rotr(datapar::simd< F, Abi > x)
Rotate the elements of x to the right by S positions.
Definition rotate.hpp:239
datapar::simd< F, Abi > rot(datapar::simd< F, Abi > x, int s)
Rotate the elements of x to the right by s positions.
Definition rotate.hpp:18
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
Definition loop.hpp:20
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
auto rotate(datapar::simd< T, Abi > x, std::integral_constant< int, S >)
Definition hyhound.tpp:118
void hyhound_diag_full_microkernel(index_t kA, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
Definition hyhound.tpp:72
const constinit auto microkernel_full_lut
Definition hyhound.hpp:85
void hyhound_diag_cyclic_register(view< T, Abi, OL > L11, view< T, Abi, OW > A1, view< T, Abi, OY > L21, view< const T, Abi, OW > A22, view< T, Abi, OW > A2_out, view< T, Abi, OU > L31, view< const T, Abi, OW > A31, view< T, Abi, OW > A3_out, view< const T, Abi > D) noexcept
Performs a factorization update of the following matrix:
Definition hyhound.tpp:490
void hyhound_diag_register(view< T, Abi, OL > L, view< T, Abi, OA > A, view< const T, Abi > D) noexcept
Block hyperbolic Householder factorization update using register blocking.
Definition hyhound.tpp:264
void hyhound_diag_riccati_register(view< T, Abi, OL > L11, view< T, Abi, OA > A1, view< T, Abi, OL > L21, view< const T, Abi, OA > A2, view< T, Abi, OA > A2_out, view< T, Abi, OLu > Lu1, view< T, Abi, OAu > Au_out, view< const T, Abi > D, bool shift_A_out) noexcept
Performs a factorization update of the following matrix:
Definition hyhound.tpp:584
void hyhound_diag_tail_microkernel(index_t kA_in_offset, index_t kA_in, index_t k, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< const T, Abi, OA > A_in, uview< T, Abi, OA > A_out, uview< const T, Abi, OB > B, uview< const T, Abi, StorageOrder::ColMajor > diag, Structure struc_L, int rotate_A) noexcept
Definition hyhound.tpp:136
constexpr std::pair< index_t, index_t > hyhound_W_size(view< T, Abi, OL > L)
Definition hyhound.hpp:105
void hyhound_diag_diag_microkernel(index_t kA, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
Definition hyhound.tpp:19
void hyhound_diag_apply_register(view< T, Abi, OL > L, view< const T, Abi, OA > Ain, view< T, Abi, OA > Aout, view< const T, Abi, OA > B, view< const T, Abi > D, view< const T, Abi > W, index_t kA_in_offset=0) noexcept
Apply a block hyperbolic Householder transformation.
Definition hyhound.tpp:377
const constinit auto microkernel_diag_lut
Definition hyhound.hpp:79
void hyhound_diag_2_register(view< T, Abi, OL1 > L11, view< T, Abi, OA1 > A1, view< T, Abi, OL2 > L21, view< T, Abi, OA2 > A2, view< const T, Abi > D) noexcept
Same as hyhound_diag_register but for two block rows at once.
Definition hyhound.tpp:428
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
T cneg(T x, T signs)
Conditionally negates the sign bit of x, depending on signs, which should contain only ±0 (i....
Definition cneg.hpp:42
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114