10#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
21template <MatrixStructure Struc>
25template <index_t ColsReg, MatrixStructure Struc>
34[[gnu::hot, gnu::flatten]] std::conditional_t<Conf.track_zeros, std::pair<index_t, index_t>,
void>
38 const index_t k)
noexcept {
44 static constexpr auto min_col =
first_column<Conf.struc_C>;
57 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
58 C_reg[ii][jj] = C_cached.load(ii, jj);
61 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
62 C_reg[ii][jj] = simd{0};
69 index_t first_nonzero = -1, last_nonzero = -1;
70 for (index_t l = 0; l < k; ++l) {
74 simd Ail = dl * A_cached.load(ii, l);
75 if constexpr (Conf.track_zeros)
76 all_zero &= all_of(Ail == simd{0});
77 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj) {
78 simd &Cij = C_reg[ii][jj];
79 simd Blj = B_cached.load(l, jj);
80 Conf.negate ? (Cij -= Ail * Blj) : (Cij += Ail * Blj);
83 if constexpr (Conf.track_zeros)
86 if (first_nonzero < 0)
94 UNROLL_FOR (index_t jj = min_col(ii); jj <= max_col(ii); ++jj)
95 D_cached.store(C_reg[ii][jj], ii, jj);
97 if constexpr (Conf.track_zeros) {
98 if (first_nonzero < 0)
101 return {first_nonzero, last_nonzero + 1};
114 const index_t I = D.rows(), J = D.cols(), K = A.cols();
120 if constexpr (Conf.struc_C !=
General)
126 constexpr KernelConfig ConfSmall{.negate = Conf.negate, .struc_C = Conf.struc_C};
131 (void)microkernel_sub;
135 const std::optional<uview<const T, Abi, OC>> C_ = C;
140 if (I <= Rows && J <= Cols)
141 return microkernel_small[I - 1][J - 1](A_, B_, C_, D_, d_, K);
153 auto [l0, l1] = [&] {
154 const auto j = Conf.struc_C ==
General ? 0 : i;
155 const auto nj = std::min<index_t>(Cols, J - j);
157 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
158 const auto Dij = D_.
block(i, j);
159 if constexpr (Conf.track_zeros)
160 return microkernel[ni - 1][nj - 1](Ai, Bj, Cij, Dij, d_, K);
162 microkernel[ni - 1][nj - 1](Ai, Bj, Cij, Dij, d_, K);
163 return std::pair<index_t, index_t>{0, K};
168 D.block(i, j0, ni, j1 - j0).set_constant(T{});
169 else if (C->data() == D.data() && C->outer_stride() == D.outer_stride())
171 else if constexpr (OC == StorageOrder::ColMajor)
172 for (index_t jj = j0; jj < j1; ++jj)
173 for (index_t ii = i; ii < i + ni; ++ii)
174 D_.
store(C_->load(ii, jj), ii, jj);
176 for (index_t ii = i; ii < i + ni; ++ii)
177 for (index_t jj = j0; jj < j1; ++jj)
178 D_.
store(C_->load(ii, jj), ii, jj);
184 const auto Cij = C_ ? std::make_optional(C_->block(i, j)) : std::nullopt;
185 const auto Dij = D_.
block(i, j);
186 const auto Ail = Ai.middle_cols(l0);
187 const auto Blj = Bj.middle_rows(l0);
188 const auto dl = d_.
segment(l0);
189 microkernel_sub[ni - 1][nj - 1](Ail, Blj, Cij, Dij, dl, l1 - l0);
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
stdx::simd< Tp, Abi > simd
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
constexpr auto last_column
constexpr auto first_column
const constinit auto gemm_diag_copy_lut
constexpr index_t ColsReg
std::conditional_t< Conf.track_zeros, std::pair< index_t, index_t >, void > gemm_diag_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, std::optional< uview< const T, Abi, OC > > C, uview< T, Abi, OD > D, uview_vec< const T, Abi > diag, index_t k) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Single register block.
void gemm_diag_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, std::optional< view< const T, Abi, OC > > C, view< T, Abi, OD > D, view< const T, Abi > diag) noexcept
Generalized matrix multiplication D = C ± A⁽ᵀ⁾ diag(d) B⁽ᵀ⁾. Using register blocking.
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
std::integral_constant< index_t, I > index_constant
Self segment(this const Self &self, index_t r) noexcept
Self block(this const Self &self, index_t r, index_t c) noexcept
void store(simd x, index_t r, index_t c) const noexcept
Self middle_rows(this const Self &self, index_t r) noexcept
Self middle_cols(this const Self &self, index_t c) noexcept