batmat 0.0.13
Batched linear algebra routines
Loading...
Searching...
No Matches
gemm.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
8
9#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
10
12
13template <MatrixStructure Struc>
14inline constexpr auto first_column =
15 [](index_t row_index) { return Struc == MatrixStructure::UpperTriangular ? row_index : 0; };
16
17template <index_t ColsReg, MatrixStructure Struc>
18inline constexpr auto last_column = [](index_t row_index) {
19 return Struc == MatrixStructure::LowerTriangular ? std::min(row_index, ColsReg - 1)
20 : ColsReg - 1;
21};
22
23/// Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Single register block.
24template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder OA,
26[[gnu::hot, gnu::flatten]] void
28 const std::optional<uview<const T, Abi, OC>> C, const uview<T, Abi, OD> D,
29 const index_t k) noexcept {
30 static_assert(RowsReg > 0 && ColsReg > 0);
31 using enum MatrixStructure;
32 using namespace ops;
33 using simd = datapar::simd<T, Abi>;
34 // Column range for triangular matrix C (gemmt)
35 static constexpr auto min_col = first_column<Conf.struc_C>;
36 static constexpr auto max_col = last_column<ColsReg, Conf.struc_C>;
37 // The following assumption ensures that there is no unnecessary branch
38 // for k == 0 in between the loops. This is crucial for good code
39 // generation, otherwise the compiler inserts jumps and labels between
40 // the matmul kernel and the loading/storing of C, which will cause it to
41 // place C_reg on the stack, resulting in many unnecessary loads and stores.
42 BATMAT_ASSUME(k > 0);
43 // Check dimensions in the triangular case
44 if constexpr (Conf.struc_A != General)
46 if constexpr (Conf.struc_B != General)
48 // Load accumulator into registers
49 simd C_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
50 if (C) [[likely]] {
51 const auto C_cached = with_cached_access<RowsReg, ColsReg>(*C);
52 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
53 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
54 C_reg[ii][jj] = rotl<Conf.rotate_C>(C_cached.load(ii, jj));
55 } else {
56 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
57 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
58 C_reg[ii][jj] = simd{0};
59 }
60
61 const auto A_cached = with_cached_access<RowsReg, 0>(A);
62 const auto B_cached = with_cached_access<0, ColsReg>(B);
63
64 // Triangular matrix multiplication kernel
65 index_t l = 0;
66 if constexpr (Conf.struc_A == UpperTriangular && Conf.struc_B == LowerTriangular) {
67 l += std::max(RowsReg, ColsReg);
68 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
69 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
70 UNROLL_FOR (index_t ll = std::max(ii, jj); ll < std::max(RowsReg, ColsReg); ++ll) {
71 simd &Cij = C_reg[ii][jj];
72 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
73 simd Blj = shiftl<Conf.shift_B>(B_cached.load(ll, jj));
74 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
75 }
76 }
77 }
78 } else if constexpr (Conf.struc_A == UpperTriangular) {
79 l += RowsReg;
80 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
81 UNROLL_FOR (index_t ll = ii; ll < RowsReg; ++ll) {
82 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
83 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
84 simd &Cij = C_reg[ii][jj];
85 simd Blj = shiftl<Conf.shift_B>(B_cached.load(ll, jj));
86 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
87 }
88 }
89 }
90 } else if constexpr (Conf.struc_B == LowerTriangular) {
91 l += ColsReg;
92 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
93 UNROLL_FOR (index_t ll = 0; ll < ColsReg; ++ll) {
94 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
95 UNROLL_FOR (index_t jj = min_col(ii); jj <= std::min(ll, max_col(ii)); ++jj) {
96 simd &Cij = C_reg[ii][jj];
97 simd Blj = shiftl<Conf.shift_B>(B_cached.load(ll, jj));
98 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
99 }
100 }
101 }
102 }
103
104 // Rectangular matrix multiplication kernel
105 const index_t l_end_A = Conf.struc_A == LowerTriangular ? k - RowsReg : k;
106 const index_t l_end_B = Conf.struc_B == UpperTriangular ? k - ColsReg : k;
107 for (; l < std::min(l_end_A, l_end_B); ++l) {
108 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
109 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l));
110 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
111 simd &Cij = C_reg[ii][jj];
112 simd Blj = shiftl<Conf.shift_B>(B_cached.load(l, jj));
113 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
114 }
115 }
116 }
117
118 // Triangular matrix multiplication kernel
119 if constexpr (Conf.struc_A == LowerTriangular && Conf.struc_B == UpperTriangular) {
120 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
121 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
122 const index_t lmax = std::min(ii, jj) + std::max(ColsReg, RowsReg) - RowsReg;
123 UNROLL_FOR (index_t ll = 0; ll <= lmax; ++ll) {
124 simd &Cij = C_reg[ii][jj];
125 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
126 simd Blj = shiftl<Conf.shift_B>(B_cached.load(l + ll, jj));
127 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
128 }
129 }
130 }
131 } else if constexpr (Conf.struc_A == LowerTriangular) {
132 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
133 UNROLL_FOR (index_t ll = 0; ll <= ii; ++ll) {
134 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
135 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
136 simd &Cij = C_reg[ii][jj];
137 simd Blj = shiftl<Conf.shift_B>(B_cached.load(l + ll, jj));
138 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
139 }
140 }
141 }
142 } else if constexpr (Conf.struc_B == UpperTriangular) {
143 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
144 UNROLL_FOR (index_t ll = 0; ll < ColsReg; ++ll) {
145 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
146 UNROLL_FOR (index_t jj = std::max(ll, min_col(ii)); jj <= max_col(ii); ++jj) {
147 simd &Cij = C_reg[ii][jj];
148 simd Blj = shiftl<Conf.shift_B>(B_cached.load(l + ll, jj));
149 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
150 }
151 }
152 }
153 }
154
155 const auto D_cached = with_cached_access<RowsReg, ColsReg>(D);
156 // Store accumulator to memory again
157 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
158 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
159 D_cached.template store<Conf.mask_D>(rotr<Conf.rotate_D>(C_reg[ii][jj]), ii, jj);
160}
161
162/// Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Using register blocking.
163template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OC,
164 StorageOrder OD>
166 const std::optional<view<const T, Abi, OC>> C,
167 const view<T, Abi, OD> D) noexcept {
168 using enum MatrixStructure;
169 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
170 // Check dimensions
171 const index_t I = D.rows(), J = D.cols(), K = A.cols();
172 BATMAT_ASSUME(A.rows() == I);
173 BATMAT_ASSUME(B.rows() == K);
174 BATMAT_ASSUME(B.cols() == J);
175 if constexpr (Conf.struc_A != General)
176 BATMAT_ASSUME(I == K);
177 if constexpr (Conf.struc_B != General)
178 BATMAT_ASSUME(K == J);
179 if constexpr (Conf.struc_C != General)
180 BATMAT_ASSUME(I == J);
181 BATMAT_ASSUME(I > 0);
182 BATMAT_ASSUME(J > 0);
183 BATMAT_ASSUME(K > 0);
184 // Configurations for the various micro-kernels
185 constexpr KernelConfig ConfGXG{.negate = Conf.negate,
186 .shift_A = Conf.shift_A,
187 .shift_B = Conf.shift_B,
188 .rotate_C = Conf.rotate_C,
189 .rotate_D = Conf.rotate_D,
190 .mask_D = Conf.mask_D,
191 .struc_A = General,
192 .struc_B = Conf.struc_B,
193 .struc_C = General};
194 constexpr KernelConfig ConfXGG{.negate = Conf.negate,
195 .shift_A = Conf.shift_A,
196 .shift_B = Conf.shift_B,
197 .rotate_C = Conf.rotate_C,
198 .rotate_D = Conf.rotate_D,
199 .mask_D = Conf.mask_D,
200 .struc_A = Conf.struc_A,
201 .struc_B = General,
202 .struc_C = General};
203 constexpr KernelConfig ConfXXG{.negate = Conf.negate,
204 .shift_A = Conf.shift_A,
205 .shift_B = Conf.shift_B,
206 .rotate_C = Conf.rotate_C,
207 .rotate_D = Conf.rotate_D,
208 .mask_D = Conf.mask_D,
209 .struc_A = Conf.struc_A,
210 .struc_B = Conf.struc_B,
211 .struc_C = General};
212 static const auto microkernel = gemm_copy_lut<T, Abi, Conf, OA, OB, OC, OD>;
213 static const auto microkernel_GXG = gemm_copy_lut<T, Abi, ConfGXG, OA, OB, OC, OD>;
214 static const auto microkernel_XGG = gemm_copy_lut<T, Abi, ConfXGG, OA, OB, OC, OD>;
215 static const auto microkernel_XXG = gemm_copy_lut<T, Abi, ConfXXG, OA, OB, OC, OD>;
216 // Sizeless views to partition and pass to the micro-kernels
217 const uview<const T, Abi, OA> A_ = A;
218 const uview<const T, Abi, OB> B_ = B;
219 const std::optional<uview<const T, Abi, OC>> C_ = C;
220 const uview<T, Abi, OD> D_ = D;
221
222 // Optimization for very small matrices
223 if (I <= Rows && J <= Cols)
224 return microkernel[I - 1][J - 1](A_, B_, C_, D_, K);
225
226 // Simply loop over all blocks in the given matrices.
227 auto run = [&] [[gnu::always_inline]] (index_t i, index_t ni, index_t j, index_t nj) {
228 const auto Bj = B_.middle_cols(j);
229 const auto l0A = Conf.struc_A == UpperTriangular ? i : 0;
230 const auto l1A = Conf.struc_A == LowerTriangular ? i + ni + std::max(K, I) - I : K;
231 const auto l0B = Conf.struc_B == LowerTriangular ? j : 0;
232 const auto l1B = Conf.struc_B == UpperTriangular ? j + nj + std::max(K, J) - J : K;
233 const auto l0 = std::max(l0A, l0B);
234 const auto l1 = std::min(l1A, l1B);
235 const auto Ai = A_.middle_rows(i);
236 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
237 const auto Dij = D_.block(i, j);
238 const auto Ail = Ai.middle_cols(l0);
239 const auto Blj = Bj.middle_rows(l0);
240
241 if (l1 == l0) // TODO: this is wrong.
242 return;
243 if constexpr (Conf.struc_A == LowerTriangular && Conf.struc_B == UpperTriangular) { // LU
244 if (l1A > l1B) {
245 microkernel_GXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
246 return;
247 } else if (l1A < l1B) {
248 microkernel_XGG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
249 return;
250 }
251 }
252 if constexpr (Conf.struc_A == UpperTriangular && Conf.struc_B == LowerTriangular) { // UL
253 if (l0A > l0B) {
254 microkernel_XGG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
255 return;
256 } else if (l0A < l0B) {
257 microkernel_GXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
258 return;
259 }
260 }
261 if constexpr (Conf.struc_C != General) { // syrk
262 if (i != j) {
263 microkernel_XXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
264 return;
265 }
266 }
267 microkernel[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
268 };
269 // Pick loop directions that allow having A=D or B=D
270 constexpr auto dir_i = Conf.struc_A == LowerTriangular ? LoopDir::Backward : LoopDir::Forward,
271 dir_j = Conf.struc_B == UpperTriangular ? LoopDir::Backward : LoopDir::Forward;
272 // Loop over block rows of A and block columns of B
273 if constexpr (OB == StorageOrder::ColMajor)
275 0, J, index_constant<Cols>(),
276 [&](index_t j, auto nj) {
277 const auto i0 = Conf.struc_C == LowerTriangular ? j : 0,
278 i1 = Conf.struc_C == UpperTriangular ? j + nj : I;
280 i0, i1, index_constant<Rows>(), [&](index_t i, auto ni) { run(i, ni, j, nj); },
281 dir_i);
282 },
283 dir_j);
284 else // swap the loops for row-major B
286 0, I, index_constant<Rows>(),
287 [&](index_t i, auto ni) {
288 const auto j0 = Conf.struc_C == UpperTriangular ? i : 0,
289 j1 = Conf.struc_C == LowerTriangular ? i + ni : J;
291 j0, j1, index_constant<Cols>(), [&](index_t j, auto nj) { run(i, ni, j, nj); },
292 dir_j);
293 },
294 dir_i);
295}
296
297} // namespace batmat::linalg::micro_kernels::gemm
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:9
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
const constinit auto gemm_copy_lut
Definition gemm.hpp:40
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
void gemm_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Using register blocking.
Definition gemm.tpp:165
void gemm_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, std::optional< uview< const T, Abi, OC > > C, uview< T, Abi, OD > D, index_t k) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Single register block.
Definition gemm.tpp:27
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114
Self middle_cols(this const Self &self, index_t c) noexcept
Definition uview.hpp:118