batmat 0.0.16
Batched linear algebra routines
Loading...
Searching...
No Matches
gemm.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/kib.hpp>
11#include <batmat/loop.hpp>
13#include <guanaqo/trace.hpp>
14#include <optional>
15
16namespace batmat::linalg {
17
18/// Decides which matrices to pack during large matrix-matrix multiplication.
19/// @ingroup topic-linalg
20enum class PackingSelector : int8_t {
21 Never, ///< Access the original matrices directly in the micro-kernels.
22 Always, ///< Always pack the blocks of the matrix in a contiguous workspace.
23 Transpose, ///< Pack the blocks of the matrix only if it is not in the optimal storage order.
24};
25
26/// Packing and tiling options for matrix-matrix multiplication.
27/// @ingroup topic-linalg
29 bool no_tiling = false; ///< Don't use cache tiling.
30 PackingSelector pack_A = PackingSelector::Transpose; ///< When to pack matrix A.
31 PackingSelector pack_B = PackingSelector::Always; ///< When to pack matrix B.
32 index_t n_c = 0; ///< Cache block size in the N dimension (columns of B, C and D).
33 index_t k_c = 0; ///< Cache block size in the K dimension (columns of A, rows of B).
34 index_t m_c = 0; ///< Cache block size in the M dimension (rows of A, C and D).
35};
36
37namespace detail {
38template <class T, class Abi, micro_kernels::gemm::KernelConfig Conf = {}, StorageOrder OA,
40 requires(Conf.struc_A == MatrixStructure::General && Conf.struc_B == MatrixStructure::General &&
41 Conf.struc_C == MatrixStructure::General)
43 std::optional<view<const T, Abi, OC>> C, view<T, Abi, OD> D, TilingOptions packing = {}) {
44 // Check dimensions
45 BATMAT_ASSERT(!C || C->rows() == D.rows());
46 BATMAT_ASSERT(!C || C->cols() == D.cols());
47 BATMAT_ASSERT(A.rows() == D.rows());
48 BATMAT_ASSERT(A.cols() == B.rows());
49 BATMAT_ASSERT(B.cols() == D.cols());
50 const index_t M = D.rows(), N = D.cols(), K = A.cols();
51 GUANAQO_TRACE_LINALG("gemm", total(flops::gemm(M, N, K)) * D.depth());
52 static const index_t N_reg = micro_kernels::gemm::ColsReg<T, Abi>;
53 static const index_t M_reg = micro_kernels::gemm::RowsReg<T, Abi>;
54
55 // Degenerate case
56 if (M == 0 || N == 0) [[unlikely]]
57 return;
58 if (K == 0) [[unlikely]] {
59 // https://github.com/llvm/llvm-project/issues/146272
60 constexpr detail::copy::CopyConfig rot{
61 .rotate = Conf.rotate_C - Conf.rotate_D, .mask = Conf.mask_D, .struc = Conf.struc_C};
62 constexpr detail::copy::FillConfig msk{.mask = Conf.mask_D, .struc = Conf.struc_C};
63 if (C)
65 else
67 return;
68 }
69
70 // Small matrices
72 if (M <= M_reg && N <= N_reg) [[likely]]
73 return gemm_copy_lut<T, Abi, Conf, OA, OB, OC, OD>[M - 1][N - 1](A, B, C, D, K);
74
75 // Determine block sizes for cache tiling
76 static const index_t simd_stride = simd_view_types<T, Abi>::simd_stride;
77 static const index_t L1_cache_size = 48_KiB; // TODO: determine dynamically
78 static const index_t L2_cache_size = 512_KiB;
79 static const index_t L3_cache_size = 16_MiB;
80 static const index_t n_cores = 8; // TODO: OMP
81 // clang-format off
82 static const index_t K_cache_default = L1_cache_size / sizeof(T) / simd_stride / N_reg;
83 static const index_t M_cache_default = (L2_cache_size / sizeof(T) / simd_stride / K_cache_default / M_reg) * M_reg;
84 static const index_t N_cache_default = std::max<index_t>(L3_cache_size / sizeof(T) / simd_stride / K_cache_default / n_cores / M_cache_default, 1) * M_cache_default;
85 // clang-format on
86 const index_t K_cache = packing.k_c ? packing.k_c : K_cache_default;
87 const index_t M_cache = packing.m_c ? packing.m_c : M_cache_default;
88 const index_t N_cache = packing.n_c ? packing.n_c : N_cache_default;
89
90 // Medium size (no tiling)
91 if ((M <= M_cache && N <= N_cache && K <= K_cache) || packing.no_tiling) [[likely]]
93
94 // Determine sizes for packing tiles of A and B
95 using simd_align_t = typename simd_view_types<T, Abi>::simd_align_t;
96 const index_t B_pack_size = B.ceil_depth() * K_cache * N_cache;
97 const index_t A_pack_size = A.ceil_depth() * M_cache * K_cache;
98 const index_t B_size = B.ceil_depth() * K * N;
99 const index_t A_size = A.ceil_depth() * M * K;
100 const bool select_pack_B =
101 packing.pack_B == PackingSelector::Always ||
102 (packing.pack_B == PackingSelector::Transpose && OB == StorageOrder::RowMajor);
103 const bool select_pack_A =
104 packing.pack_A == PackingSelector::Always ||
105 (packing.pack_A == PackingSelector::Transpose && OA == StorageOrder::ColMajor);
106 const bool pack_B = select_pack_B && B_size >= 2 * B_pack_size; // TODO: tune
107 const bool pack_A = select_pack_A && A_size >= 2 * A_pack_size; // TODO: tune
109 auto B_pack = make_aligned_unique_ptr<T>(pack_B ? static_cast<size_t>(B_pack_size) : 0,
110 simd_align_t(), uninitialized);
111 auto A_pack = make_aligned_unique_ptr<T>(pack_A ? static_cast<size_t>(A_pack_size) : 0,
112 simd_align_t(), uninitialized);
115
116 // Three outer loops for tiling, with optional packing of A and B
118 foreach_chunked_merged(0, N, N_cache, [&](index_t j_c, index_t n_c) {
119 foreach_chunked_merged(0, K, K_cache, [&](index_t p_c, index_t k_c) {
120 auto Bkj = B.block(p_c, j_c, k_c, n_c);
121 if (pack_B) {
122 Bkj_pack.reassign({{.data = B_pack.get(), .rows = k_c, .cols = n_c}});
123 detail::copy::copy<T, Abi>(Bkj, Bkj_pack);
124 foreach_chunked_merged(0, M, M_cache, [&](index_t i_c, index_t m_c) {
125 auto Cij = C ? std::make_optional(C->block(i_c, j_c, m_c, n_c)) : std::nullopt;
126 auto Dij = D.block(i_c, j_c, m_c, n_c);
127 auto Aik = A.block(i_c, p_c, m_c, k_c);
128 if (pack_A) {
129 Aik_pack.reassign({{.data = A_pack.get(), .rows = m_c, .cols = k_c}});
130 detail::copy::copy<T, Abi>(Aik, Aik_pack);
131 gemm_copy_register<T, Abi, Conf>(Aik_pack.as_const(), Bkj_pack.as_const(),
132 p_c == 0 ? Cij : Dij, Dij);
133 } else {
134 gemm_copy_register<T, Abi, Conf>(Aik, Bkj_pack.as_const(),
135 p_c == 0 ? Cij : Dij, Dij);
136 }
137 });
138 } else {
139 foreach_chunked_merged(0, M, M_cache, [&](index_t i_c, index_t m_c) {
140 auto Cij = C ? std::make_optional(C->block(i_c, j_c, m_c, n_c)) : std::nullopt;
141 auto Dij = D.block(i_c, j_c, m_c, n_c);
142 auto Aik = A.block(i_c, p_c, m_c, k_c);
143 if (pack_A) {
144 Aik_pack.reassign({{.data = A_pack.get(), .rows = m_c, .cols = k_c}});
145 detail::copy::copy<T, Abi>(Aik, Aik_pack);
146 gemm_copy_register<T, Abi, Conf>(Aik_pack.as_const(), Bkj,
147 p_c == 0 ? Cij : Dij, Dij);
148 } else {
149 gemm_copy_register<T, Abi, Conf>(Aik, Bkj, p_c == 0 ? Cij : Dij, Dij);
150 }
151 });
152 }
153 });
154 });
155}
156
157template <class T, class Abi, micro_kernels::gemm::KernelConfig Conf = {}, StorageOrder OA,
159 requires(Conf.struc_C != MatrixStructure::General)
161 std::optional<view<const T, Abi, OC>> C, view<T, Abi, OD> D) {
162 if (Conf.struc_A != MatrixStructure::General)
163 BATMAT_ASSERT(A.rows() == A.cols()); // TODO: could be relaxed
164 if (Conf.struc_B != MatrixStructure::General)
165 BATMAT_ASSERT(B.rows() == B.cols()); // TODO: could be relaxed
166 BATMAT_ASSERT(D.rows() == D.cols()); // TODO: could be relaxed
167 BATMAT_ASSERT(!C || C->rows() == D.rows());
168 BATMAT_ASSERT(!C || C->cols() == D.cols());
169 BATMAT_ASSERT(A.rows() == D.rows());
170 BATMAT_ASSERT(A.cols() == B.rows());
171 BATMAT_ASSERT(B.cols() == D.cols());
172 const index_t M = D.rows(), N = D.cols(), K = A.cols();
173 [[maybe_unused]] const auto fc = flops::trmm(M, N, K, Conf.struc_A, Conf.struc_B, Conf.struc_C);
174 GUANAQO_TRACE_LINALG("gemmt", total(fc) * D.depth());
175 if (M == 0 || N == 0) [[unlikely]]
176 return;
177 if (K == 0) [[unlikely]] {
178 // https://github.com/llvm/llvm-project/issues/146272
179 constexpr detail::copy::CopyConfig rot{
180 .rotate = Conf.rotate_C - Conf.rotate_D, .mask = Conf.mask_D, .struc = Conf.struc_C};
181 constexpr detail::copy::FillConfig msk{.mask = Conf.mask_D, .struc = Conf.struc_C};
182 if (C)
184 else
186 return;
187 }
188 // TODO: cache blocking
190}
191
192template <class T, class Abi, micro_kernels::gemm::KernelConfig Conf = {}, StorageOrder OA,
194 requires(Conf.struc_A != MatrixStructure::General || Conf.struc_B != MatrixStructure::General)
196 std::optional<view<const T, Abi, OC>> C, view<T, Abi, OD> D) {
197 static_assert(Conf.struc_A != Conf.struc_B,
198 "lower times lower or upper times upper currently not supported"); // TODO
199 if (Conf.struc_A != MatrixStructure::General)
200 BATMAT_ASSERT(A.rows() == A.cols()); // TODO: could be relaxed
201 if (Conf.struc_B != MatrixStructure::General)
202 BATMAT_ASSERT(B.rows() == B.cols()); // TODO: could be relaxed
203 BATMAT_ASSERT(!C || C->rows() == D.rows());
204 BATMAT_ASSERT(!C || C->cols() == D.cols());
205 BATMAT_ASSERT(A.rows() == D.rows());
206 BATMAT_ASSERT(A.cols() == B.rows());
207 BATMAT_ASSERT(B.cols() == D.cols());
208 const index_t M = D.rows(), N = D.cols(), K = A.cols();
209 [[maybe_unused]] const auto fc = flops::trmm(M, N, K, Conf.struc_A, Conf.struc_B, Conf.struc_C);
210 GUANAQO_TRACE_LINALG("trmm", total(fc) * D.depth());
211 if (M == 0 || N == 0) [[unlikely]]
212 return;
213 if (K == 0) [[unlikely]] {
214 // https://github.com/llvm/llvm-project/issues/146272
215 constexpr detail::copy::CopyConfig rot{
216 .rotate = Conf.rotate_C - Conf.rotate_D, .mask = Conf.mask_D, .struc = Conf.struc_C};
217 constexpr detail::copy::FillConfig msk{.mask = Conf.mask_D, .struc = Conf.struc_C};
218 if (C)
220 else
222 return;
223 }
224 // TODO: cache blocking
226}
227
228template <shift_opt... Opts>
231 if (auto s = shift_A<Opts...>)
232 conf.shift_A = *s;
233 if (auto s = rotate_B<Opts...>)
234 conf.rotate_B = *s;
235 if (auto s = rotate_C<Opts...>)
236 conf.rotate_C = *s;
237 if (auto s = rotate_D<Opts...>)
238 conf.rotate_D = *s;
239 if (auto s = mask_D<Opts...>)
240 conf.mask_D = *s;
241 return conf;
242}
243
244} // namespace detail
245
246/// @addtogroup topic-linalg
247/// @{
248
249/// @name Multiplication of batches of general matrices
250/// @{
251
252/// D = A B
253template <simdifiable VA, simdifiable VB, simdifiable VD, shift_opt... Opts>
255void gemm(VA &&A, VB &&B, VD &&D, TilingOptions packing = {}, Opts... opts) {
256 std::optional<decltype(simdify(D).as_const())> null;
257 constexpr auto conf = detail::apply_gemm_options({.negate = false}, opts...);
259 simdify(A).as_const(), simdify(B).as_const(), null, simdify(D), packing);
260}
261
262/// D = -A B
263template <simdifiable VA, simdifiable VB, simdifiable VD, shift_opt... Opts>
265void gemm_neg(VA &&A, VB &&B, VD &&D, TilingOptions packing = {}, Opts... opts) {
266 std::optional<decltype(simdify(D).as_const())> null;
267 constexpr auto conf = detail::apply_gemm_options({.negate = true}, opts...);
269 simdify(A).as_const(), simdify(B).as_const(), null, simdify(D), packing);
270}
271
272/// D = C + A B
273template <simdifiable VA, simdifiable VB, simdifiable VC, simdifiable VD, shift_opt... Opts>
275void gemm_add(VA &&A, VB &&B, VC &&C, VD &&D, TilingOptions packing = {}, Opts... opts) {
276 constexpr auto conf = detail::apply_gemm_options({.negate = false}, opts...);
278 simdify(A).as_const(), simdify(B).as_const(), std::make_optional(simdify(C).as_const()),
279 simdify(D), packing);
280}
281/// D += A B
282template <simdifiable VA, simdifiable VB, simdifiable VD, shift_opt... Opts>
283void gemm_add(VA &&A, VB &&B, VD &&D, TilingOptions packing = {}, Opts... opts) {
284 return gemm_add(A, B, D, D, packing, opts...);
285}
286
287/// D = C - A B
288template <simdifiable VA, simdifiable VB, simdifiable VC, simdifiable VD, shift_opt... Opts>
290void gemm_sub(VA &&A, VB &&B, VC &&C, VD &&D, TilingOptions packing = {}, Opts... opts) {
291 constexpr auto conf = detail::apply_gemm_options({.negate = true}, opts...);
293 simdify(A).as_const(), simdify(B).as_const(), std::make_optional(simdify(C).as_const()),
294 simdify(D), packing);
295}
296/// D -= A B
297template <simdifiable VA, simdifiable VB, simdifiable VD, shift_opt... Opts>
298void gemm_sub(VA &&A, VB &&B, VD &&D, TilingOptions packing = {}, Opts... opts) {
299 return gemm_sub(A, B, D, D, packing, opts...);
300}
301
302/// @}
303
304/// @name Multiplication of batches of matrices with symmetric results
305/// @{
306
307/// D = A Aᵀ with D symmetric
308template <MatrixStructure SA, MatrixStructure SD, simdifiable VA, simdifiable VD, shift_opt... Opts>
311 using enum MatrixStructure;
312 static_assert(SD != General);
313 std::optional<decltype(simdify(D.value).as_const())> null;
314 constexpr auto conf = detail::apply_gemm_options(
315 {.negate = false, .struc_A = SA, .struc_B = transpose(SA), .struc_C = SD}, opts...);
317 simdify(A.value).as_const(), simdify(A.value).as_const().transposed(), null,
318 simdify(D.value));
319}
320
321/// D = A Aᵀ with D symmetric
322template <MatrixStructure SD, class TA, simdifiable VD, shift_opt... Opts>
323void syrk(TA &&A, Structured<VD, SD> D, Opts... opts) {
324 syrk(Structured{std::forward<TA>(A)}, std::move(D), std::forward<Opts>(opts)...);
325}
326
327/// D = D Dᵀ with D triangular on input and symmetric on output
328template <MatrixStructure SD, simdifiable VD, shift_opt... Opts>
329void syrk(Structured<VD, SD> D, Opts... opts) {
330 using enum MatrixStructure;
331 static_assert(SD != General);
332 std::optional<decltype(simdify(D.value).as_const())> null;
333 constexpr auto conf = detail::apply_gemm_options(
334 {.negate = false, .struc_A = SD, .struc_B = transpose(SD), .struc_C = SD}, opts...);
336 simdify(D.value).as_const(), simdify(D.value).as_const().transposed(), null,
337 simdify(D.value));
338}
339
340/// D = -A Aᵀ with D symmetric
341template <MatrixStructure SA, MatrixStructure SD, simdifiable VA, simdifiable VD, shift_opt... Opts>
344 using enum MatrixStructure;
345 static_assert(SD != General);
346 std::optional<decltype(simdify(D.value).as_const())> null;
347 constexpr auto conf = detail::apply_gemm_options(
348 {.negate = true, .struc_A = SA, .struc_B = transpose(SA), .struc_C = SD}, opts...);
350 simdify(A.value).as_const(), simdify(A.value).as_const().transposed(), null,
351 simdify(D.value));
352}
353
354/// D = A Aᵀ with D symmetric
355template <MatrixStructure SD, class TA, simdifiable VD, shift_opt... Opts>
356void syrk_neg(TA &&A, Structured<VD, SD> D, Opts... opts) {
357 syrk_neg(Structured{std::forward<TA>(A)}, std::move(D), std::forward<Opts>(opts)...);
358}
359
360/// D = -D Dᵀ with D triangular on input and symmetric on output
361template <MatrixStructure SD, simdifiable VD, shift_opt... Opts>
362void syrk_neg(Structured<VD, SD> D, Opts... opts) {
363 using enum MatrixStructure;
364 static_assert(SD != General);
365 std::optional<decltype(simdify(D.value).as_const())> null;
366 constexpr auto conf = detail::apply_gemm_options(
367 {.negate = true, .struc_A = SD, .struc_B = transpose(SD), .struc_C = SD}, opts...);
369 simdify(D.value).as_const(), simdify(D.value).as_const().transposed(), null,
370 simdify(D.value));
371}
372
373/// D = C + A Aᵀ with C, D symmetric
374template <MatrixStructure SD, simdifiable VA, simdifiable VC, simdifiable VD, shift_opt... Opts>
376void syrk_add(VA &&A, Structured<VC, SD> C, Structured<VD, SD> D, Opts... opts) {
377 using enum MatrixStructure;
378 static_assert(SD != General);
379 constexpr auto conf = detail::apply_gemm_options({.negate = false, .struc_C = SD}, opts...);
381 simdify(A).as_const(), simdify(A).as_const().transposed(),
382 std::make_optional(simdify(C.value).as_const()), simdify(D.value));
383}
384/// D += A Aᵀ with D symmetric
385template <MatrixStructure SD, simdifiable VA, simdifiable VD, shift_opt... Opts>
386void syrk_add(VA &&A, Structured<VD, SD> D, Opts... opts) {
387 return syrk_add(A, D.ref(), D.ref(), opts...);
388}
389
390/// D = C - A Aᵀ with C, D symmetric
391template <MatrixStructure SD, simdifiable VA, simdifiable VC, simdifiable VD, shift_opt... Opts>
393void syrk_sub(VA &&A, Structured<VC, SD> C, Structured<VD, SD> D, Opts... opts) {
394 using enum MatrixStructure;
395 static_assert(SD != General);
396 constexpr auto conf = detail::apply_gemm_options({.negate = true, .struc_C = SD}, opts...);
398 simdify(A).as_const(), simdify(A).as_const().transposed(),
399 std::make_optional(simdify(C.value).as_const()), simdify(D.value));
400}
401/// D -= A Aᵀ with D symmetric
402template <MatrixStructure SD, simdifiable VA, simdifiable VD, shift_opt... Opts>
403void syrk_sub(VA &&A, Structured<VD, SD> D, Opts... opts) {
404 return syrk_sub(A, D.ref(), D.ref(), opts...);
405}
406
407/// @}
408
409/// @name Multiplication of batches of triangular matrices
410/// @{
411
412/// D = A B with A and/or B triangular
413template <MatrixStructure SA, MatrixStructure SB, MatrixStructure SD, simdifiable VA,
414 simdifiable VB, simdifiable VD, shift_opt... Opts>
417 std::optional<decltype(simdify(D.value).as_const())> null;
418 constexpr auto conf = detail::apply_gemm_options(
419 {.negate = false, .struc_A = SA, .struc_B = SB, .struc_C = SD}, opts...);
421 simdify(A.value).as_const(), simdify(B.value).as_const(), null, simdify(D.value));
422}
423/// D = A B with A and/or B triangular
424template <class TA, class TB, class TD, shift_opt... Opts>
425void trmm(TA &&A, TB &&B, TD &&D, Opts... opts) {
426 return trmm(Structured{std::forward<TA>(A)}, Structured{std::forward<TB>(B)},
427 Structured{std::forward<TD>(D)}, opts...);
428}
429/// D = A D with A triangular
430template <MatrixStructure SA, simdifiable VA, simdifiable VD, shift_opt... Opts>
431void trmm(Structured<VA, SA> A, VD &&D, Opts... opts) {
432 return trmm(A.ref(), Structured{D}, Structured{D}, opts...);
433}
434/// D = D B with B triangular
435template <MatrixStructure SB, simdifiable VB, simdifiable VD, shift_opt... Opts>
436void trmm(VD &&D, Structured<VB, SB> B, Opts... opts) {
437 return trmm(Structured{D}, B.ref(), Structured{D}, opts...);
438}
439
440/// D = -A B with A and/or B triangular
441template <MatrixStructure SA, MatrixStructure SB, MatrixStructure SD, simdifiable VA,
442 simdifiable VB, simdifiable VD, shift_opt... Opts>
445 std::optional<decltype(simdify(D.value).as_const())> null;
446 constexpr auto conf = detail::apply_gemm_options(
447 {.negate = true, .struc_A = SA, .struc_B = SB, .struc_C = SD}, opts...);
449 simdify(A.value).as_const(), simdify(B.value).as_const(), null, simdify(D.value));
450}
451/// D = -A B with A and/or B triangular
452template <class TA, class TB, class TD, shift_opt... Opts>
453void trmm_neg(TA &&A, TB &&B, TD &&D, Opts... opts) {
454 return trmm_neg(Structured{std::forward<TA>(A)}, Structured{std::forward<TB>(B)},
455 Structured{std::forward<TD>(D)}, opts...);
456}
457
458/// D = C + A B with A and/or B triangular
459template <MatrixStructure SA, MatrixStructure SB, MatrixStructure SD, simdifiable VA,
460 simdifiable VB, simdifiable VC, simdifiable VD, shift_opt... Opts>
463 Structured<VD, SD> D, Opts... opts) {
464 constexpr auto conf = detail::apply_gemm_options(
465 {.negate = false, .struc_A = SA, .struc_B = SB, .struc_C = SD}, opts...);
467 simdify(A.value).as_const(), simdify(B.value).as_const(),
468 std::make_optional(simdify(C.value).as_const()), simdify(D.value));
469}
470/// D = C + A B with A and/or B triangular
471template <class TA, class TB, class TC, class TD, shift_opt... Opts>
472void trmm_add(TA &&A, TB &&B, TC &&C, TD &&D, Opts... opts) {
473 return trmm_add(Structured{std::forward<TA>(A)}, Structured{std::forward<TB>(B)},
474 Structured{std::forward<TC>(C)}, Structured{std::forward<TD>(D)}, opts...);
475}
476
477/// D = C - A B with A and/or B triangular
478template <MatrixStructure SA, MatrixStructure SB, MatrixStructure SD, simdifiable VA,
479 simdifiable VB, simdifiable VC, simdifiable VD, shift_opt... Opts>
482 Structured<VD, SD> D, Opts... opts) {
483 constexpr auto conf = detail::apply_gemm_options(
484 {.negate = true, .struc_A = SA, .struc_B = SB, .struc_C = SD}, opts...);
486 simdify(A.value).as_const(), simdify(B.value).as_const(),
487 std::make_optional(simdify(C.value).as_const()), simdify(D.value));
488}
489/// D = C - A B with A and/or B triangular
490template <class TA, class TB, class TC, class TD, shift_opt... Opts>
491void trmm_sub(TA &&A, TB &&B, TC &&C, TD &&D, Opts... opts) {
492 return trmm_sub(Structured{std::forward<TA>(A)}, Structured{std::forward<TB>(B)},
493 Structured{std::forward<TC>(C)}, Structured{std::forward<TD>(D)}, opts...);
494}
495
496/// @}
497
498/// @}
499
500} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
std::ptrdiff_t index_t
constexpr FlopCount gemm(index_t m, index_t n, index_t k)
Matrix-matrix multiplication of m×k and k×n matrices.
Definition flops.hpp:38
constexpr FlopCount trmm(index_t m, index_t n, index_t k, MatrixStructure sA, MatrixStructure sB, MatrixStructure sC)
Matrix-matrix multiplication of m×k and k×n matrices where one or more of the matrices are triangular...
Definition flops.hpp:45
PackingSelector pack_B
When to pack matrix B.
Definition gemm.hpp:31
index_t m_c
Cache block size in the M dimension (rows of A, C and D).
Definition gemm.hpp:34
index_t n_c
Cache block size in the N dimension (columns of B, C and D).
Definition gemm.hpp:32
bool no_tiling
Don't use cache tiling.
Definition gemm.hpp:29
PackingSelector pack_A
When to pack matrix A.
Definition gemm.hpp:30
index_t k_c
Cache block size in the K dimension (columns of A, rows of B).
Definition gemm.hpp:33
void syrk_neg(Structured< VA, SA > A, Structured< VD, SD > D, Opts... opts)
D = -A Aᵀ with D symmetric.
Definition gemm.hpp:343
void gemm_add(VA &&A, VB &&B, VC &&C, VD &&D, TilingOptions packing={}, Opts... opts)
D = C + A B.
Definition gemm.hpp:275
void trmm_neg(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
D = -A B with A and/or B triangular.
Definition gemm.hpp:444
void syrk_sub(VA &&A, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
D = C - A Aᵀ with C, D symmetric.
Definition gemm.hpp:393
void gemm_neg(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
D = -A B.
Definition gemm.hpp:265
PackingSelector
Decides which matrices to pack during large matrix-matrix multiplication.
Definition gemm.hpp:20
void trmm_sub(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
D = C - A B with A and/or B triangular.
Definition gemm.hpp:481
void gemm(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
D = A B.
Definition gemm.hpp:255
void syrk(Structured< VA, SA > A, Structured< VD, SD > D, Opts... opts)
D = A Aᵀ with D symmetric.
Definition gemm.hpp:310
void trmm(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
D = A B with A and/or B triangular.
Definition gemm.hpp:416
void gemm_sub(VA &&A, VB &&B, VC &&C, VD &&D, TilingOptions packing={}, Opts... opts)
D = C - A B.
Definition gemm.hpp:290
void syrk_add(VA &&A, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
D = C + A Aᵀ with C, D symmetric.
Definition gemm.hpp:376
constexpr MatrixStructure transpose(MatrixStructure s)
Definition structure.hpp:11
void trmm_add(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
D = C + A B with A and/or B triangular.
Definition gemm.hpp:462
@ Always
Always pack the blocks of the matrix in a contiguous workspace.
Definition gemm.hpp:22
@ Never
Access the original matrices directly in the micro-kernels.
Definition gemm.hpp:21
@ Transpose
Pack the blocks of the matrix only if it is not in the optimal storage order.
Definition gemm.hpp:23
Packing and tiling options for matrix-matrix multiplication.
Definition gemm.hpp:28
struct batmat::matrix::uninitialized_t uninitialized
Tag type to indicate that memory should not be initialized.
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
#define GUANAQO_TRACE_LINALG(name, gflops)
void fill(T a, view< T, Abi, OB > B)
Definition copy.hpp:27
void copy(view< const T, Abi, OA > A, view< T, Abi, OB > B)
Definition copy.hpp:68
void trmm(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D)
Definition gemm.hpp:195
void gemmt(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D)
Definition gemm.hpp:160
constexpr micro_kernels::gemm::KernelConfig apply_gemm_options(micro_kernels::gemm::KernelConfig conf, Opts...)
Definition gemm.hpp:230
void gemm(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D, TilingOptions packing={})
Definition gemm.hpp:42
const constinit auto gemm_copy_lut
Definition gemm.hpp:40
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
void gemm_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Using register blocking.
Definition gemm.tpp:165
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:204
constexpr std::optional< int > rotate_B
Definition shift.hpp:38
constexpr std::optional< int > rotate_C
Definition shift.hpp:45
constexpr bool simdify_compatible
Definition simdify.hpp:207
constexpr std::optional< int > mask_D
Definition shift.hpp:59
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:214
constexpr std::optional< int > rotate_D
Definition shift.hpp:52
constexpr std::optional< int > shift_A
Definition shift.hpp:31
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Aligned allocation for matrix storage.
Light-weight wrapper class used for overload resolution of triangular and symmetric matrices.
datapar::simd_align< T, Abi > simd_align_t
Definition uview.hpp:25
static constexpr auto simd_stride
Definition uview.hpp:26