batmat 0.0.19
Batched linear algebra routines
Loading...
Searching...
No Matches
hyhound.hpp
Go to the documentation of this file.
1#pragma once
2
10#include <batmat/loop.hpp>
12#include <guanaqo/trace.hpp>
13
14namespace batmat::linalg {
15
16namespace detail {
17
18template <class T, class Abi, micro_kernels::hyhound::KernelConfig Conf, StorageOrder OL,
19 StorageOrder OA>
21 const index_t k = A.cols();
22 BATMAT_ASSERT(L.rows() >= L.cols());
23 BATMAT_ASSERT(L.rows() == A.rows());
24 BATMAT_ASSERT(A.cols() == D.rows());
25 [[maybe_unused]] const index_t flop_count = total(flops::hyh(L.rows(), L.cols(), k));
26 GUANAQO_TRACE_LINALG("hyhound_diag", flop_count * L.depth());
27 if (k == 0) [[unlikely]]
28 return;
30}
31
32template <class T, class Abi, micro_kernels::hyhound::KernelConfig Conf, StorageOrder OL,
33 StorageOrder OA>
35 using namespace micro_kernels::hyhound;
36 const index_t k = A.cols();
37 BATMAT_ASSERT(L.rows() >= L.cols());
38 BATMAT_ASSERT(L.rows() == A.rows());
39 BATMAT_ASSERT(A.cols() == D.rows());
40 BATMAT_ASSERT(std::make_pair(W.rows(), W.cols()) == (hyhound_W_size<T, Abi>)(L));
41 [[maybe_unused]] const index_t flop_count = total(flops::hyh(L.rows(), L.cols(), k));
42 GUANAQO_TRACE_LINALG("hyhound_diag", flop_count * L.depth());
43 if (k == 0) [[unlikely]]
44 return;
45 return hyhound_diag_register<T, Abi, Conf>(L, A, D, W);
46}
47
48template <class T, class Abi, micro_kernels::hyhound::KernelConfig Conf, StorageOrder OL,
49 StorageOrder OA>
52 index_t kA_in_offset = 0) {
53 using namespace micro_kernels::hyhound;
54 const index_t k = Aout.cols();
55 BATMAT_ASSERT(Aout.rows() == Ain.rows());
56 BATMAT_ASSERT(Ain.rows() == L.rows());
57 BATMAT_ASSERT(B.rows() == L.cols());
58 BATMAT_ASSERT(B.cols() == Aout.cols());
59 BATMAT_ASSERT(D.rows() == Aout.cols());
60 BATMAT_ASSERT(0 <= kA_in_offset);
61 BATMAT_ASSERT(kA_in_offset + Ain.cols() <= Aout.cols());
62 BATMAT_ASSERT(std::make_pair(W.rows(), W.cols()) == (hyhound_W_size<T, Abi>)(L));
63 // Note: ignoring initial zero values of A in the FLOP count for simplicity (for large matrices
64 // this does not matter)
65 [[maybe_unused]] const index_t flop_count = total(flops::hyh_apply(L.rows(), L.cols(), k));
66 GUANAQO_TRACE_LINALG("hyhound_diag_apply", flop_count * L.depth());
67 if (k == 0) [[unlikely]]
68 return;
69 return hyhound_diag_apply_register<T, Abi, Conf>(L, Ain, Aout, B, D, W, kA_in_offset);
70}
71
72template <class T, class Abi, micro_kernels::hyhound::KernelConfig Conf, StorageOrder OL1,
76 const index_t k = A1.cols(), m = L11.rows() + L21.rows();
77 BATMAT_ASSERT(L11.rows() >= L11.cols());
78 BATMAT_ASSERT(L11.rows() == A1.rows());
79 BATMAT_ASSERT(A1.cols() == D.rows());
80 BATMAT_ASSERT(A2.cols() == A1.cols());
81 BATMAT_ASSERT(L21.cols() == L11.cols());
82 [[maybe_unused]] const index_t flop_count = total(flops::hyh(m, L11.cols(), k));
83 GUANAQO_TRACE_LINALG("hyhound_diag_2", flop_count * L11.depth());
84 if (k == 0) [[unlikely]]
85 return;
87 L11, A1, L21, A2, D);
88}
89
90template <class T, class Abi, micro_kernels::hyhound::KernelConfig Conf, StorageOrder OL,
96 const index_t k = A1.cols(), m = L11.rows() + L21.rows() + L31.rows();
97 BATMAT_ASSERT(L11.rows() >= L11.cols());
98 BATMAT_ASSERT(L11.rows() == A1.rows());
99 BATMAT_ASSERT(L21.rows() == A22.rows());
100 BATMAT_ASSERT(L31.rows() == A31.rows());
101 BATMAT_ASSERT(A1.cols() == D.rows());
102 BATMAT_ASSERT(A22.cols() + A31.cols() == A1.cols());
103 BATMAT_ASSERT(L21.cols() == L11.cols());
104 BATMAT_ASSERT(L31.cols() == L11.cols());
105 // Note: ignoring initial zero values of A in the FLOP count for simplicity (for large matrices
106 // this does not matter)
107 [[maybe_unused]] const index_t flop_count = total(flops::hyh(m, L11.cols(), k));
108 GUANAQO_TRACE_LINALG("hyhound_diag_cyclic", flop_count * L11.depth());
109 if (k == 0) [[unlikely]]
110 return;
112 L11, A1, L21, A22, A2_out, L31, A31, A3_out, D);
113}
114
115template <class T, class Abi, micro_kernels::hyhound::KernelConfig Conf, StorageOrder OL,
119 view<T, Abi, OAu> Au_out, view<const T, Abi> D, bool shift_A_out) {
120 const index_t k = A1.cols(), m = L11.rows() + L21.rows() + Lu1.rows();
121 BATMAT_ASSERT(L11.rows() >= L11.cols());
122 BATMAT_ASSERT(L11.rows() == A1.rows());
123 BATMAT_ASSERT(L21.rows() == A2.rows());
124 BATMAT_ASSERT(A2_out.rows() == A2.rows());
125 BATMAT_ASSERT(A2_out.cols() == A2.cols());
126 BATMAT_ASSERT(Lu1.rows() == Au_out.rows());
127 BATMAT_ASSERT(A1.cols() == D.rows());
128 BATMAT_ASSERT(A2.cols() == A1.cols());
129 BATMAT_ASSERT(L21.cols() == L11.cols());
130 BATMAT_ASSERT(Lu1.cols() == L11.cols());
131 // Note: ignoring upper trapezoidal shape of Lu and initial zero value of Au for simplicity
132 // (for large matrices this does not matter)
133 [[maybe_unused]] index_t flop_count = total(flops::hyh(m, L11.cols(), k));
134 GUANAQO_TRACE_LINALG("hyhound_diag_riccati", flop_count * L11.depth());
135 if (k == 0) [[unlikely]]
136 return;
138 L11, A1, L21, A2, A2_out, Lu1, Au_out, D, shift_A_out);
139}
140
141} // namespace detail
142
143/// @addtogroup topic-linalg
144/// @{
145
146/// @name Cholesky factorization updates
147/// @{
148
149/// Update Cholesky factor L using low-rank term A diag(d) Aᵀ.
150template <MatrixStructure SL, simdifiable VL, simdifiable VA, simdifiable Vd>
152void hyhound_diag(Structured<VL, SL> L, VA &&A, Vd &&d) {
153 static_assert(SL == MatrixStructure::LowerTriangular); // TODO
155 simdify(L.value), simdify(A), simdify(d).as_const());
156}
157
158/// Update Cholesky factor L using low-rank term A diag(d) Aᵀ, with full Householder representation.
159template <MatrixStructure SL, simdifiable VL, simdifiable VA, simdifiable Vd, simdifiable VW>
161void hyhound_diag(Structured<VL, SL> L, VA &&A, Vd &&d, VW &&W) {
162 static_assert(SL == MatrixStructure::LowerTriangular); // TODO
164 simdify(L.value), simdify(A), simdify(d).as_const(), simdify(W));
165}
166
167/// Get the size of the storage for the matrix W returned by
168/// @ref hyhound_diag(Structured<VL,SL>, VA&&, Vd&&, VW&&).
169template <MatrixStructure SL, simdifiable VL>
174
175/// Apply Householder transformation generated by @ref hyhound_diag, computing (L̃, D) = (L, A) Q̆.
176/// @param[in,out] L Part of the Cholesky factor to be updated (rectangular). Overwritten by L̃.
177/// @param[in] A Update matrix (rectangular).
178/// @param[out] D Updated update matrix (rectangular).
179/// @param[in] B Householder reflector vectors returned by @ref hyhound_diag.
180/// @param[in] d Diagonal update matrix.
181/// @param[in] W Householder representation returned by @ref hyhound_diag.
182/// @param[in] kA_in_offset If A is smaller than D, A is implicitly padded with zero
183/// columns: the offset of the nonzero part of A in the padded
184/// matrix can be specified using @p kA_in_offset.
185template <simdifiable VL, simdifiable VA, simdifiable VD, simdifiable VB, simdifiable Vd,
186 simdifiable VW>
188void hyhound_diag_apply(VL &&L, VA &&A, VD &&D, VB &&B, Vd &&d, VW &&W, index_t kA_in_offset = 0) {
190 simdify(L), simdify(A).as_const(), simdify(D), simdify(B).as_const(), simdify(d).as_const(),
191 simdify(W).as_const(), kA_in_offset);
192}
193
194/// Apply Householder transformation generated by @ref hyhound_diag, computing (L̃, Ã) = (L, A) Q̆.
195/// @param[in,out] L Part of the Cholesky factor to be updated (rectangular). Overwritten by L̃.
196/// @param[in,out] A Update matrix (rectangular). Overwritten by Ã.
197/// @param[in] B Householder reflector vectors returned by @ref hyhound_diag.
198/// @param[in] d Diagonal update matrix.
199/// @param[in] W Householder representation returned by @ref hyhound_diag.
200template <simdifiable VL, simdifiable VA, simdifiable VB, simdifiable Vd, simdifiable VW>
202void hyhound_diag_apply(VL &&L, VA &&A, VB &&B, Vd &&d, VW &&W) {
204 simdify(L), simdify(A).as_const(), simdify(A), simdify(B).as_const(), simdify(d).as_const(),
205 simdify(W).as_const(), 0);
206}
207
208/// Update Cholesky factor L using low-rank term A diag(copysign(1, d)) Aᵀ,
209/// where d contains only ±0 values.
210template <MatrixStructure SL, simdifiable VL, simdifiable VA, simdifiable Vd>
212void hyhound_sign(Structured<VL, SL> L, VA &&A, Vd &&d) {
214 simdify(L.value), simdify(A), simdify(d).as_const());
215}
216
217/// Update Cholesky factor L using low-rank term A diag(d) Aᵀ, where L and A are stored as two
218/// separate block rows.
219/// @f[
220/// L = \begin{pmatrix} L_{11} \\ L_{21} \end{pmatrix}, \quad
221/// A = \begin{pmatrix} A_{1} \\ A_{2} \end{pmatrix}.
222/// @f]
223template <MatrixStructure SL, simdifiable VL1, simdifiable VA1, simdifiable VL2, simdifiable VA2,
224 simdifiable Vd>
226void hyhound_diag_2(Structured<VL1, SL> L1, VA1 &&A1, VL2 &&L2, VA2 &&A2, Vd &&d) {
228 simdify(L1.value), simdify(A1), simdify(L2), simdify(A2), simdify(d).as_const());
229}
230
231/// Update structured Cholesky factor L using structured low-rank term A diag(d) Aᵀ,
232/// @f[
233/// L = \begin{pmatrix} L_{11} \\ L_{21} \\ L_{31} \end{pmatrix}, \quad
234/// A = \begin{pmatrix} A_{11} & A_{12} \\ 0 & A_{22} \\ A_{31} & 0 \end{pmatrix}, \quad
235/// \tilde A = \begin{pmatrix} 0 \\ \tilde A_{2} \\ \tilde A_{3} \end{pmatrix}.
236/// @f]
237template <MatrixStructure SL, simdifiable VL11, simdifiable VA1, simdifiable VL21, simdifiable VA2,
238 simdifiable VA2o, simdifiable VU, simdifiable VA3, simdifiable VA3o, simdifiable Vd>
240void hyhound_diag_cyclic(Structured<VL11, SL> L11, VA1 &&A1, VL21 &&L21, VA2 &&A22, VA2o &&A2_out,
241 VU &&L31, VA3 &&A31, VA3o &&A3_out, Vd &&d) {
243 simdify(L11.value), simdify(A1), simdify(L21), simdify(A22).as_const(), simdify(A2_out),
244 simdify(L31), simdify(A31).as_const(), simdify(A3_out), simdify(d).as_const());
245}
246
247/// Update structured Cholesky factor L using structured low-rank term A diag(d) Aᵀ,
248/// @f[
249/// L = \begin{pmatrix} L_{11} \\ L_{21} \\ L_{u} \end{pmatrix}, \quad
250/// A = \begin{pmatrix} A_{1} \\ A_{2} \\ 0 \end{pmatrix}, \quad
251/// \tilde A = \begin{pmatrix} 0 \\ \tilde A_{2} \\ \tilde A_{u} \end{pmatrix}.
252/// @f]
253/// The @p shift_A_out parameter indicates whether the output matrix A2_out should be shifted along
254/// the batch dimension. This is used in the Cyqlone solver.
255template <MatrixStructure SL, simdifiable VL11, simdifiable VA1, simdifiable VL21, simdifiable VA2,
256 simdifiable VA2o, simdifiable VLu1, simdifiable VAuo, simdifiable Vd>
258void hyhound_diag_riccati(Structured<VL11, SL> L11, VA1 &&A1, VL21 &&L21, VA2 &&A2, VA2o &&A2_out,
259 VLu1 &&Lu1, VAuo &&Au_out, Vd &&d, bool shift_A_out = false) {
261 simdify(L11.value), simdify(A1), simdify(L21), simdify(A2).as_const(), simdify(A2_out),
262 simdify(Lu1), simdify(Au_out), simdify(d).as_const(), shift_A_out);
263}
264
265/// @}
266
267/// @}
268
269} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
constexpr FlopCount hyh(index_t nr, index_t nc, index_t k)
Hyperbolic Householder factorization update with L nr×nc and A nr×k.
Definition flops.hpp:173
constexpr FlopCount hyh_apply(index_t nr, index_t nc, index_t k)
Hyperbolic Householder factorization application to L2 nr×nc and A2 nr×k.
Definition flops.hpp:161
void hyhound_diag_apply(VL &&L, VA &&A, VD &&D, VB &&B, Vd &&d, VW &&W, index_t kA_in_offset=0)
Apply Householder transformation generated by hyhound_diag, computing (L̃, D) = (L,...
Definition hyhound.hpp:188
void hyhound_diag_riccati(Structured< VL11, SL > L11, VA1 &&A1, VL21 &&L21, VA2 &&A2, VA2o &&A2_out, VLu1 &&Lu1, VAuo &&Au_out, Vd &&d, bool shift_A_out=false)
Update structured Cholesky factor L using structured low-rank term A diag(d) Aᵀ,.
Definition hyhound.hpp:258
void hyhound_sign(Structured< VL, SL > L, VA &&A, Vd &&d)
Update Cholesky factor L using low-rank term A diag(copysign(1, d)) Aᵀ, where d contains only ±0 valu...
Definition hyhound.hpp:212
auto hyhound_size_W(Structured< VL, SL > L)
Get the size of the storage for the matrix W returned by hyhound_diag(Structured<VL,...
Definition hyhound.hpp:170
void hyhound_diag_2(Structured< VL1, SL > L1, VA1 &&A1, VL2 &&L2, VA2 &&A2, Vd &&d)
Update Cholesky factor L using low-rank term A diag(d) Aᵀ, where L and A are stored as two separate b...
Definition hyhound.hpp:226
void hyhound_diag(Structured< VL, SL > L, VA &&A, Vd &&d)
Update Cholesky factor L using low-rank term A diag(d) Aᵀ.
Definition hyhound.hpp:152
void hyhound_diag_cyclic(Structured< VL11, SL > L11, VA1 &&A1, VL21 &&L21, VA2 &&A22, VA2o &&A2_out, VU &&L31, VA3 &&A31, VA3o &&A3_out, Vd &&d)
Update structured Cholesky factor L using structured low-rank term A diag(d) Aᵀ,.
Definition hyhound.hpp:240
#define GUANAQO_TRACE_LINALG(name, gflops)
void hyhound_diag_apply(view< T, Abi, OL > L, view< const T, Abi, OA > Ain, view< T, Abi, OA > Aout, view< const T, Abi, OA > B, view< const T, Abi > D, view< const T, Abi > W, index_t kA_in_offset=0)
Definition hyhound.hpp:50
void hyhound_diag_riccati(view< T, Abi, OL > L11, view< T, Abi, OA > A1, view< T, Abi, OL > L21, view< const T, Abi, OA > A2, view< T, Abi, OA > A2_out, view< T, Abi, OLu > Lu1, view< T, Abi, OAu > Au_out, view< const T, Abi > D, bool shift_A_out)
Definition hyhound.hpp:117
void hyhound_diag(view< T, Abi, OL > L, view< T, Abi, OA > A, view< const T, Abi > D)
Definition hyhound.hpp:20
void hyhound_diag_cyclic(view< T, Abi, OL > L11, view< T, Abi, OW > A1, view< T, Abi, OY > L21, view< const T, Abi, OW > A22, view< T, Abi, OW > A2_out, view< T, Abi, OU > L31, view< const T, Abi, OW > A31, view< T, Abi, OW > A3_out, view< const T, Abi > D)
Definition hyhound.hpp:92
void hyhound_diag_2(view< T, Abi, OL1 > L11, view< T, Abi, OA1 > A1, view< T, Abi, OL2 > L21, view< T, Abi, OA2 > A2, view< const T, Abi > D)
Definition hyhound.hpp:74
void hyhound_diag_cyclic_register(view< T, Abi, OL > L11, view< T, Abi, OW > A1, view< T, Abi, OY > L21, view< const T, Abi, OW > A22, view< T, Abi, OW > A2_out, view< T, Abi, OU > L31, view< const T, Abi, OW > A31, view< T, Abi, OW > A3_out, view< const T, Abi > D) noexcept
Performs a factorization update of the following matrix:
Definition hyhound.tpp:514
void hyhound_diag_register(view< T, Abi, OL > L, view< T, Abi, OA > A, view< const T, Abi > D) noexcept
Block hyperbolic Householder factorization update using register blocking.
Definition hyhound.tpp:288
void hyhound_diag_riccati_register(view< T, Abi, OL > L11, view< T, Abi, OA > A1, view< T, Abi, OL > L21, view< const T, Abi, OA > A2, view< T, Abi, OA > A2_out, view< T, Abi, OLu > Lu1, view< T, Abi, OAu > Au_out, view< const T, Abi > D, bool shift_A_out) noexcept
Performs a factorization update of the following matrix:
Definition hyhound.tpp:608
constexpr std::pair< index_t, index_t > hyhound_W_size(view< T, Abi, OL > L)
Definition hyhound.hpp:82
void hyhound_diag_2_register(view< T, Abi, OL1 > L11, view< T, Abi, OA1 > A1, view< T, Abi, OL2 > L21, view< T, Abi, OA2 > A2, view< const T, Abi > D) noexcept
Same as hyhound_diag_register but for two block rows at once.
Definition hyhound.tpp:452
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:204
constexpr bool simdify_compatible
Definition simdify.hpp:207
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:214
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Aligned allocation for matrix storage.
Light-weight wrapper class used for overload resolution of triangular and symmetric matrices.