batmat 0.0.17
Batched linear algebra routines
Loading...
Searching...
No Matches
symv.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
7#include <batmat/lut.hpp>
9
10#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
11
13
14template <class T, class Abi, KernelConfig Conf, StorageOrder OA>
15inline const constinit auto symv_copy_lut =
18 });
19
20/// Symmetric matrix-vector multiplication d = c ± A b. Single register block.
21template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder OA>
22[[gnu::hot, gnu::flatten]] void
26 const uview<T, Abi, StorageOrder::ColMajor> D, const index_t k) noexcept {
27 static_assert(RowsReg > 0);
28 using enum MatrixStructure;
29 using namespace ops;
30 using simd = datapar::simd<T, Abi>;
32
33 // TODO: optimize for row-major case
34
35 // Load B and C into registers
36 simd B_reg[RowsReg], C_reg[RowsReg]; // NOLINT(*-c-arrays)
37 UNROLL_FOR (index_t l = 0; l < RowsReg; ++l) {
38 B_reg[l] = B.load(l, 0);
39 C_reg[l] = C ? C->load(l, 0) : simd{0};
40 }
41 // Matrix-vector multiplication kernel (diagonal block)
42 const auto A_cached = with_cached_access<0, RowsReg>(A);
43 UNROLL_FOR (index_t ll = 0; ll < RowsReg; ++ll) {
44 auto Blj = B_reg[ll];
45 auto All = A_cached.load(ll, ll);
46 Conf.negate ? (C_reg[ll] -= All * Blj) : (C_reg[ll] += All * Blj);
47 UNROLL_FOR (index_t ii = ll + 1; ii < RowsReg; ++ii) {
48 auto Ail = A_cached.load(ii, ll);
49 auto Bil = B_reg[ii];
50 Conf.negate ? (C_reg[ii] -= Ail * Blj) : (C_reg[ii] += Ail * Blj);
51 Conf.negate ? (C_reg[ll] -= Ail * Bil) : (C_reg[ll] += Ail * Bil);
52 }
53 }
54 // Matrix-vector multiplication kernel (subdiagonal block)
55 for (index_t i = RowsReg; i < k; ++i) {
56 auto Cij = C ? C->load(i, 0) : simd{0};
57 UNROLL_FOR (index_t ll = 0; ll < RowsReg; ++ll) {
58 auto Blj = B_reg[ll];
59 auto Ail = A_cached.load(i, ll);
60 auto Bil = B.load(i, 0);
61 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
62 Conf.negate ? (C_reg[ll] -= Ail * Bil) : (C_reg[ll] += Ail * Bil);
63 }
64 D.store(Cij, i, 0);
65 }
66 UNROLL_FOR (index_t ll = 0; ll < RowsReg; ++ll)
67 D.store(C_reg[ll], ll, 0);
68}
69
70/// Generalized matrix multiplication d = c ± A⁽ᵀ⁾ b. Using register blocking.
71template <class T, class Abi, KernelConfig Conf, StorageOrder OA>
73 const std::optional<view<const T, Abi>> C, const view<T, Abi> D) noexcept {
74 using enum MatrixStructure;
75 constexpr auto Rows = RowsReg<T, Abi>;
76 // Check dimensions
77 const index_t I = D.rows();
78 BATMAT_ASSUME(A.rows() == I);
79 BATMAT_ASSUME(A.cols() == I);
80 BATMAT_ASSUME(B.rows() == I);
81 BATMAT_ASSUME(B.cols() == 1);
82 BATMAT_ASSUME(D.cols() == 1);
83 BATMAT_ASSUME(I > 0);
84 static const auto microkernel = symv_copy_lut<T, Abi, Conf, OA>;
85 // Sizeless views to partition and pass to the micro-kernels
86 const uview<const T, Abi, OA> A_ = A;
88 const std::optional<uview<const T, Abi, StorageOrder::ColMajor>> C_ = C;
90
91 if (I <= Rows)
92 return microkernel[I - 1](A_, B_, C_, D_, I);
93 microkernel[Rows - 1](A_, B_, C_, D_, I);
94 foreach_chunked_merged(Rows, I, Rows, [&](index_t k, auto nk) {
95 auto Dk = D_.middle_rows(k);
96 microkernel[nk - 1](A_.block(k, k), B_.middle_rows(k), Dk, Dk, I - k);
97 });
98}
99
100} // namespace batmat::linalg::micro_kernels::symv
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
void symv_copy_register(view< const T, Abi, OA > A, view< const T, Abi > B, std::optional< view< const T, Abi > > C, view< T, Abi > D) noexcept
Generalized matrix multiplication d = c ± A⁽ᵀ⁾ b. Using register blocking.
Definition symv.tpp:72
void symv_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > B, std::optional< uview< const T, Abi, StorageOrder::ColMajor > > C, uview< T, Abi, StorageOrder::ColMajor > D, index_t k) noexcept
Symmetric matrix-vector multiplication d = c ± A b. Single register block.
Definition symv.tpp:23
const constinit auto symv_copy_lut
Definition symv.tpp:15
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114