batmat 0.0.17
Batched linear algebra routines
Loading...
Searching...
No Matches
gemv.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
7#include <batmat/lut.hpp>
9
10#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
11
13
14template <class T, class Abi, KernelConfig Conf, StorageOrder OA>
15inline const constinit auto gemv_copy_lut =
18 });
19
20/// Generalized matrix-vector multiplication d = c ± A⁽ᵀ⁾ b. Single register block.
21template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder OA>
22[[gnu::hot, gnu::flatten]] void
26 const uview<T, Abi, StorageOrder::ColMajor> D, const index_t k) noexcept {
27 static_assert(RowsReg > 0);
28 using enum MatrixStructure;
29 using namespace ops;
30 using simd = datapar::simd<T, Abi>;
31 BATMAT_ASSUME(k > 0);
32 if constexpr (OA == StorageOrder::RowMajor) {
33 // Load accumulator into registers
34 simd C_reg[RowsReg]; // NOLINT(*-c-arrays)
35 if (C) [[likely]] {
36 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
37 C_reg[ii] = rotl<Conf.rotate_C>(C->load(ii, 0));
38 } else {
39 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
40 C_reg[ii] = simd{0};
41 }
42 // Matrix-vector multiplication kernel
43 const auto A_cached = with_cached_access<RowsReg, 0>(A);
44 for (index_t l = 0; l < k; ++l) {
45 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
46 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l));
47 simd &Cij = C_reg[ii];
48 simd Blj = rotl<Conf.rotate_B>(B.load(l, 0));
49 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
50 }
51 }
52 // Store accumulator to memory again
53 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
54 D.template store<Conf.mask_D>(rotr<Conf.rotate_D>(C_reg[ii]), ii, 0);
55 } else {
56 // Load B into registers
57 simd B_reg[RowsReg]; // NOLINT(*-c-arrays)
58 UNROLL_FOR (index_t l = 0; l < RowsReg; ++l)
59 B_reg[l] = rotl<Conf.rotate_B>(B.load(l, 0));
60 // Matrix-vector multiplication kernel
61 const auto A_cached = with_cached_access<0, RowsReg>(A);
62 if (C) [[likely]] {
63 for (index_t i = 0; i < k; ++i) {
64 simd Cij = rotl<Conf.rotate_C>(C->load(i, 0));
65 UNROLL_FOR (index_t ll = 0; ll < RowsReg; ++ll) {
66 simd Ail = shiftl<Conf.shift_A>(A_cached.load(i, ll));
67 Conf.negate ? (Cij -= Ail * B_reg[ll]) : (Cij += Ail * B_reg[ll]);
68 }
69 D.template store<Conf.mask_D>(rotr<Conf.rotate_D>(Cij), i, 0);
70 }
71 } else {
72 for (index_t i = 0; i < k; ++i) {
73 simd Cij{0};
74 UNROLL_FOR (index_t ll = 0; ll < RowsReg; ++ll) {
75 simd Ail = shiftl<Conf.shift_A>(A_cached.load(i, ll));
76 Conf.negate ? (Cij -= Ail * B_reg[ll]) : (Cij += Ail * B_reg[ll]);
77 }
78 D.template store<Conf.mask_D>(rotr<Conf.rotate_D>(Cij), i, 0);
79 }
80 }
81 }
82}
83
84/// Generalized matrix multiplication d = c ± A⁽ᵀ⁾ b. Using register blocking.
85template <class T, class Abi, KernelConfig Conf, StorageOrder OA>
87 const std::optional<view<const T, Abi>> C, const view<T, Abi> D) noexcept {
88 using enum MatrixStructure;
89 constexpr auto Rows = RowsReg<T, Abi>;
90 // Check dimensions
91 const index_t I = D.rows(), K = A.cols();
92 BATMAT_ASSUME(A.rows() == I);
93 BATMAT_ASSUME(B.rows() == K);
94 BATMAT_ASSUME(B.cols() == 1);
95 BATMAT_ASSUME(D.cols() == 1);
96 BATMAT_ASSUME(I > 0);
97 BATMAT_ASSUME(K > 0);
98 static const auto microkernel = gemv_copy_lut<T, Abi, Conf, OA>;
99 // Sizeless views to partition and pass to the micro-kernels
100 const uview<const T, Abi, OA> A_ = A;
102 const std::optional<uview<const T, Abi, StorageOrder::ColMajor>> C_ = C;
104
105 if constexpr (OA == StorageOrder::RowMajor) {
106 if (I <= Rows)
107 return microkernel[I - 1](A_, B_, C_, D_, K);
108 foreach_chunked_merged(0, I, Rows, [&](index_t i, auto ni) {
109 auto Cj = C_ ? std::make_optional(C_->middle_rows(i)) : std::nullopt;
110 microkernel[ni - 1](A_.middle_rows(i), B_, Cj, D_.middle_rows(i), K);
111 });
112 } else {
113 if (K <= Rows)
114 return microkernel[K - 1](A_, B_, C_, D_, I);
115 microkernel[Rows - 1](A_.middle_cols(0), B_.middle_rows(0), C_, D_, I);
116 foreach_chunked_merged(Rows, K, Rows, [&](index_t k, auto nk) {
117 microkernel[nk - 1](A_.middle_cols(k), B_.middle_rows(k), D_, D_, I);
118 });
119 }
120}
121
122} // namespace batmat::linalg::micro_kernels::gemv
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
const constinit auto gemv_copy_lut
Definition gemv.tpp:15
void gemv_copy_register(view< const T, Abi, OA > A, view< const T, Abi > B, std::optional< view< const T, Abi > > C, view< T, Abi > D) noexcept
Generalized matrix multiplication d = c ± A⁽ᵀ⁾ b. Using register blocking.
Definition gemv.tpp:86
void gemv_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > B, std::optional< uview< const T, Abi, StorageOrder::ColMajor > > C, uview< T, Abi, StorageOrder::ColMajor > D, index_t k) noexcept
Generalized matrix-vector multiplication d = c ± A⁽ᵀ⁾ b. Single register block.
Definition gemv.tpp:23
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10