29 const index_t k)
noexcept {
35 static constexpr auto min_col =
first_column<Conf.struc_C>;
44 if constexpr (Conf.struc_A !=
General)
46 if constexpr (Conf.struc_B !=
General)
53 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
54 C_reg[ii][jj] = rotl<Conf.rotate_C>(C_cached.load(ii, jj));
57 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
58 C_reg[ii][jj] = simd{0};
69 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
71 simd &Cij = C_reg[ii][jj];
72 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
73 simd Blj = rotl<Conf.rotate_B>(B_cached.load(ll, jj));
74 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
82 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
83 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
84 simd &Cij = C_reg[ii][jj];
85 simd Blj = rotl<Conf.rotate_B>(B_cached.load(ll, jj));
86 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
94 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, ll));
95 UNROLL_FOR (index_t jj = min_col(ii); jj <= std::min(ll, max_col(ii)); ++jj) {
96 simd &Cij = C_reg[ii][jj];
97 simd Blj = rotl<Conf.rotate_B>(B_cached.load(ll, jj));
98 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
107 for (; l < std::min(l_end_A, l_end_B); ++l) {
109 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l));
110 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
111 simd &Cij = C_reg[ii][jj];
112 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l, jj));
113 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
121 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
123 UNROLL_FOR (index_t ll = 0; ll <= lmax; ++ll) {
124 simd &Cij = C_reg[ii][jj];
125 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
126 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l + ll, jj));
127 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
134 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
135 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
136 simd &Cij = C_reg[ii][jj];
137 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l + ll, jj));
138 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
145 simd Ail = shiftl<Conf.shift_A>(A_cached.load(ii, l + ll));
146 UNROLL_FOR (index_t jj = std::max(ll, min_col(ii)); jj <= max_col(ii); ++jj) {
147 simd &Cij = C_reg[ii][jj];
148 simd Blj = rotl<Conf.rotate_B>(B_cached.load(l + ll, jj));
149 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
158 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
159 D_cached.template store<Conf.mask_D>(rotr<Conf.rotate_D>(C_reg[ii][jj]), ii, jj);
171 const index_t I = D.rows(), J = D.cols(), K = A.cols();
175 if constexpr (Conf.struc_A !=
General)
177 if constexpr (Conf.struc_B !=
General)
179 if constexpr (Conf.struc_C !=
General)
186 .shift_A = Conf.shift_A,
187 .rotate_B = Conf.rotate_B,
188 .rotate_C = Conf.rotate_C,
189 .rotate_D = Conf.rotate_D,
190 .mask_D = Conf.mask_D,
192 .struc_B = Conf.struc_B,
195 .shift_A = Conf.shift_A,
196 .rotate_B = Conf.rotate_B,
197 .rotate_C = Conf.rotate_C,
198 .rotate_D = Conf.rotate_D,
199 .mask_D = Conf.mask_D,
200 .struc_A = Conf.struc_A,
204 .shift_A = Conf.shift_A,
205 .rotate_B = Conf.rotate_B,
206 .rotate_C = Conf.rotate_C,
207 .rotate_D = Conf.rotate_D,
208 .mask_D = Conf.mask_D,
209 .struc_A = Conf.struc_A,
210 .struc_B = Conf.struc_B,
219 const std::optional<uview<const T, Abi, OC>> C_ = C;
223 if (I <= Rows && J <= Cols)
224 return microkernel[I - 1][J - 1](A_, B_, C_, D_, K);
227 auto run = [&] [[gnu::always_inline]] (index_t i, index_t ni, index_t j, index_t nj) {
230 const auto l1A = Conf.struc_A ==
LowerTriangular ? i + ni + std::max(K, I) - I : K;
232 const auto l1B = Conf.struc_B ==
UpperTriangular ? j + nj + std::max(K, J) - J : K;
233 const auto l0 = std::max(l0A, l0B);
234 const auto l1 = std::min(l1A, l1B);
236 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
237 const auto Dij = D_.
block(i, j);
238 const auto Ail = Ai.middle_cols(l0);
239 const auto Blj = Bj.middle_rows(l0);
245 microkernel_GXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
247 }
else if (l1A < l1B) {
248 microkernel_XGG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
254 microkernel_XGG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
256 }
else if (l0A < l0B) {
257 microkernel_GXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
261 if constexpr (Conf.struc_C !=
General) {
263 microkernel_XXG[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
267 microkernel[ni - 1][nj - 1](Ail, Blj, Cij, Dij, l1 - l0);
273 if constexpr (OB == StorageOrder::ColMajor)
276 [&](index_t j,
auto nj) {
287 [&](index_t i,
auto ni) {
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
void gemm_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Using register blocking.
void gemm_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, std::optional< uview< const T, Abi, OC > > C, uview< T, Abi, OD > D, index_t k) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ B⁽ᵀ⁾. Single register block.