batmat 0.0.14
Batched linear algebra routines
Loading...
Searching...
No Matches
symv.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
8
9#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
10
12
13/// Symmetric matrix-vector multiplication d = c ± A b. Single register block.
14template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder OA>
15[[gnu::hot, gnu::flatten]] void
19 const uview<T, Abi, StorageOrder::ColMajor> D, const index_t k) noexcept {
20 static_assert(RowsReg > 0);
21 using enum MatrixStructure;
22 using namespace ops;
23 using simd = datapar::simd<T, Abi>;
25
26 // TODO: optimize for row-major case
27
28 // Load B and C into registers
29 simd B_reg[RowsReg], C_reg[RowsReg]; // NOLINT(*-c-arrays)
30 UNROLL_FOR (index_t l = 0; l < RowsReg; ++l) {
31 B_reg[l] = B.load(l, 0);
32 C_reg[l] = C ? C->load(l, 0) : simd{0};
33 }
34 // Matrix-vector multiplication kernel (diagonal block)
35 const auto A_cached = with_cached_access<0, RowsReg>(A);
36 UNROLL_FOR (index_t ll = 0; ll < RowsReg; ++ll) {
37 auto Blj = B_reg[ll];
38 auto All = A_cached.load(ll, ll);
39 Conf.negate ? (C_reg[ll] -= All * Blj) : (C_reg[ll] += All * Blj);
40 UNROLL_FOR (index_t ii = ll + 1; ii < RowsReg; ++ii) {
41 auto Ail = A_cached.load(ii, ll);
42 auto Bil = B_reg[ii];
43 Conf.negate ? (C_reg[ii] -= Ail * Blj) : (C_reg[ii] += Ail * Blj);
44 Conf.negate ? (C_reg[ll] -= Ail * Bil) : (C_reg[ll] += Ail * Bil);
45 }
46 }
47 // Matrix-vector multiplication kernel (subdiagonal block)
48 for (index_t i = RowsReg; i < k; ++i) {
49 auto Cij = C ? C->load(i, 0) : simd{0};
50 UNROLL_FOR (index_t ll = 0; ll < RowsReg; ++ll) {
51 auto Blj = B_reg[ll];
52 auto Ail = A_cached.load(i, ll);
53 auto Bil = B.load(i, 0);
54 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
55 Conf.negate ? (C_reg[ll] -= Ail * Bil) : (C_reg[ll] += Ail * Bil);
56 }
57 D.store(Cij, i, 0);
58 }
59 UNROLL_FOR (index_t ll = 0; ll < RowsReg; ++ll)
60 D.store(C_reg[ll], ll, 0);
61}
62
63/// Generalized matrix multiplication d = c ± A⁽ᵀ⁾ b. Using register blocking.
64template <class T, class Abi, KernelConfig Conf, StorageOrder OA>
66 const std::optional<view<const T, Abi>> C, const view<T, Abi> D) noexcept {
67 using enum MatrixStructure;
68 constexpr auto Rows = RowsReg<T, Abi>;
69 // Check dimensions
70 const index_t I = D.rows();
71 BATMAT_ASSUME(A.rows() == I);
72 BATMAT_ASSUME(A.cols() == I);
73 BATMAT_ASSUME(B.rows() == I);
74 BATMAT_ASSUME(B.cols() == 1);
75 BATMAT_ASSUME(D.cols() == 1);
76 BATMAT_ASSUME(I > 0);
77 static const auto microkernel = symv_copy_lut<T, Abi, Conf, OA>;
78 // Sizeless views to partition and pass to the micro-kernels
79 const uview<const T, Abi, OA> A_ = A;
81 const std::optional<uview<const T, Abi, StorageOrder::ColMajor>> C_ = C;
83
84 if (I <= Rows)
85 return microkernel[I - 1](A_, B_, C_, D_, I);
86 microkernel[Rows - 1](A_, B_, C_, D_, I);
87 foreach_chunked_merged(Rows, I, Rows, [&](index_t k, auto nk) {
88 auto Dk = D_.middle_rows(k);
89 microkernel[nk - 1](A_.block(k, k), B_.middle_rows(k), Dk, Dk, I - k);
90 });
91}
92
93} // namespace batmat::linalg::micro_kernels::symv
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:9
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
void symv_copy_register(view< const T, Abi, OA > A, view< const T, Abi > B, std::optional< view< const T, Abi > > C, view< T, Abi > D) noexcept
Generalized matrix multiplication d = c ± A⁽ᵀ⁾ b. Using register blocking.
Definition symv.tpp:65
void symv_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > B, std::optional< uview< const T, Abi, StorageOrder::ColMajor > > C, uview< T, Abi, StorageOrder::ColMajor > D, index_t k) noexcept
Symmetric matrix-vector multiplication d = c ± A b. Single register block.
Definition symv.tpp:16
const constinit auto symv_copy_lut
Definition symv.hpp:29
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114