batmat 0.0.18
Batched linear algebra routines
Loading...
Searching...
No Matches
gemm.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
4// Work around GCC bug regarding declarations and explicit instantiations of extern const variables.
5#define BATMAT_LINALG_GEMM_NO_DECLARE_LUT
7#undef BATMAT_LINALG_GEMM_NO_DECLARE_LUT
9#include <batmat/loop.hpp>
10#include <batmat/ops/rotate.hpp>
11
12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
13
15
16template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OC,
17 StorageOrder OD>
18BATMAT_LINALG_GEMM_EXPORT extern const constinit decltype(detail::gemm_copy_lut<T, Abi, Conf, OA,
19 OB, OC, OD>)
21
22template <MatrixStructure Struc>
23inline constexpr auto first_column =
24 [](index_t row_index) { return Struc == MatrixStructure::UpperTriangular ? row_index : 0; };
25
26template <index_t ColsReg, MatrixStructure Struc>
27inline constexpr auto last_column = [](index_t row_index) {
28 return Struc == MatrixStructure::LowerTriangular ? std::min(row_index, ColsReg - 1)
29 : ColsReg - 1;
30};
31
32/// Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Single register block.
33template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder OA,
35[[gnu::hot, gnu::flatten]] void
37 const std::optional<uview<const T, Abi, OC>> C, const uview<T, Abi, OD> D,
38 const index_t k) noexcept {
39 static_assert(RowsReg > 0 && ColsReg > 0);
40 using enum MatrixStructure;
41 using namespace ops;
42 using simd = datapar::simd<T, Abi>;
43 // Column range for triangular matrix C (gemmt)
44 static constexpr auto min_col = first_column<Conf.struc_C>;
45 static constexpr auto max_col = last_column<ColsReg, Conf.struc_C>;
46 // The following assumption ensures that there is no unnecessary branch
47 // for k == 0 in between the loops. This is crucial for good code
48 // generation, otherwise the compiler inserts jumps and labels between
49 // the matmul kernel and the loading/storing of C, which will cause it to
50 // place C_reg on the stack, resulting in many unnecessary loads and stores.
51 BATMAT_ASSUME(k > 0);
52 // Check dimensions in the triangular case
53 if constexpr (Conf.struc_A != General)
55 if constexpr (Conf.struc_B != General)
57 // Load accumulator into registers
58 simd C_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
59 if (C) [[likely]] {
60 const auto C_cached = with_cached_access<RowsReg, ColsReg>(*C);
61 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
62 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
63 C_reg[ii][jj] = rotl<Conf.rotate_C>(C_cached.load(ii, jj));
64 } else {
65 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
66 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
67 C_reg[ii][jj] = simd{0};
68 }
69
70 const auto A_cached = with_cached_access<RowsReg, 0>(A);
71 const auto B_cached = with_cached_access<0, ColsReg>(B);
72
73 // Triangular matrix multiplication kernel
74 index_t l = 0;
75 if constexpr (Conf.struc_A == UpperTriangular && Conf.struc_B == LowerTriangular) {
76 l += std::max(RowsReg, ColsReg);
77 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
78 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
79 UNROLL_FOR (index_t ll = std::max(ii, jj); ll < std::max(RowsReg, ColsReg); ++ll) {
80 simd &Cij = C_reg[ii][jj];
81 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
82 simd Blj = rotl<Conf.rotate_B>(B_cached.load(ll, jj));
83 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
84 }
85 }
86 }
87 } else if constexpr (Conf.struc_A == UpperTriangular) {
88 l += RowsReg;
89 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
90 UNROLL_FOR (index_t ll = ii; ll < RowsReg; ++ll) {
91 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
92 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
93 simd &Cij = C_reg[ii][jj];
94 simd Blj = rotl<Conf.rotate_B>(B_cached.load(ll, jj));
95 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
96 }
97 }
98 }
99 } else if constexpr (Conf.struc_B == LowerTriangular) {
100 l += ColsReg;
101 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
102 UNROLL_FOR (index_t ll = 0; ll < ColsReg; ++ll) {
103 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
104 UNROLL_FOR (index_t jj = min_col(ii); jj <= std::min(ll, max_col(ii)); ++jj) {
105 simd &Cij = C_reg[ii][jj];
106 simd Blj = rotl<Conf.rotate_B>(B_cached.load(ll, jj));
107 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
108 }
109 }
110 }
111 }
112
113 // Rectangular matrix multiplication kernel
114 const index_t l_end_A = Conf.struc_A == LowerTriangular ? k - RowsReg : k;
115 const index_t l_end_B = Conf.struc_B == UpperTriangular ? k - ColsReg : k;
116 for (; l < std::min(l_end_A, l_end_B); ++l) {
117 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
118 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l));
119 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
120 simd &Cij = C_reg[ii][jj];
121 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l, jj));
122 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
123 }
124 }
125 }
126
127 // Triangular matrix multiplication kernel
128 if constexpr (Conf.struc_A == LowerTriangular && Conf.struc_B == UpperTriangular) {
129 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
130 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
131 const index_t lmax = std::min(ii, jj) + std::max(ColsReg, RowsReg) - RowsReg;
132 UNROLL_FOR (index_t ll = 0; ll <= lmax; ++ll) {
133 simd &Cij = C_reg[ii][jj];
134 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
135 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l + ll, jj));
136 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
137 }
138 }
139 }
140 } else if constexpr (Conf.struc_A == LowerTriangular) {
141 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
142 UNROLL_FOR (index_t ll = 0; ll <= ii; ++ll) {
143 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
144 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
145 simd &Cij = C_reg[ii][jj];
146 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l + ll, jj));
147 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
148 }
149 }
150 }
151 } else if constexpr (Conf.struc_B == UpperTriangular) {
152 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
153 UNROLL_FOR (index_t ll = 0; ll < ColsReg; ++ll) {
154 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
155 UNROLL_FOR (index_t jj = std::max(ll, min_col(ii)); jj <= max_col(ii); ++jj) {
156 simd &Cij = C_reg[ii][jj];
157 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l + ll, jj));
158 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
159 }
160 }
161 }
162 }
163
164 const auto D_cached = with_cached_access<RowsReg, ColsReg>(D);
165 // Store accumulator to memory again
166 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
167 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
168 D_cached.template store<Conf.mask_D>(rotr<Conf.rotate_D>(C_reg[ii][jj]), ii, jj);
169}
170
171/// Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Using register blocking.
172template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OC,
173 StorageOrder OD>
175 const std::optional<view<const T, Abi, OC>> C,
176 const view<T, Abi, OD> D) noexcept {
177 using enum MatrixStructure;
178 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
179 // Check dimensions
180 const index_t I = D.rows(), J = D.cols(), K = A.cols();
181 BATMAT_ASSUME(A.rows() == I);
182 BATMAT_ASSUME(B.rows() == K);
183 BATMAT_ASSUME(B.cols() == J);
184 if constexpr (Conf.struc_A != General)
185 BATMAT_ASSUME(I == K);
186 if constexpr (Conf.struc_B != General)
187 BATMAT_ASSUME(K == J);
188 if constexpr (Conf.struc_C != General)
189 BATMAT_ASSUME(I == J);
190 BATMAT_ASSUME(I > 0);
191 BATMAT_ASSUME(J > 0);
192 BATMAT_ASSUME(K > 0);
193 // Configurations for the various micro-kernels
194 constexpr KernelConfig ConfGXG{.negate = Conf.negate,
195 .shift_A = Conf.shift_A,
196 .rotate_B = Conf.rotate_B,
197 .rotate_C = Conf.rotate_C,
198 .rotate_D = Conf.rotate_D,
199 .mask_D = Conf.mask_D,
200 .struc_A = General,
201 .struc_B = Conf.struc_B,
202 .struc_C = General};
203 constexpr KernelConfig ConfXGG{.negate = Conf.negate,
204 .shift_A = Conf.shift_A,
205 .rotate_B = Conf.rotate_B,
206 .rotate_C = Conf.rotate_C,
207 .rotate_D = Conf.rotate_D,
208 .mask_D = Conf.mask_D,
209 .struc_A = Conf.struc_A,
210 .struc_B = General,
211 .struc_C = General};
212 constexpr KernelConfig ConfXXG{.negate = Conf.negate,
213 .shift_A = Conf.shift_A,
214 .rotate_B = Conf.rotate_B,
215 .rotate_C = Conf.rotate_C,
216 .rotate_D = Conf.rotate_D,
217 .mask_D = Conf.mask_D,
218 .struc_A = Conf.struc_A,
219 .struc_B = Conf.struc_B,
220 .struc_C = General};
221 static const auto microkernel = gemm_copy_lut<T, Abi, Conf, OA, OB, OC, OD>;
222 static const auto microkernel_GXG = gemm_copy_lut<T, Abi, ConfGXG, OA, OB, OC, OD>;
223 static const auto microkernel_XGG = gemm_copy_lut<T, Abi, ConfXGG, OA, OB, OC, OD>;
224 static const auto microkernel_XXG = gemm_copy_lut<T, Abi, ConfXXG, OA, OB, OC, OD>;
225 // Sizeless views to partition and pass to the micro-kernels
226 const uview<const T, Abi, OA> A_ = A;
227 const uview<const T, Abi, OB> B_ = B;
228 const std::optional<uview<const T, Abi, OC>> C_ = C;
229 const uview<T, Abi, OD> D_ = D;
230
231 // Optimization for very small matrices
232 if (I <= Rows && J <= Cols)
233 return microkernel[I - 1][J - 1](A_, B_, C_, D_, K);
234
235 // Simply loop over all blocks in the given matrices.
236 auto run = [&] [[gnu::always_inline]] (index_t i, index_t ni, index_t j, index_t nj) {
237 const auto Bj = B_.middle_cols(j);
238 const auto l0A = Conf.struc_A == UpperTriangular ? i : 0;
239 const auto l1A = Conf.struc_A == LowerTriangular ? i + ni + std::max(K, I) - I : K;
240 const auto l0B = Conf.struc_B == LowerTriangular ? j : 0;
241 const auto l1B = Conf.struc_B == UpperTriangular ? j + nj + std::max(K, J) - J : K;
242 const auto l0 = std::max(l0A, l0B);
243 const auto l1 = std::min(l1A, l1B);
244 const auto Ai = A_.middle_rows(i);
245 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
246 const auto Dij = D_.block(i, j);
247 const auto Ail = Ai.middle_cols(l0);
248 const auto Blj = Bj.middle_rows(l0);
249
250 if (l1 == l0) // TODO: this is wrong.
251 return;
252 if constexpr (Conf.struc_A == LowerTriangular && Conf.struc_B == UpperTriangular) { // LU
253 if (l1A > l1B) {
254 microkernel_GXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
255 return;
256 } else if (l1A < l1B) {
257 microkernel_XGG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
258 return;
259 }
260 }
261 if constexpr (Conf.struc_A == UpperTriangular && Conf.struc_B == LowerTriangular) { // UL
262 if (l0A > l0B) {
263 microkernel_XGG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
264 return;
265 } else if (l0A < l0B) {
266 microkernel_GXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
267 return;
268 }
269 }
270 if constexpr (Conf.struc_C != General) { // syrk
271 if (i != j) {
272 microkernel_XXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
273 return;
274 }
275 }
276 microkernel[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
277 };
278 // Pick loop directions that allow having A=D or B=D
279 constexpr auto dir_i = Conf.struc_A == LowerTriangular ? LoopDir::Backward : LoopDir::Forward,
280 dir_j = Conf.struc_B == UpperTriangular ? LoopDir::Backward : LoopDir::Forward;
281 // Loop over block rows of A and block columns of B
282 if constexpr (OB == StorageOrder::ColMajor)
284 0, J, index_constant<Cols>(),
285 [&](index_t j, auto nj) {
286 const auto i0 = Conf.struc_C == LowerTriangular ? j : 0,
287 i1 = Conf.struc_C == UpperTriangular ? j + nj : I;
289 i0, i1, index_constant<Rows>(), [&](index_t i, auto ni) { run(i, ni, j, nj); },
290 dir_i);
291 },
292 dir_j);
293 else // swap the loops for row-major B
295 0, I, index_constant<Rows>(),
296 [&](index_t i, auto ni) {
297 const auto j0 = Conf.struc_C == UpperTriangular ? i : 0,
298 j1 = Conf.struc_C == LowerTriangular ? i + ni : J;
300 j0, j1, index_constant<Cols>(), [&](index_t j, auto nj) { run(i, ni, j, nj); },
301 dir_j);
302 },
303 dir_i);
304}
305
306} // namespace batmat::linalg::micro_kernels::gemm
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
const constinit decltype(detail::gemm_copy_lut< T, Abi, Conf, OA, OB, OC, OD >) gemm_copy_lut
Definition gemm.tpp:20
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
void gemm_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Using register blocking.
Definition gemm.tpp:174
void gemm_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, std::optional< uview< const T, Abi, OC > > C, uview< T, Abi, OD > D, index_t k) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Single register block.
Definition gemm.tpp:36
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114
Self middle_cols(this const Self &self, index_t c) noexcept
Definition uview.hpp:118