batmat 0.0.13
Batched linear algebra routines
Loading...
Searching...
No Matches
gemm-diag.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
8
9#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
10
12
13template <MatrixStructure Struc>
14inline constexpr auto first_column =
15 [](index_t row_index) { return Struc == MatrixStructure::UpperTriangular ? row_index : 0; };
16
17template <index_t ColsReg, MatrixStructure Struc>
18inline constexpr auto last_column = [](index_t row_index) {
19 return Struc == MatrixStructure::LowerTriangular ? std::min(row_index, ColsReg - 1)
20 : ColsReg - 1;
21};
22
23/// Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Single register block.
24template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder OA,
26[[gnu::hot, gnu::flatten]] std::conditional_t<Conf.track_zeros, std::pair<index_t, index_t>, void>
28 const std::optional<uview<const T, Abi, OC>> C,
30 const index_t k) noexcept {
31 static_assert(RowsReg > 0 && ColsReg > 0);
32 using enum MatrixStructure;
33 using namespace ops;
34 using simd = datapar::simd<T, Abi>;
35 // Column range for triangular matrix C (gemmt)
36 static constexpr auto min_col = first_column<Conf.struc_C>;
37 static constexpr auto max_col = last_column<ColsReg, Conf.struc_C>;
38 // The following assumption ensures that there is no unnecessary branch
39 // for k == 0 in between the loops. This is crucial for good code
40 // generation, otherwise the compiler inserts jumps and labels between
41 // the matmul kernel and the loading/storing of C, which will cause it to
42 // place C_reg on the stack, resulting in many unnecessary loads and stores.
43 BATMAT_ASSUME(k > 0);
44 // Load accumulator into registers
45 simd C_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
46 if (C) [[likely]] {
47 const auto C_cached = with_cached_access<RowsReg, ColsReg>(*C);
48 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
49 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
50 C_reg[ii][jj] = C_cached.load(ii, jj);
51 } else {
52 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
53 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
54 C_reg[ii][jj] = simd{0};
55 }
56
57 const auto A_cached = with_cached_access<RowsReg, 0>(A);
58 const auto B_cached = with_cached_access<0, ColsReg>(B);
59
60 // Rectangular matrix multiplication kernel
61 index_t first_nonzero = -1, last_nonzero = -1;
62 for (index_t l = 0; l < k; ++l) {
63 bool all_zero = true;
64 simd dl = d.load(l);
65 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
66 simd Ail = dl * A_cached.load(ii, l);
67 if constexpr (Conf.track_zeros)
68 all_zero &= all_of(Ail == simd{0});
69 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
70 simd &Cij = C_reg[ii][jj];
71 simd Blj = B_cached.load(l, jj);
72 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
73 }
74 }
75 if constexpr (Conf.track_zeros)
76 if (!all_zero) {
77 last_nonzero = l;
78 if (first_nonzero < 0)
79 first_nonzero = l;
80 }
81 }
82
83 const auto D_cached = with_cached_access<RowsReg, ColsReg>(D);
84 // Store accumulator to memory again
85 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
86 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
87 D_cached.store(C_reg[ii][jj], ii, jj);
88
89 if constexpr (Conf.track_zeros) {
90 if (first_nonzero < 0)
91 return {k, k};
92 else
93 return {first_nonzero, last_nonzero + 1};
94 }
95}
96
97/// Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Using register blocking.
98template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OC,
99 StorageOrder OD>
101 const std::optional<view<const T, Abi, OC>> C,
102 const view<T, Abi, OD> D, view<const T, Abi> d) noexcept {
103 using enum MatrixStructure;
104 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
105 // Check dimensions
106 const index_t I = D.rows(), J = D.cols(), K = A.cols();
107 BATMAT_ASSUME(A.rows() == I);
108 BATMAT_ASSUME(B.rows() == K);
109 BATMAT_ASSUME(B.cols() == J);
110 BATMAT_ASSUME(d.rows() == K);
111 BATMAT_ASSUME(d.cols() == 1);
112 if constexpr (Conf.struc_C != General)
113 BATMAT_ASSUME(I == J);
114 BATMAT_ASSUME(I > 0);
115 BATMAT_ASSUME(J > 0);
116 BATMAT_ASSUME(K > 0);
117 // Configurations for the various micro-kernels
118 constexpr KernelConfig ConfSmall{.negate = Conf.negate, .struc_C = Conf.struc_C};
119 constexpr KernelConfig ConfSub{.negate = Conf.negate, .struc_C = General};
120 static const auto microkernel = gemm_diag_copy_lut<T, Abi, Conf, OA, OB, OC, OD>;
121 static const auto microkernel_small = gemm_diag_copy_lut<T, Abi, ConfSmall, OA, OB, OC, OD>;
122 static const auto microkernel_sub = gemm_diag_copy_lut<T, Abi, ConfSub, OA, OB, OC, OD>;
123 (void)microkernel_sub; // GCC incorrectly warns about unused variable
124 // Sizeless views to partition and pass to the micro-kernels
125 const uview<const T, Abi, OA> A_ = A;
126 const uview<const T, Abi, OB> B_ = B;
127 const std::optional<uview<const T, Abi, OC>> C_ = C;
128 const uview<T, Abi, OD> D_ = D;
129 const uview_vec<const T, Abi> d_{d};
130
131 // Optimization for very small matrices
132 if (I <= Rows && J <= Cols)
133 return microkernel_small[I - 1][J - 1](A_, B_, C_, D_, d_, K);
134
135 // Loop over block rows of A and block columns of B
136 foreach_chunked_merged(0, I, index_constant<Rows>(), [&](index_t i, auto ni) {
137 const auto Ai = A_.middle_rows(i);
138 // If triangular: use diagonal block (i, i) for counting zeros.
139 // If general: use first block column (i, 0).
140 const auto j0 = Conf.struc_C == UpperTriangular ? i + ni
141 : Conf.struc_C == LowerTriangular ? 0
142 : std::min(Cols, J),
143 j1 = Conf.struc_C == LowerTriangular ? i : J;
144 // First micro-kernel call that keeps track of the leading/trailing zeros in A diag(d)
145 auto [l0, l1] = [&] {
146 const auto j = Conf.struc_C == General ? 0 : i;
147 const auto nj = std::min<index_t>(Cols, J - j);
148 const auto Bj = B_.middle_cols(j);
149 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
150 const auto Dij = D_.block(i, j);
151 if constexpr (Conf.track_zeros)
152 return microkernel[ni - 1][nj - 1](Ai, Bj, Cij, Dij, d_, K);
153 else {
154 microkernel[ni - 1][nj - 1](Ai, Bj, Cij, Dij, d_, K);
155 return std::pair<index_t, index_t>{0, K};
156 }
157 }();
158 if (l1 == l0) {
159 if (!C)
160 D.block(i, j0, ni, j1 - j0).set_constant(T{});
161 else if (C->data() == D.data() && C->outer_stride() == D.outer_stride())
162 BATMAT_ASSUME(C->storage_order == D.storage_order); // Nothing to do
163 else if constexpr (OC == StorageOrder::ColMajor)
164 for (index_t jj = j0; jj < j1; ++jj) // TODO: suboptimal when transpose required
165 for (index_t ii = i; ii < i + ni; ++ii)
166 D_.store(C_->load(ii, jj), ii, jj);
167 else
168 for (index_t ii = i; ii < i + ni; ++ii) // TODO: suboptimal when transpose required
169 for (index_t jj = j0; jj < j1; ++jj)
170 D_.store(C_->load(ii, jj), ii, jj);
171 return;
172 }
173 // Process other blocks, trimming any leading/trailing zeros (before l0 or after l1)
174 foreach_chunked_merged(j0, j1, index_constant<Cols>(), [&](index_t j, auto nj) {
175 const auto Bj = B_.middle_cols(j);
176 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
177 const auto Dij = D_.block(i, j);
178 const auto Ail = Ai.middle_cols(l0);
179 const auto Blj = Bj.middle_rows(l0);
180 const auto dl = d_.segment(l0);
181 microkernel_sub[ni - 1][nj - 1](Ail, Blj, Cij, Dij, dl, l1 - l0);
182 });
183 });
184}
185
186} // namespace batmat::linalg::micro_kernels::gemm_diag
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:9
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
std::conditional_t< Conf.track_zeros, std::pair< index_t, index_t >, void > gemm_diag_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, std::optional< uview< const T, Abi, OC > > C, uview< T, Abi, OD > D, uview_vec< const T, Abi > diag, index_t k) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Single register block.
Definition gemm-diag.tpp:27
void gemm_diag_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D, view< const T, Abi > diag) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Using register blocking.
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Self segment(this const Self &self, index_t r) noexcept
Definition uview.hpp:162
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
void store(simd x, index_t r, index_t c) const noexcept
Definition uview.hpp:104
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114
Self middle_cols(this const Self &self, index_t c) noexcept
Definition uview.hpp:118