batmat 0.0.19
Batched linear algebra routines
Loading...
Searching...
No Matches
gemm.hpp
Go to the documentation of this file.
1#pragma once
2
5#include <batmat/lut.hpp>
6#include <batmat/micro-kernels/gemm/export.h>
8#include <optional>
9
11
23
24template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder OA,
28 index_t k) noexcept;
29
30template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OC,
31 StorageOrder OD>
32BATMAT_LINALG_GEMM_EXPORT void
34 std::optional<view<const T, Abi, OC>> C, view<T, Abi, OD> D) noexcept;
35
36// Square block sizes greatly simplify handling of triangular matrices.
37template <class T, class Abi>
38constexpr index_t ColsReg = RowsReg<T, Abi>;
39
40namespace detail {
41// Initialization of the LUT. The actual LUT is not defined here, because it needs to be exported
42// in the shared library, we don't want the compiler calling the micro-kernels directly, since
43// those are not exported. We could move this value to the .tpp file, but then we'd have to spell
44// out the type here anyway.
45template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OC,
46 StorageOrder OD>
48 []<index_t Row, index_t Col>(index_constant<Row>, index_constant<Col>) {
50 });
51} // namespace detail
52
53#ifndef BATMAT_LINALG_GEMM_NO_DECLARE_LUT
54template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OC,
55 StorageOrder OD>
56BATMAT_LINALG_GEMM_EXPORT extern const constinit decltype(detail::gemm_copy_lut<T, Abi, Conf, OA,
57 OB, OC, OD>)
59#endif
60
61} // namespace batmat::linalg::micro_kernels::gemm
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
const constinit decltype(detail::gemm_copy_lut< T, Abi, Conf, OA, OB, OC, OD >) gemm_copy_lut
Definition gemm.tpp:20
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
void gemm_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Using register blocking.
Definition gemm.tpp:174
void gemm_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, std::optional< uview< const T, Abi, OC > > C, uview< T, Abi, OD > D, index_t k) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Single register block.
Definition gemm.tpp:36
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10