batmat 0.0.14
Batched linear algebra routines
Loading...
Searching...
No Matches
gemv.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
8
9#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
10
12
13/// Generalized matrix-vector multiplication d = c ± A⁽ᵀ⁾ b. Single register block.
14template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder OA>
15[[gnu::hot, gnu::flatten]] void
19 const uview<T, Abi, StorageOrder::ColMajor> D, const index_t k) noexcept {
20 static_assert(RowsReg > 0);
21 using enum MatrixStructure;
22 using namespace ops;
23 using simd = datapar::simd<T, Abi>;
24 BATMAT_ASSUME(k > 0);
25 if constexpr (OA == StorageOrder::RowMajor) {
26 // Load accumulator into registers
27 simd C_reg[RowsReg]; // NOLINT(*-c-arrays)
28 if (C) [[likely]] {
29 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
30 C_reg[ii] = rotl<Conf.rotate_C>(C->load(ii, 0));
31 } else {
32 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
33 C_reg[ii] = simd{0};
34 }
35 // Matrix-vector multiplication kernel
36 const auto A_cached = with_cached_access<RowsReg, 0>(A);
37 for (index_t l = 0; l < k; ++l) {
38 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
39 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l));
40 simd &Cij = C_reg[ii];
41 simd Blj = shiftl<Conf.shift_B>(B.load(l, 0));
42 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
43 }
44 }
45 // Store accumulator to memory again
46 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
47 D.template store<Conf.mask_D>(rotr<Conf.rotate_D>(C_reg[ii]), ii, 0);
48 } else {
49 // Load B into registers
50 simd B_reg[RowsReg]; // NOLINT(*-c-arrays)
51 UNROLL_FOR (index_t l = 0; l < RowsReg; ++l)
52 B_reg[l] = shiftl<Conf.shift_B>(B.load(l, 0));
53 // Matrix-vector multiplication kernel
54 const auto A_cached = with_cached_access<0, RowsReg>(A);
55 if (C) [[likely]] {
56 for (index_t i = 0; i < k; ++i) {
57 simd Cij = rotl<Conf.rotate_C>(C->load(i, 0));
58 UNROLL_FOR (index_t ll = 0; ll < RowsReg; ++ll) {
59 simd Ail = shiftl<Conf.shift_A>(A_cached.load(i, ll));
60 Conf.negate ? (Cij -= Ail * B_reg[ll]) : (Cij += Ail * B_reg[ll]);
61 }
62 D.template store<Conf.mask_D>(rotr<Conf.rotate_D>(Cij), i, 0);
63 }
64 } else {
65 for (index_t i = 0; i < k; ++i) {
66 simd Cij{0};
67 UNROLL_FOR (index_t ll = 0; ll < RowsReg; ++ll) {
68 simd Ail = shiftl<Conf.shift_A>(A_cached.load(i, ll));
69 Conf.negate ? (Cij -= Ail * B_reg[ll]) : (Cij += Ail * B_reg[ll]);
70 }
71 D.template store<Conf.mask_D>(rotr<Conf.rotate_D>(Cij), i, 0);
72 }
73 }
74 }
75}
76
77/// Generalized matrix multiplication d = c ± A⁽ᵀ⁾ b. Using register blocking.
78template <class T, class Abi, KernelConfig Conf, StorageOrder OA>
80 const std::optional<view<const T, Abi>> C, const view<T, Abi> D) noexcept {
81 using enum MatrixStructure;
82 constexpr auto Rows = RowsReg<T, Abi>;
83 // Check dimensions
84 const index_t I = D.rows(), K = A.cols();
85 BATMAT_ASSUME(A.rows() == I);
86 BATMAT_ASSUME(B.rows() == K);
87 BATMAT_ASSUME(B.cols() == 1);
88 BATMAT_ASSUME(D.cols() == 1);
89 BATMAT_ASSUME(I > 0);
90 BATMAT_ASSUME(K > 0);
91 static const auto microkernel = gemv_copy_lut<T, Abi, Conf, OA>;
92 // Sizeless views to partition and pass to the micro-kernels
93 const uview<const T, Abi, OA> A_ = A;
95 const std::optional<uview<const T, Abi, StorageOrder::ColMajor>> C_ = C;
97
98 if constexpr (OA == StorageOrder::RowMajor) {
99 if (I <= Rows)
100 return microkernel[I - 1](A_, B_, C_, D_, K);
101 foreach_chunked_merged(0, I, Rows, [&](index_t i, auto ni) {
102 auto Cj = C_ ? std::make_optional(C_->middle_rows(i)) : std::nullopt;
103 microkernel[ni - 1](A_.middle_rows(i), B_, Cj, D_.middle_rows(i), K);
104 });
105 } else {
106 if (K <= Rows)
107 return microkernel[K - 1](A_, B_, C_, D_, I);
108 microkernel[Rows - 1](A_.middle_cols(0), B_.middle_rows(0), C_, D_, I);
109 foreach_chunked_merged(Rows, K, Rows, [&](index_t k, auto nk) {
110 microkernel[nk - 1](A_.middle_cols(k), B_.middle_rows(k), D_, D_, I);
111 });
112 }
113}
114
115} // namespace batmat::linalg::micro_kernels::gemv
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:9
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
const constinit auto gemv_copy_lut
Definition gemv.hpp:33
void gemv_copy_register(view< const T, Abi, OA > A, view< const T, Abi > B, std::optional< view< const T, Abi > > C, view< T, Abi > D) noexcept
Generalized matrix multiplication d = c ± A⁽ᵀ⁾ b. Using register blocking.
Definition gemv.tpp:79
void gemv_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > B, std::optional< uview< const T, Abi, StorageOrder::ColMajor > > C, uview< T, Abi, StorageOrder::ColMajor > D, index_t k) noexcept
Generalized matrix-vector multiplication d = c ± A⁽ᵀ⁾ b. Single register block.
Definition gemv.tpp:16
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70