batmat 0.0.16
Batched linear algebra routines
Loading...
Searching...
No Matches
trsm.hpp
Go to the documentation of this file.
1#pragma once
2
10#include <batmat/loop.hpp>
12#include <guanaqo/trace.hpp>
13
14namespace batmat::linalg {
15
16// TODO: move these to shift.hpp (but make sure the concepts for gemm are still correct)
17template <int I>
18struct with_rotate_A_t : std::integral_constant<int, I> {};
19template <int I>
21
22namespace detail {
23template <class T, class Abi, micro_kernels::trsm::KernelConfig Conf, StorageOrder OA,
25 requires(Conf.struc_A != MatrixStructure::General)
27 // Check dimensions
28 BATMAT_ASSERT(A.rows() == A.cols()); // TODO: could be relaxed
29 BATMAT_ASSERT(A.cols() == B.rows());
30 BATMAT_ASSERT(B.rows() == D.rows());
31 BATMAT_ASSERT(B.cols() == D.cols());
32 const index_t M = A.rows(), K = A.cols(), N = B.cols();
33 [[maybe_unused]] const auto fc = flops::trsm(M, N);
34 GUANAQO_TRACE_LINALG("trsm", total(fc) * D.depth());
35 // Degenerate case
36 if (M == 0 || N == 0 || K == 0) [[unlikely]]
37 return;
39}
40} // namespace detail
41
42/// @addtogroup topic-linalg
43/// @{
44
45/// @name Triangular solve of batches of matrices
46/// @{
47
48/// D = A⁻¹ B with A triangular
49template <MatrixStructure SA, simdifiable VA, simdifiable VB, simdifiable VD, int RotB = 0>
51void trsm(Structured<VA, SA> A, VB &&B, VD &&D, with_rotate_B_t<RotB> = {}) {
52 detail::trsm<simdified_value_t<VA>, simdified_abi_t<VA>, {.struc_A = SA, .rotate_B = RotB}>(
53 simdify(A.value).as_const(), simdify(B).as_const(), simdify(D));
54}
55/// D = A⁻¹ D with A triangular
56template <MatrixStructure SA, simdifiable VA, simdifiable VD, int RotB = 0>
58void trsm(Structured<VA, SA> A, VD &&D, with_rotate_B_t<RotB> shift = {}) {
59 trsm(A.ref(), D, D, shift);
60}
61
62/// D = A B⁻¹ with B triangular
63template <MatrixStructure SB, simdifiable VA, simdifiable VB, simdifiable VD, int RotA = 0>
65void trsm(VA &&A, Structured<VB, SB> B, VD &&D, with_rotate_A_t<RotA> = {}) {
66 // D = B A⁻¹ <=> Dᵀ = A⁻ᵀ Bᵀ
68 {.struc_A = transpose(SB), .rotate_B = RotA}>(
69 simdify(B.value).transposed().as_const(), simdify(A).transposed().as_const(),
70 simdify(D).transposed());
71}
72/// D = D B⁻¹ with B triangular
73template <MatrixStructure SB, simdifiable VB, simdifiable VD, int RotA = 0>
75void trsm(VD &&D, Structured<VB, SB> B, with_rotate_A_t<RotA> shift = {}) {
76 trsm(D, B.ref(), D, shift);
77}
78
79/// @}
80
81/// @}
82
83} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
constexpr FlopCount trsm(index_t m, index_t n)
Triangular solve of m×n matrices.
Definition flops.hpp:191
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
D = A⁻¹ B with A triangular.
Definition trsm.hpp:51
constexpr MatrixStructure transpose(MatrixStructure s)
Definition structure.hpp:11
#define GUANAQO_TRACE_LINALG(name, gflops)
void trsm(view< const T, Abi, OA > A, view< const T, Abi, OB > B, view< T, Abi, OD > D)
Definition trsm.hpp:26
void trsm_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, view< T, Abi, OD > D) noexcept
Triangular solve D = (A⁽ᵀ⁾)⁻¹ B⁽ᵀ⁾ where A⁽ᵀ⁾ is lower triangular.
Definition trsm.tpp:85
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:204
constexpr bool simdify_compatible
Definition simdify.hpp:207
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:214
constexpr with_rotate_A_t< I > with_rotate_A
Definition trsm.hpp:20
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Aligned allocation for matrix storage.
Light-weight wrapper class used for overload resolution of triangular and symmetric matrices.