13#include <guanaqo/trace.hpp>
50 const index_t M = D.rows(), N = D.cols(), K = A.cols();
56 if (M == 0 || N == 0) [[unlikely]]
58 if (K == 0) [[unlikely]] {
61 .rotate = Conf.rotate_C - Conf.rotate_D, .mask = Conf.mask_D, .struc = Conf.struc_C};
62 constexpr detail::copy::FillConfig msk{.mask = Conf.mask_D, .struc = Conf.struc_C};
72 if (M <= M_reg && N <= N_reg) [[likely]]
73 return gemm_copy_lut<T, Abi, Conf, OA, OB, OC, OD>[M - 1][N - 1](A, B, C, D, K);
77 static const index_t L1_cache_size = 48_KiB;
78 static const index_t L2_cache_size = 512_KiB;
79 static const index_t L3_cache_size = 16_MiB;
80 static const index_t n_cores = 8;
82 static const index_t K_cache_default = L1_cache_size /
sizeof(T) / simd_stride / N_reg;
83 static const index_t M_cache_default = (L2_cache_size /
sizeof(T) / simd_stride / K_cache_default / M_reg) * M_reg;
84 static const index_t N_cache_default = std::max<index_t>(L3_cache_size /
sizeof(T) / simd_stride / K_cache_default / n_cores / M_cache_default, 1) * M_cache_default;
86 const index_t K_cache = packing.k_c ? packing.k_c : K_cache_default;
87 const index_t M_cache = packing.m_c ? packing.m_c : M_cache_default;
88 const index_t N_cache = packing.n_c ? packing.n_c : N_cache_default;
91 if ((M <= M_cache && N <= N_cache && K <= K_cache) || packing.no_tiling) [[likely]]
96 const index_t B_pack_size = B.ceil_depth() * K_cache * N_cache;
97 const index_t A_pack_size = A.ceil_depth() * M_cache * K_cache;
98 const index_t B_size = B.ceil_depth() * K * N;
99 const index_t A_size = A.ceil_depth() * M * K;
100 const bool select_pack_B =
103 const bool select_pack_A =
106 const bool pack_B = select_pack_B && B_size >= 2 * B_pack_size;
107 const bool pack_A = select_pack_A && A_size >= 2 * A_pack_size;
109 auto B_pack = make_aligned_unique_ptr<T>(pack_B ?
static_cast<size_t>(B_pack_size) : 0,
110 simd_align_t(), uninitialized);
111 auto A_pack = make_aligned_unique_ptr<T>(pack_A ?
static_cast<size_t>(A_pack_size) : 0,
112 simd_align_t(), uninitialized);
120 auto Bkj = B.block(p_c, j_c, k_c, n_c);
122 Bkj_pack.reassign({{.data = B_pack.get(), .rows = k_c, .cols = n_c}});
125 auto Cij = C ? std::make_optional(C->block(i_c, j_c, m_c, n_c)) : std::nullopt;
126 auto Dij = D.block(i_c, j_c, m_c, n_c);
127 auto Aik = A.block(i_c, p_c, m_c, k_c);
129 Aik_pack.reassign({{.data = A_pack.get(), .rows = m_c, .cols = k_c}});
131 gemm_copy_register<T, Abi, Conf>(Aik_pack.as_const(), Bkj_pack.as_const(),
132 p_c == 0 ? Cij : Dij, Dij);
134 gemm_copy_register<T, Abi, Conf>(Aik, Bkj_pack.as_const(),
135 p_c == 0 ? Cij : Dij, Dij);
140 auto Cij = C ? std::make_optional(C->block(i_c, j_c, m_c, n_c)) : std::nullopt;
141 auto Dij = D.block(i_c, j_c, m_c, n_c);
142 auto Aik = A.block(i_c, p_c, m_c, k_c);
144 Aik_pack.reassign({{.data = A_pack.get(), .rows = m_c, .cols = k_c}});
146 gemm_copy_register<T, Abi, Conf>(Aik_pack.as_const(), Bkj,
147 p_c == 0 ? Cij : Dij, Dij);
149 gemm_copy_register<T, Abi, Conf>(Aik, Bkj, p_c == 0 ? Cij : Dij, Dij);
157template <
class T,
class Abi, micro_kernels::gemm::KernelConfig Conf = {},
StorageOrder OA,
172 const index_t M = D.rows(), N = D.cols(), K = A.cols();
173 [[maybe_unused]]
const auto fc =
flops::trmm(M, N, K, Conf.struc_A, Conf.struc_B, Conf.struc_C);
175 if (M == 0 || N == 0) [[unlikely]]
177 if (K == 0) [[unlikely]] {
180 .rotate = Conf.rotate_C - Conf.rotate_D, .mask = Conf.mask_D, .struc = Conf.struc_C};
197 static_assert(Conf.struc_A != Conf.struc_B,
198 "lower times lower or upper times upper currently not supported");
208 const index_t M = D.rows(), N = D.cols(), K = A.cols();
209 [[maybe_unused]]
const auto fc =
flops::trmm(M, N, K, Conf.struc_A, Conf.struc_B, Conf.struc_C);
211 if (M == 0 || N == 0) [[unlikely]]
213 if (K == 0) [[unlikely]] {
216 .rotate = Conf.rotate_C - Conf.rotate_D, .mask = Conf.mask_D, .struc = Conf.struc_C};
253template <simdifiable VA, simdifiable VB, simdifiable VD, shift_opt... Opts>
256 std::optional<
decltype(
simdify(D).as_const())> null;
266 std::optional<
decltype(
simdify(D).as_const())> null;
284 return gemm_add(A, B, D, D, packing, opts...);
288template <simdifiable VA, simdifiable VB, simdifiable VC, simdifiable VD, shift_opt... Opts>
299 return gemm_sub(A, B, D, D, packing, opts...);
313 std::optional<
decltype(
simdify(D.
value).as_const())> null;
315 {.negate =
false, .struc_A = SA, .struc_B =
transpose(SA), .struc_C = SD}, opts...);
322template <
MatrixStructure SD,
class TA, simdifiable VD, shift_opt... Opts>
324 syrk(
Structured{std::forward<TA>(A)}, std::move(D), std::forward<Opts>(opts)...);
332 std::optional<
decltype(
simdify(D.
value).as_const())> null;
334 {.negate =
false, .struc_A = SD, .struc_B =
transpose(SD), .struc_C = SD}, opts...);
346 std::optional<
decltype(
simdify(D.
value).as_const())> null;
348 {.negate =
true, .struc_A = SA, .struc_B =
transpose(SA), .struc_C = SD}, opts...);
355template <
MatrixStructure SD,
class TA, simdifiable VD, shift_opt... Opts>
365 std::optional<
decltype(
simdify(D.
value).as_const())> null;
367 {.negate =
true, .struc_A = SD, .struc_B =
transpose(SD), .struc_C = SD}, opts...);
374template <
MatrixStructure SD, simdifiable VA, simdifiable VC, simdifiable VD, shift_opt... Opts>
385template <
MatrixStructure SD, simdifiable VA, simdifiable VD, shift_opt... Opts>
391template <
MatrixStructure SD, simdifiable VA, simdifiable VC, simdifiable VD, shift_opt... Opts>
402template <
MatrixStructure SD, simdifiable VA, simdifiable VD, shift_opt... Opts>
414 simdifiable VB, simdifiable VD, shift_opt... Opts>
417 std::optional<
decltype(
simdify(D.
value).as_const())> null;
419 {.negate =
false, .struc_A = SA, .struc_B = SB, .struc_C = SD}, opts...);
424template <
class TA,
class TB,
class TD, shift_opt... Opts>
425void trmm(TA &&A, TB &&B, TD &&D, Opts... opts) {
430template <
MatrixStructure SA, simdifiable VA, simdifiable VD, shift_opt... Opts>
435template <
MatrixStructure SB, simdifiable VB, simdifiable VD, shift_opt... Opts>
442 simdifiable VB, simdifiable VD, shift_opt... Opts>
445 std::optional<
decltype(
simdify(D.
value).as_const())> null;
447 {.negate =
true, .struc_A = SA, .struc_B = SB, .struc_C = SD}, opts...);
452template <
class TA,
class TB,
class TD, shift_opt... Opts>
453void trmm_neg(TA &&A, TB &&B, TD &&D, Opts... opts) {
460 simdifiable VB, simdifiable VC, simdifiable VD, shift_opt... Opts>
465 {.negate =
false, .struc_A = SA, .struc_B = SB, .struc_C = SD}, opts...);
471template <
class TA,
class TB,
class TC,
class TD, shift_opt... Opts>
472void trmm_add(TA &&A, TB &&B, TC &&C, TD &&D, Opts... opts) {
479 simdifiable VB, simdifiable VC, simdifiable VD, shift_opt... Opts>
484 {.negate =
true, .struc_A = SA, .struc_B = SB, .struc_C = SD}, opts...);
490template <
class TA,
class TB,
class TC,
class TD, shift_opt... Opts>
491void trmm_sub(TA &&A, TB &&B, TC &&C, TD &&D, Opts... opts) {
constexpr FlopCount gemm(index_t m, index_t n, index_t k)
Matrix-matrix multiplication of m×k and k×n matrices.
constexpr FlopCount trmm(index_t m, index_t n, index_t k, MatrixStructure sA, MatrixStructure sB, MatrixStructure sC)
Matrix-matrix multiplication of m×k and k×n matrices where one or more of the matrices are triangular...
PackingSelector pack_B
When to pack matrix B.
index_t m_c
Cache block size in the M dimension (rows of A, C and D).
index_t n_c
Cache block size in the N dimension (columns of B, C and D).
bool no_tiling
Don't use cache tiling.
PackingSelector pack_A
When to pack matrix A.
index_t k_c
Cache block size in the K dimension (columns of A, rows of B).
void syrk_neg(Structured< VA, SA > A, Structured< VD, SD > D, Opts... opts)
D = -A Aᵀ with D symmetric.
void gemm_add(VA &&A, VB &&B, VC &&C, VD &&D, TilingOptions packing={}, Opts... opts)
D = C + A B.
void trmm_neg(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
D = -A B with A and/or B triangular.
void syrk_sub(VA &&A, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
D = C - A Aᵀ with C, D symmetric.
void gemm_neg(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
D = -A B.
PackingSelector
Decides which matrices to pack during large matrix-matrix multiplication.
void trmm_sub(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
D = C - A B with A and/or B triangular.
void gemm(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
D = A B.
void syrk(Structured< VA, SA > A, Structured< VD, SD > D, Opts... opts)
D = A Aᵀ with D symmetric.
void trmm(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
D = A B with A and/or B triangular.
void gemm_sub(VA &&A, VB &&B, VC &&C, VD &&D, TilingOptions packing={}, Opts... opts)
D = C - A B.
void syrk_add(VA &&A, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
D = C + A Aᵀ with C, D symmetric.
constexpr MatrixStructure transpose(MatrixStructure s)
void trmm_add(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
D = C + A B with A and/or B triangular.
@ Always
Always pack the blocks of the matrix in a contiguous workspace.
@ Never
Access the original matrices directly in the micro-kernels.
@ Transpose
Pack the blocks of the matrix only if it is not in the optimal storage order.
Packing and tiling options for matrix-matrix multiplication.
struct batmat::matrix::uninitialized_t uninitialized
Tag type to indicate that memory should not be initialized.
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
#define GUANAQO_TRACE_LINALG(name, gflops)
void fill(T a, view< T, Abi, OB > B)
void copy(view< const T, Abi, OA > A, view< T, Abi, OB > B)
void trmm(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D)
void gemmt(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D)
constexpr micro_kernels::gemm::KernelConfig apply_gemm_options(micro_kernels::gemm::KernelConfig conf, Opts...)
void gemm(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D, TilingOptions packing={})
const constinit auto gemm_copy_lut
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
void gemm_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Using register blocking.
constexpr index_t ColsReg
typename detail::simdified_abi< V >::type simdified_abi_t
constexpr std::optional< int > rotate_C
constexpr bool simdify_compatible
constexpr std::optional< int > mask_D
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
constexpr std::optional< int > shift_B
constexpr std::optional< int > rotate_D
constexpr std::optional< int > shift_A
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Aligned allocation for matrix storage.
Light-weight wrapper class used for overload resolution of triangular and symmetric matrices.
datapar::simd_align< T, Abi > simd_align_t
static constexpr auto simd_stride