batmat main
Batched linear algebra routines
Loading...
Searching...
No Matches
gemm-diag.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
7#include <batmat/lut.hpp>
9
10#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
11
13
14template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OC,
15 StorageOrder OD>
17 []<index_t Row, index_t Col>(index_constant<Row>, index_constant<Col>) {
19 });
20
21template <MatrixStructure Struc>
22inline constexpr auto first_column =
23 [](index_t row_index) { return Struc == MatrixStructure::UpperTriangular ? row_index : 0; };
24
25template <index_t ColsReg, MatrixStructure Struc>
26inline constexpr auto last_column = [](index_t row_index) {
27 return Struc == MatrixStructure::LowerTriangular ? std::min(row_index, ColsReg - 1)
28 : ColsReg - 1;
29};
30
31/// Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Single register block.
32template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder OA,
34[[gnu::hot, gnu::flatten]] std::conditional_t<Conf.track_zeros, std::pair<index_t, index_t>, void>
36 const std::optional<uview<const T, Abi, OC>> C,
38 const index_t k) noexcept {
39 static_assert(RowsReg > 0 && ColsReg > 0);
40 using enum MatrixStructure;
41 using namespace ops;
42 using simd = datapar::simd<T, Abi>;
43 // Column range for triangular matrix C (gemmt)
44 static constexpr auto min_col = first_column<Conf.struc_C>;
45 static constexpr auto max_col = last_column<ColsReg, Conf.struc_C>;
46 // The following assumption ensures that there is no unnecessary branch
47 // for k == 0 in between the loops. This is crucial for good code
48 // generation, otherwise the compiler inserts jumps and labels between
49 // the matmul kernel and the loading/storing of C, which will cause it to
50 // place C_reg on the stack, resulting in many unnecessary loads and stores.
51 BATMAT_ASSUME(k > 0);
52 // Load accumulator into registers
53 simd C_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
54 if (C) [[likely]] {
55 const auto C_cached = with_cached_access<RowsReg, ColsReg>(*C);
56 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
57 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
58 C_reg[ii][jj] = C_cached.load(ii, jj);
59 } else {
60 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
61 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
62 C_reg[ii][jj] = simd{0};
63 }
64
65 const auto A_cached = with_cached_access<RowsReg, 0>(A);
66 const auto B_cached = with_cached_access<0, ColsReg>(B);
67
68 // Rectangular matrix multiplication kernel
69 index_t first_nonzero = -1, last_nonzero = -1;
70 for (index_t l = 0; l < k; ++l) {
71 bool all_zero = true;
72 simd dl = d.load(l);
73 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
74 simd Ail = dl * A_cached.load(ii, l);
75 if constexpr (Conf.track_zeros)
76 all_zero &= all_of(Ail == simd{0});
77 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
78 simd &Cij = C_reg[ii][jj];
79 simd Blj = B_cached.load(l, jj);
80 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
81 }
82 }
83 if constexpr (Conf.track_zeros)
84 if (!all_zero) {
85 last_nonzero = l;
86 if (first_nonzero < 0)
87 first_nonzero = l;
88 }
89 }
90
91 const auto D_cached = with_cached_access<RowsReg, ColsReg>(D);
92 // Store accumulator to memory again
93 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
94 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
95 D_cached.store(C_reg[ii][jj], ii, jj);
96
97 if constexpr (Conf.track_zeros) {
98 if (first_nonzero < 0)
99 return {k, k};
100 else
101 return {first_nonzero, last_nonzero + 1};
102 }
103}
104
105/// Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Using register blocking.
106template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OC,
107 StorageOrder OD>
109 const std::optional<view<const T, Abi, OC>> C,
110 const view<T, Abi, OD> D, view<const T, Abi> d) noexcept {
111 using enum MatrixStructure;
112 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
113 // Check dimensions
114 const index_t I = D.rows(), J = D.cols(), K = A.cols();
115 BATMAT_ASSUME(A.rows() == I);
116 BATMAT_ASSUME(B.rows() == K);
117 BATMAT_ASSUME(B.cols() == J);
118 BATMAT_ASSUME(d.rows() == K);
119 BATMAT_ASSUME(d.cols() == 1);
120 if constexpr (Conf.struc_C != General)
121 BATMAT_ASSUME(I == J);
122 BATMAT_ASSUME(I > 0);
123 BATMAT_ASSUME(J > 0);
124 BATMAT_ASSUME(K > 0);
125 // Configurations for the various micro-kernels
126 constexpr KernelConfig ConfSmall{.negate = Conf.negate, .struc_C = Conf.struc_C};
127 constexpr KernelConfig ConfSub{.negate = Conf.negate, .struc_C = General};
128 static const auto microkernel = gemm_diag_copy_lut<T, Abi, Conf, OA, OB, OC, OD>;
129 static const auto microkernel_small = gemm_diag_copy_lut<T, Abi, ConfSmall, OA, OB, OC, OD>;
130 static const auto microkernel_sub = gemm_diag_copy_lut<T, Abi, ConfSub, OA, OB, OC, OD>;
131 (void)microkernel_sub; // GCC incorrectly warns about unused variable
132 // Sizeless views to partition and pass to the micro-kernels
133 const uview<const T, Abi, OA> A_ = A;
134 const uview<const T, Abi, OB> B_ = B;
135 const std::optional<uview<const T, Abi, OC>> C_ = C;
136 const uview<T, Abi, OD> D_ = D;
137 const uview_vec<const T, Abi> d_{d};
138
139 // Optimization for very small matrices
140 if (I <= Rows && J <= Cols)
141 return microkernel_small[I - 1][J - 1](A_, B_, C_, D_, d_, K);
142
143 // Loop over block rows of A and block columns of B
144 foreach_chunked_merged(0, I, index_constant<Rows>(), [&](index_t i, auto ni) {
145 const auto Ai = A_.middle_rows(i);
146 // If triangular: use diagonal block (i, i) for counting zeros.
147 // If general: use first block column (i, 0).
148 const auto j0 = Conf.struc_C == UpperTriangular ? i + ni
149 : Conf.struc_C == LowerTriangular ? 0
150 : std::min(Cols, J),
151 j1 = Conf.struc_C == LowerTriangular ? i : J;
152 // First micro-kernel call that keeps track of the leading/trailing zeros in A diag(d)
153 auto [l0, l1] = [&] {
154 const auto j = Conf.struc_C == General ? 0 : i;
155 const auto nj = std::min<index_t>(Cols, J - j);
156 const auto Bj = B_.middle_cols(j);
157 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
158 const auto Dij = D_.block(i, j);
159 if constexpr (Conf.track_zeros)
160 return microkernel[ni - 1][nj - 1](Ai, Bj, Cij, Dij, d_, K);
161 else {
162 microkernel[ni - 1][nj - 1](Ai, Bj, Cij, Dij, d_, K);
163 return std::pair<index_t, index_t>{0, K};
164 }
165 }();
166 if (l1 == l0) {
167 if (!C)
168 D.block(i, j0, ni, j1 - j0).set_constant(T{});
169 else if (C->data() == D.data() && C->outer_stride() == D.outer_stride())
170 BATMAT_ASSUME(C->storage_order == D.storage_order); // Nothing to do
171 else if constexpr (OC == StorageOrder::ColMajor)
172 for (index_t jj = j0; jj < j1; ++jj) // TODO: suboptimal when transpose required
173 for (index_t ii = i; ii < i + ni; ++ii)
174 D_.store(C_->load(ii, jj), ii, jj);
175 else
176 for (index_t ii = i; ii < i + ni; ++ii) // TODO: suboptimal when transpose required
177 for (index_t jj = j0; jj < j1; ++jj)
178 D_.store(C_->load(ii, jj), ii, jj);
179 return;
180 }
181 // Process other blocks, trimming any leading/trailing zeros (before l0 or after l1)
182 foreach_chunked_merged(j0, j1, index_constant<Cols>(), [&](index_t j, auto nj) {
183 const auto Bj = B_.middle_cols(j);
184 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
185 const auto Dij = D_.block(i, j);
186 const auto Ail = Ai.middle_cols(l0);
187 const auto Blj = Bj.middle_rows(l0);
188 const auto dl = d_.segment(l0);
189 microkernel_sub[ni - 1][nj - 1](Ail, Blj, Cij, Dij, dl, l1 - l0);
190 });
191 });
192}
193
194} // namespace batmat::linalg::micro_kernels::gemm_diag
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
std::conditional_t< Conf.track_zeros, std::pair< index_t, index_t >, void > gemm_diag_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, std::optional< uview< const T, Abi, OC > > C, uview< T, Abi, OD > D, uview_vec< const T, Abi > diag, index_t k) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Single register block.
Definition gemm-diag.tpp:35
void gemm_diag_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D, view< const T, Abi > diag) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Using register blocking.
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Self segment(this const Self &self, index_t r) noexcept
Definition uview.hpp:162
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
void store(simd x, index_t r, index_t c) const noexcept
Definition uview.hpp:104
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114
Self middle_cols(this const Self &self, index_t c) noexcept
Definition uview.hpp:118