38 const index_t k)
noexcept {
44 static constexpr auto min_col =
first_column<Conf.struc_C>;
53 if constexpr (Conf.struc_A !=
General)
55 if constexpr (Conf.struc_B !=
General)
62 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
63 C_reg[ii][jj] = rotl<Conf.rotate_C>(C_cached.load(ii, jj));
66 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
67 C_reg[ii][jj] = simd{0};
78 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
80 simd &Cij = C_reg[ii][jj];
81 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
82 simd Blj = rotl<Conf.rotate_B>(B_cached.load(ll, jj));
83 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
91 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
92 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
93 simd &Cij = C_reg[ii][jj];
94 simd Blj = rotl<Conf.rotate_B>(B_cached.load(ll, jj));
95 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
103 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
104 UNROLL_FOR (index_t jj = min_col(ii); jj <= std::min(ll, max_col(ii)); ++jj) {
105 simd &Cij = C_reg[ii][jj];
106 simd Blj = rotl<Conf.rotate_B>(B_cached.load(ll, jj));
107 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
116 for (; l < std::min(l_end_A, l_end_B); ++l) {
118 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l));
119 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
120 simd &Cij = C_reg[ii][jj];
121 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l, jj));
122 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
130 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
132 UNROLL_FOR (index_t ll = 0; ll <= lmax; ++ll) {
133 simd &Cij = C_reg[ii][jj];
134 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
135 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l + ll, jj));
136 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
143 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
144 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
145 simd &Cij = C_reg[ii][jj];
146 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l + ll, jj));
147 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
154 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
155 UNROLL_FOR (index_t jj = std::max(ll, min_col(ii)); jj <= max_col(ii); ++jj) {
156 simd &Cij = C_reg[ii][jj];
157 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l + ll, jj));
158 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
167 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
168 D_cached.template store<Conf.mask_D>(rotr<Conf.rotate_D>(C_reg[ii][jj]), ii, jj);
180 const index_t I = D.rows(), J = D.cols(), K = A.cols();
184 if constexpr (Conf.struc_A !=
General)
186 if constexpr (Conf.struc_B !=
General)
188 if constexpr (Conf.struc_C !=
General)
195 .shift_A = Conf.shift_A,
196 .rotate_B = Conf.rotate_B,
197 .rotate_C = Conf.rotate_C,
198 .rotate_D = Conf.rotate_D,
199 .mask_D = Conf.mask_D,
201 .struc_B = Conf.struc_B,
204 .shift_A = Conf.shift_A,
205 .rotate_B = Conf.rotate_B,
206 .rotate_C = Conf.rotate_C,
207 .rotate_D = Conf.rotate_D,
208 .mask_D = Conf.mask_D,
209 .struc_A = Conf.struc_A,
213 .shift_A = Conf.shift_A,
214 .rotate_B = Conf.rotate_B,
215 .rotate_C = Conf.rotate_C,
216 .rotate_D = Conf.rotate_D,
217 .mask_D = Conf.mask_D,
218 .struc_A = Conf.struc_A,
219 .struc_B = Conf.struc_B,
228 const std::optional<uview<const T, Abi, OC>> C_ = C;
232 if (I <= Rows && J <= Cols)
233 return microkernel[I - 1][J - 1](A_, B_, C_, D_, K);
236 auto run = [&] [[gnu::always_inline]] (index_t i, index_t ni, index_t j, index_t nj) {
239 const auto l1A = Conf.struc_A ==
LowerTriangular ? i + ni + std::max(K, I) - I : K;
241 const auto l1B = Conf.struc_B ==
UpperTriangular ? j + nj + std::max(K, J) - J : K;
242 const auto l0 = std::max(l0A, l0B);
243 const auto l1 = std::min(l1A, l1B);
245 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
246 const auto Dij = D_.
block(i, j);
247 const auto Ail = Ai.middle_cols(l0);
248 const auto Blj = Bj.middle_rows(l0);
254 microkernel_GXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
256 }
else if (l1A < l1B) {
257 microkernel_XGG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
263 microkernel_XGG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
265 }
else if (l0A < l0B) {
266 microkernel_GXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
270 if constexpr (Conf.struc_C !=
General) {
272 microkernel_XXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
276 microkernel[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
282 if constexpr (OB == StorageOrder::ColMajor)
285 [&](index_t j,
auto nj) {
296 [&](index_t i,
auto ni) {
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
void gemm_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Using register blocking.
void gemm_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, std::optional< uview< const T, Abi, OC > > C, uview< T, Abi, OD > D, index_t k) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Single register block.