batmat 0.0.14
Batched linear algebra routines
Loading...
Searching...
No Matches
gemv.hpp
Go to the documentation of this file.
1#pragma once
2
9#include <batmat/loop.hpp>
11#include <guanaqo/trace.hpp>
12
13namespace batmat::linalg {
14
15namespace detail {
16template <class T, class Abi, micro_kernels::gemv::KernelConfig Conf = {}, StorageOrder OA>
18 view<T, Abi> D) {
19 GUANAQO_TRACE_LINALG("gemv", A.rows() * A.cols() * B.cols() * A.depth());
20 // Check dimensions
21 BATMAT_ASSERT(!C || C->rows() == D.rows());
22 BATMAT_ASSERT(!C || C->cols() == D.cols());
23 BATMAT_ASSERT(A.rows() == D.rows());
24 BATMAT_ASSERT(A.cols() == B.rows());
25 BATMAT_ASSERT(B.cols() == D.cols());
26 BATMAT_ASSERT(B.cols() == 1);
27 const index_t M = D.rows(), K = A.cols();
28
29 // Degenerate case
30 if (M == 0) [[unlikely]]
31 return;
32 if (K == 0) [[unlikely]] {
33 // https://github.com/llvm/llvm-project/issues/146272
34 constexpr detail::copy::CopyConfig rot{.rotate = Conf.rotate_C - Conf.rotate_D,
35 .mask = Conf.mask_D};
36 constexpr detail::copy::FillConfig msk{.mask = Conf.mask_D};
37 if (C)
39 else
41 return;
42 }
44}
45
46template <shift_opt... Opts>
49 if (auto s = shift_A<Opts...>)
50 conf.shift_A = *s;
51 if (auto s = shift_B<Opts...>)
52 conf.shift_B = *s;
53 if (auto s = rotate_C<Opts...>)
54 conf.rotate_C = *s;
55 if (auto s = rotate_D<Opts...>)
56 conf.rotate_D = *s;
57 if (auto s = mask_D<Opts...>)
58 conf.mask_D = *s;
59 return conf;
60}
61
62} // namespace detail
63
64/// @addtogroup topic-linalg
65/// @{
66
67/// @name Matrix-vector multiplication of batches of matrices
68/// @{
69
70/// d = A b
71template <simdifiable VA, simdifiable VB, simdifiable VD, shift_opt... Opts>
73void gemv(VA &&A, VB &&B, VD &&D, Opts... opts) {
74 constexpr auto conf = detail::apply_gemv_options({.negate = false}, opts...);
75 std::optional<decltype(simdify(D).as_const())> null;
77 simdify(A).as_const(), simdify(B).as_const(), null, simdify(D));
78}
79
80/// d = -A b
81template <simdifiable VA, simdifiable VB, simdifiable VD, shift_opt... Opts>
83void gemv_neg(VA &&A, VB &&B, VD &&D, Opts... opts) {
84 constexpr auto conf = detail::apply_gemv_options({.negate = true}, opts...);
85 std::optional<decltype(simdify(D).as_const())> null;
87 simdify(A).as_const(), simdify(B).as_const(), null, simdify(D));
88}
89
90/// d = c + A b
91template <simdifiable VA, simdifiable VB, simdifiable VC, simdifiable VD, shift_opt... Opts>
93void gemv_add(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts) {
94 constexpr auto conf = detail::apply_gemv_options({.negate = false}, opts...);
96 simdify(A).as_const(), simdify(B).as_const(), simdify(C).as_const(), simdify(D));
97}
98/// d = d + A b
99template <simdifiable VA, simdifiable VB, simdifiable VD, shift_opt... Opts>
101void gemv_add(VA &&A, VB &&B, VD &&D, Opts... opts) {
102 gemv_add(A, B, D, D, opts...);
103}
104
105/// d = c - A b
106template <simdifiable VA, simdifiable VB, simdifiable VC, simdifiable VD, shift_opt... Opts>
108void gemv_sub(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts) {
109 constexpr auto conf = detail::apply_gemv_options({.negate = true}, opts...);
111 simdify(A).as_const(), simdify(B).as_const(), simdify(C).as_const(), simdify(D));
112}
113/// d = d - A b
114template <simdifiable VA, simdifiable VB, simdifiable VD, shift_opt... Opts>
116void gemv_sub(VA &&A, VB &&B, VD &&D, Opts... opts) {
117 gemv_sub(A, B, D, D, opts...);
118}
119
120/// @}
121
122/// @}
123
124} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
void gemv_add(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
d = c + A b
Definition gemv.hpp:93
void gemv(VA &&A, VB &&B, VD &&D, Opts... opts)
d = A b
Definition gemv.hpp:73
void gemv_neg(VA &&A, VB &&B, VD &&D, Opts... opts)
d = -A b
Definition gemv.hpp:83
void gemv_sub(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
d = c - A b
Definition gemv.hpp:108
#define GUANAQO_TRACE_LINALG(name, gflops)
void fill(T a, view< T, Abi, OB > B)
Definition copy.hpp:27
void copy(view< const T, Abi, OA > A, view< T, Abi, OB > B)
Definition copy.hpp:68
constexpr micro_kernels::gemv::KernelConfig apply_gemv_options(micro_kernels::gemv::KernelConfig conf, Opts...)
Definition gemv.hpp:48
void gemv(view< const T, Abi, OA > A, view< const T, Abi > B, std::optional< view< const T, Abi > > C, view< T, Abi > D)
Definition gemv.hpp:17
void gemv_copy_register(view< const T, Abi, OA > A, view< const T, Abi > B, std::optional< view< const T, Abi > > C, view< T, Abi > D) noexcept
Generalized matrix multiplication d = c ± A⁽ᵀ⁾ b. Using register blocking.
Definition gemv.tpp:79
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:204
constexpr std::optional< int > rotate_C
Definition shift.hpp:45
constexpr bool simdify_compatible
Definition simdify.hpp:207
constexpr std::optional< int > mask_D
Definition shift.hpp:59
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:214
constexpr std::optional< int > shift_B
Definition shift.hpp:38
constexpr std::optional< int > rotate_D
Definition shift.hpp:52
constexpr std::optional< int > shift_A
Definition shift.hpp:31
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Aligned allocation for matrix storage.