9#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
13template <MatrixStructure Struc>
17template <index_t ColsReg, MatrixStructure Struc>
26[[gnu::hot, gnu::flatten]] std::conditional_t<Conf.track_zeros, std::pair<index_t, index_t>,
void>
30 const index_t k)
noexcept {
36 static constexpr auto min_col =
first_column<Conf.struc_C>;
49 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
50 C_reg[ii][jj] = C_cached.load(ii, jj);
53 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
54 C_reg[ii][jj] = simd{0};
61 index_t first_nonzero = -1, last_nonzero = -1;
62 for (index_t l = 0; l < k; ++l) {
66 simd Ail = dl * A_cached.load(ii, l);
67 if constexpr (Conf.track_zeros)
68 all_zero &= all_of(Ail == simd{0});
69 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
70 simd &Cij = C_reg[ii][jj];
71 simd Blj = B_cached.load(l, jj);
72 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
75 if constexpr (Conf.track_zeros)
78 if (first_nonzero < 0)
86 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
87 D_cached.store(C_reg[ii][jj], ii, jj);
89 if constexpr (Conf.track_zeros) {
90 if (first_nonzero < 0)
93 return {first_nonzero, last_nonzero + 1};
106 const index_t I = D.rows(), J = D.cols(), K = A.cols();
112 if constexpr (Conf.struc_C !=
General)
118 constexpr KernelConfig ConfSmall{.negate = Conf.negate, .struc_C = Conf.struc_C};
123 (void)microkernel_sub;
127 const std::optional<uview<const T, Abi, OC>> C_ = C;
132 if (I <= Rows && J <= Cols)
133 return microkernel_small[I - 1][J - 1](A_, B_, C_, D_, d_, K);
145 auto [l0, l1] = [&] {
146 const auto j = Conf.struc_C ==
General ? 0 : i;
147 const auto nj = std::min<index_t>(Cols, J - j);
149 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
150 const auto Dij = D_.
block(i, j);
151 if constexpr (Conf.track_zeros)
152 return microkernel[ni - 1][nj - 1](Ai, Bj, Cij, Dij, d_, K);
154 microkernel[ni - 1][nj - 1](Ai, Bj, Cij, Dij, d_, K);
155 return std::pair<index_t, index_t>{0, K};
160 D.block(i, j0, ni, j1 - j0).set_constant(T{});
161 else if (C->data() == D.data() && C->outer_stride() == D.outer_stride())
163 else if constexpr (OC == StorageOrder::ColMajor)
164 for (index_t jj = j0; jj < j1; ++jj)
165 for (index_t ii = i; ii < i + ni; ++ii)
166 D_.
store(C_->load(ii, jj), ii, jj);
168 for (index_t ii = i; ii < i + ni; ++ii)
169 for (index_t jj = j0; jj < j1; ++jj)
170 D_.
store(C_->load(ii, jj), ii, jj);
176 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
177 const auto Dij = D_.
block(i, j);
178 const auto Ail = Ai.middle_cols(l0);
179 const auto Blj = Bj.middle_rows(l0);
180 const auto dl = d_.
segment(l0);
181 microkernel_sub[ni - 1][nj - 1](Ail, Blj, Cij, Dij, dl, l1 - l0);
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
stdx::simd< Tp, Abi > simd
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
constexpr auto last_column
constexpr auto first_column
const constinit auto gemm_diag_copy_lut
constexpr index_t ColsReg
std::conditional_t< Conf.track_zeros, std::pair< index_t, index_t >, void > gemm_diag_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, std::optional< uview< const T, Abi, OC > > C, uview< T, Abi, OD > D, uview_vec< const T, Abi > diag, index_t k) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Single register block.
void gemm_diag_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D, view< const T, Abi > diag) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Using register blocking.
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
std::integral_constant< index_t, I > index_constant
Self segment(this const Self &self, index_t r) noexcept
Self block(this const Self &self, index_t r, index_t c) noexcept
void store(simd x, index_t r, index_t c) const noexcept
Self middle_rows(this const Self &self, index_t r) noexcept
Self middle_cols(this const Self &self, index_t c) noexcept