batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
gemm-diag.hpp
Go to the documentation of this file.
1#pragma once
2
5#include <batmat/micro-kernels/gemm-diag/export.h>
7#include <optional>
8#include <type_traits>
9#include <utility>
10
12
13struct BATMAT_LINALG_GEMM_DIAG_EXPORT KernelConfig {
14 bool negate = false;
15 bool track_zeros = false;
17};
18
19template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder OA,
21std::conditional_t<Conf.track_zeros, std::pair<index_t, index_t>, void>
24 uview_vec<const T, Abi> diag, index_t k) noexcept;
25
26template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OC,
27 StorageOrder OD>
28BATMAT_LINALG_GEMM_DIAG_EXPORT void
30 std::optional<view<const T, Abi, OC>> C, view<T, Abi, OD> D,
31 view<const T, Abi> diag) noexcept;
32
33// Square block sizes greatly simplify handling of triangular matrices.
34using gemm::RowsReg;
35template <class T, class Abi>
36constexpr index_t ColsReg = RowsReg<T, Abi>;
37
38} // namespace batmat::linalg::micro_kernels::gemm_diag
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
std::conditional_t< Conf.track_zeros, std::pair< index_t, index_t >, void > gemm_diag_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, std::optional< uview< const T, Abi, OC > > C, uview< T, Abi, OD > D, uview_vec< const T, Abi > diag, index_t k) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Single register block.
Definition gemm-diag.tpp:35
void gemm_diag_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D, view< const T, Abi > diag) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Using register blocking.
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70