batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
sytrd.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
11#include <guanaqo/trace.hpp>
12
13namespace batmat::linalg {
14
15namespace detail {
16template <class T, class Abi, micro_kernels::sytrd::KernelConfig Conf, StorageOrder OD>
18 // Check dimensions
19 BATMAT_ASSERT(D.rows() == D.cols());
21 W.rows() == 0 || (W.cols() == 1 && W.rows() == std::max<index_t>(D.cols(), 1) - 1) ||
22 std::make_pair(W.rows(), W.cols()) == (micro_kernels::sytrd::sytrd_W_size<T, Abi>)(D));
23 BATMAT_ASSERT(std::make_pair(Y.rows(), Y.cols()) ==
25 const index_t M = D.rows();
26 [[maybe_unused]] const auto fc = flops::sytrd(M);
27 GUANAQO_TRACE_LINALG("sytrd", total(fc) * D.depth());
28 // Degenerate case
29 if (M < 3) [[unlikely]] {
30 if (W.rows() > 0 && W.cols() > 0)
31 W.set_constant(T{}); // identity
32 return;
33 }
35}
36
37template <class T, class Abi, micro_kernels::geqrf::KernelConfig Conf, StorageOrder OA,
40 view<const T, Abi> W, bool transposed) {
41 const index_t k = A.rows();
42 if (k == 0)
43 return;
44 if (A.data() != D.data())
45 linalg::copy(A.top_rows(1), D.top_rows(1));
46 geqrf_apply<T, Abi, Conf>(A.bottom_rows(k - 1), D.bottom_rows(k - 1),
47 B.bottom_left(k - 1, k - 1), W, transposed, false);
48}
49} // namespace detail
50
51/// @addtogroup topic-linalg
52/// @{
53
54/// @name Tridiagonalization of batches of matrices
55/// @{
56
57/// Tridiagonalization. The resulting diagonal and subdiagonal elements overwrite D, and the
58/// Householder vectors are stored below the subdiagonal (with the first component implicitly equal
59/// to 1). The Householder coefficients are stored in W, which should either be a vector of
60/// `A.cols() - 1` elements, or a matrix of size `sytrd_size_W(A)`. If W has zero rows, the
61/// coefficients are discarded. The workspace Y is used internally and should have size
62/// `sytrd_Y_size(D)`.
63template <simdifiable VD, simdifiable VW, simdifiable VY>
69
70template <simdifiable VA, simdifiable VD, simdifiable VB, simdifiable VW>
73 bool transposed = false) {
75 simdify(A).as_const(), simdify(D), simdify(B.value).as_const(), simdify(W).as_const(),
76 transposed);
77}
78
79template <simdifiable VD, simdifiable VB, simdifiable VW>
82 bool transposed = false) {
84 simdify(D).as_const(), simdify(D), simdify(B.value).as_const(), simdify(W).as_const(),
85 transposed);
86}
87
88/// Get the size of the storage for the matrix W returned by
89/// @ref sytrd(Structured<VD, MatrixStructure::LowerTriangular> D, VW &&W, VY &&Y).
90template <simdifiable VD>
95
96/// Get the size of the storage for the matrix Y used by
97/// @ref sytrd(Structured<VD, MatrixStructure::LowerTriangular> D, VW &&W, VY &&Y).
98template <simdifiable VD>
103
104/// @}
105
106/// @}
107
108} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
constexpr FlopCount sytrd(index_t m)
Symmetric tridiagonalization of an m×m matrix.
Definition flops.hpp:229
auto sytrd_size_Y(VD &&D)
Get the size of the storage for the matrix Y used by sytrd(Structured<VD, MatrixStructure::LowerTrian...
Definition sytrd.hpp:99
void sytrd_apply(VA &&A, VD &&D, Structured< VB, MatrixStructure::LowerTriangular > B, VW &&W, bool transposed=false)
Definition sytrd.hpp:72
void copy(VA &&A, VB &&B, Opts... opts)
B = A.
Definition copy.hpp:187
auto sytrd_size_W(VD &&D)
Get the size of the storage for the matrix W returned by sytrd(Structured<VD, MatrixStructure::LowerT...
Definition sytrd.hpp:91
void sytrd(Structured< VD, MatrixStructure::LowerTriangular > D, VW &&W, VY &&Y)
Tridiagonalization.
Definition sytrd.hpp:65
#define GUANAQO_TRACE_LINALG(name, gflops)
void geqrf_apply(view< const T, Abi, OA > A, view< T, Abi, OD > D, view< const T, Abi, OB > B, view< const T, Abi > W, bool transposed, bool reversed)
Definition geqrf.hpp:36
void sytrd(view< T, Abi, OD > D, view< T, Abi > W, view< T, Abi > Y)
Definition sytrd.hpp:17
void sytrd_apply(view< const T, Abi, OA > A, view< T, Abi, OD > D, view< const T, Abi, OB > B, view< const T, Abi > W, bool transposed)
Definition sytrd.hpp:39
constexpr std::pair< index_t, index_t > sytrd_W_size(view< T, Abi, OD > D)
Definition sytrd.hpp:24
void sytrd_register(view< T, Abi, OD > D, view< T, Abi > W, view< T, Abi > Y) noexcept
Symmetric block tridiagonalization.
Definition sytrd.tpp:204
constexpr std::pair< index_t, index_t > sytrd_Y_size(view< T, Abi, OD > D)
Definition sytrd.hpp:32
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:216
constexpr bool simdify_compatible
Definition simdify.hpp:221
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
int index_t
Definition config.hpp:13
Light-weight wrapper class used for overload resolution of triangular and symmetric matrices.