batmat 0.0.16
Batched linear algebra routines
Loading...
Searching...
No Matches
gemm-diag.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/kib.hpp>
11#include <batmat/loop.hpp>
13#include <guanaqo/trace.hpp>
14#include <optional>
15
16namespace batmat::linalg {
17
19template <class T, class Abi, micro_kernels::gemm_diag::KernelConfig Conf = {}, StorageOrder OA,
23 // Check dimensions
24 BATMAT_ASSERT(!C || C->rows() == D.rows());
25 BATMAT_ASSERT(!C || C->cols() == D.cols());
26 BATMAT_ASSERT(A.rows() == D.rows());
27 BATMAT_ASSERT(A.cols() == B.rows());
28 BATMAT_ASSERT(A.cols() == d.rows());
29 BATMAT_ASSERT(d.cols() == 1);
30 BATMAT_ASSERT(B.cols() == D.cols());
31 const index_t M = D.rows(), N = D.cols(), K = A.cols();
32 [[maybe_unused]] const auto fc = flops::gemmt_diag(M, N, K, Conf.struc_C);
33 GUANAQO_TRACE_LINALG("gemm_diag", total(fc) * A.depth());
34
35 // Degenerate case
36 if (M == 0 || N == 0) [[unlikely]]
37 return;
38 if (K == 0) [[unlikely]] {
39 constexpr detail::copy::CopyConfig rot{.struc = Conf.struc_C};
40 constexpr detail::copy::FillConfig msk{.struc = Conf.struc_C};
41 if (C)
43 else
45 return;
46 }
47 // TODO: cache blocking
49}
50} // namespace detail::gemm_diag
51
52template <bool Z>
53struct track_zeros_t : std::bool_constant<Z> {};
54
55template <bool Z = true>
57
58namespace detail::gemm_diag {
59template <class...>
60inline constexpr std::optional<bool> get_track_zeros = std::nullopt;
61template <class T, class... Ts>
62inline constexpr std::optional<bool> get_track_zeros<T, Ts...> = get_track_zeros<Ts...>;
63template <bool Z, class... Ts>
64inline constexpr std::optional<bool> get_track_zeros<track_zeros_t<Z>, Ts...> = Z;
65
66template <class>
67inline constexpr bool is_track_zeros_opt = false;
68template <bool Z>
69inline constexpr bool is_track_zeros_opt<track_zeros_t<Z>> = true;
70
71template <class Opt>
73
74template <class... Opts>
77 if (auto z = get_track_zeros<Opts...>)
78 conf.track_zeros = *z;
79 return conf;
80}
81} // namespace detail::gemm_diag
82
83/// @addtogroup topic-linalg
84/// @{
85
86/// @name Multiplication of batches of matrices with diagonal scaling
87/// @{
88
89/// D = A diag(d) B
90template <simdifiable VA, simdifiable VB, simdifiable VD, simdifiable Vd,
93void gemm_diag(VA &&A, VB &&B, VD &&D, Vd &&d, Opts... opts) {
94 std::optional<decltype(simdify(D).as_const())> null;
95 constexpr auto conf = detail::gemm_diag::apply_options({.negate = false}, opts...);
97 simdify(A).as_const(), simdify(B).as_const(), null, simdify(D), simdify(d).as_const());
98}
99
100/// D = C + A diag(d) B
101template <simdifiable VA, simdifiable VB, simdifiable VC, simdifiable VD, simdifiable Vd,
102 detail::gemm_diag::track_zeros_opt... Opts>
104void gemm_diag_add(VA &&A, VB &&B, VC &&C, VD &&D, Vd &&d, Opts... opts) {
105 constexpr auto conf = detail::gemm_diag::apply_options({.negate = false}, opts...);
107 simdify(A).as_const(), simdify(B).as_const(), std::make_optional(simdify(C).as_const()),
108 simdify(D), simdify(d).as_const());
109}
110/// D += A diag(d) B
111template <simdifiable VA, simdifiable VB, simdifiable VD, simdifiable Vd,
112 detail::gemm_diag::track_zeros_opt... Opts>
114void gemm_diag_add(VA &&A, VB &&B, VD &&D, Vd &&d, Opts... opts) {
115 gemm_diag_add(A, B, D, D, d, opts...);
116}
117
118/// D = C + A diag(d) Aᵀ with C, D symmetric
119template <MatrixStructure SC, simdifiable VA, simdifiable VC, simdifiable VD, simdifiable Vd,
120 detail::gemm_diag::track_zeros_opt... Opts>
122void syrk_diag_add(VA &&A, Structured<VC, SC> C, Structured<VD, SC> D, Vd &&d, Opts... opts) {
123 static_assert(SC != MatrixStructure::General);
124 constexpr auto conf =
125 detail::gemm_diag::apply_options({.negate = false, .struc_C = SC}, opts...);
127 simdify(A).as_const(), simdify(A).as_const().transposed(),
128 std::make_optional(simdify(C.value).as_const()), simdify(D.value), simdify(d).as_const());
129}
130/// D += A diag(d) Aᵀ with D symmetric
131template <MatrixStructure SC, simdifiable VA, simdifiable VD, simdifiable Vd,
132 detail::gemm_diag::track_zeros_opt... Opts>
134void syrk_diag_add(VA &&A, Structured<VD, SC> D, Vd &&d, Opts... opts) {
135 syrk_diag_add(A, D, D, d, opts...);
136}
137
138/// @}
139
140/// @}
141
142} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
constexpr FlopCount gemmt_diag(index_t m, index_t n, index_t k, MatrixStructure sC)
Matrix-matrix multiplication of m×k and k×n matrices with a diagonal k×k matrix in the middle,...
Definition flops.hpp:122
void gemm_diag_add(VA &&A, VB &&B, VC &&C, VD &&D, Vd &&d, Opts... opts)
D = C + A diag(d) B.
void syrk_diag_add(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, Vd &&d, Opts... opts)
D = C + A diag(d) Aᵀ with C, D symmetric.
void gemm_diag(VA &&A, VB &&B, VD &&D, Vd &&d, Opts... opts)
D = A diag(d) B.
Definition gemm-diag.hpp:93
#define GUANAQO_TRACE_LINALG(name, gflops)
void fill(T a, view< T, Abi, OB > B)
Definition copy.hpp:27
void copy(view< const T, Abi, OA > A, view< T, Abi, OB > B)
Definition copy.hpp:68
void gemm_diag(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D, view< const T, Abi > d)
Definition gemm-diag.hpp:21
constexpr micro_kernels::gemm_diag::KernelConfig apply_options(micro_kernels::gemm_diag::KernelConfig conf, Opts...)
Definition gemm-diag.hpp:76
constexpr std::optional< bool > get_track_zeros
Definition gemm-diag.hpp:60
void gemm_diag_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D, view< const T, Abi > diag) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Using register blocking.
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:204
constexpr track_zeros_t< Z > track_zeros
Definition gemm-diag.hpp:56
constexpr bool simdify_compatible
Definition simdify.hpp:207
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:214
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Aligned allocation for matrix storage.
Light-weight wrapper class used for overload resolution of triangular and symmetric matrices.