batmat main
Batched linear algebra routines
Loading...
Searching...
No Matches
trsm.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
7#include <batmat/loop.hpp>
8#include <batmat/lut.hpp>
10
11#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
12
14
15template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OD>
17 []<index_t Row, index_t Col>(index_constant<Row>, index_constant<Col>) {
19 });
20
21/// @param A Lower or upper trapezoidal RowsReg×(k+RowsReg).
22/// @param B RowsReg×ColsReg.
23/// @param D (k+RowsReg)×ColsReg.
24/// @param k Number of columns in the non-triangular part of A.
25template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder OA,
27[[gnu::hot, gnu::flatten]] void
29 const uview<T, Abi, OD> D, const index_t k) noexcept {
30 static_assert(Conf.struc_A == MatrixStructure::LowerTriangular ||
31 Conf.struc_A == MatrixStructure::UpperTriangular);
32 constexpr bool lower = Conf.struc_A == MatrixStructure::LowerTriangular;
33 static_assert(RowsReg > 0 && ColsReg > 0);
34 using namespace ops;
35 using simd = datapar::simd<T, Abi>;
36 // Pre-compute the offsets of the columns/rows of B
37 const auto B_cached = with_cached_access<RowsReg, ColsReg>(B);
38 // Load accumulator into registers
39 simd B_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
40 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
41 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
42 B_reg[ii][jj] = rotl<Conf.rotate_B>(B_cached.load(ii, jj));
43 // Matrix multiplication
44 const auto D_cached = with_cached_access<0, ColsReg>(D);
45 const index_t l0 = lower ? 0 : RowsReg, l1 = lower ? k : k + RowsReg;
46 for (index_t l = l0; l < l1; ++l)
47 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
48 simd Xlj = D_cached.load(l, jj);
49 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
50 simd Ail = A.load(ii, l);
51 simd &Bij = B_reg[ii][jj];
52 Bij -= Ail * Xlj;
53 }
54 }
55 // Triangular solve
56 if constexpr (lower) {
57 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
58 simd Aii = simd{1} / A.load(ii, k + ii);
59 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
60 simd &Xij = B_reg[ii][jj];
61 UNROLL_FOR (index_t ll = 0; ll < ii; ++ll) {
62 simd Ail = A.load(ii, k + ll);
63 simd &Xlj = B_reg[ll][jj];
64 Xij -= Ail * Xlj;
65 }
66 Xij *= Aii; // Diagonal already inverted
67 }
68 }
69 } else {
70 UNROLL_FOR (index_t ii = RowsReg; ii-- > 0;) {
71 simd Aii = simd{1} / A.load(ii, ii);
72 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
73 simd &Xij = B_reg[ii][jj];
74 UNROLL_FOR (index_t ll = ii + 1; ll < RowsReg; ++ll) {
75 simd Ail = A.load(ii, ll);
76 simd &Xlj = B_reg[ll][jj];
77 Xij -= Ail * Xlj;
78 }
79 Xij *= Aii; // Diagonal already inverted
80 }
81 }
82 }
83 // Store accumulator to memory again
84 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
85 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
86 D_cached.store(B_reg[ii][jj], lower ? k + ii : ii, jj);
87}
88
89/// Triangular solve D = (A⁽ᵀ⁾)⁻¹ B⁽ᵀ⁾ where A⁽ᵀ⁾ is lower triangular. Using register blocking.
90/// Note: D = A⁻¹ B <=> Dᵀ = Bᵀ A⁻ᵀ
91template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OD>
93 const view<T, Abi, OD> D) noexcept {
94 using enum MatrixStructure;
95 static_assert(Conf.struc_A == LowerTriangular || Conf.struc_A == UpperTriangular);
96 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
97 // Check dimensions
98 const index_t I = A.rows(), K = A.cols(), J = B.cols();
99 BATMAT_ASSUME(K >= I);
100 BATMAT_ASSUME(B.rows() == I);
101 BATMAT_ASSUME(D.rows() == K);
102 BATMAT_ASSUME(D.cols() == J);
103 BATMAT_ASSUME(I > 0);
104 BATMAT_ASSUME(J > 0);
105 BATMAT_ASSUME(K > 0);
106 static const auto microkernel = trsm_copy_lut<T, Abi, Conf, OA, OB, OD>;
107 // Sizeless views to partition and pass to the micro-kernels
108 const uview<const T, Abi, OA> A_ = A;
109 const uview<const T, Abi, OB> B_ = B;
110 const uview<T, Abi, OD> D_ = D;
111
112 // Optimization for very small matrices
113 if (I <= Rows && J <= Cols)
114 return microkernel[I - 1][J - 1](A_, B_, D_, 0);
115
116 // Function to compute a single block X(i,j)
117 auto blk = [&] [[gnu::always_inline]] (index_t i, index_t ni, index_t j, index_t nj) {
118 // i iterates backwards from I to 0, because we want to process the remainder block first,
119 // as processing it last would have poor matrix-matrix performance in the microkernel.
120 if constexpr (Conf.struc_A == LowerTriangular) {
121 i = I - i - ni; // iterate forward, smallest chunk first
122 auto Ai0 = A_.middle_rows(i); // subdiagonal block row
123 auto Bij = B_.block(i, j); // rhs block to solve now
124 auto X0j = D_.middle_cols(j); // solution up to i and solution block to fill in
125 microkernel[ni - 1][nj - 1](Ai0, Bij, X0j, i + K - I);
126 } else {
127 auto Ai0 = A_.block(i, i); // superdiagonal block row
128 auto Bij = B_.block(i, j); // rhs block to solve now
129 auto X0j = D_.block(i, j); // solution up to i and solution block to fill in
130 microkernel[ni - 1][nj - 1](Ai0, Bij, X0j, K - i - ni);
131 }
132 };
133 if constexpr (OD == StorageOrder::ColMajor)
134 foreach_chunked_merged( // Loop over block columns of B and D
135 0, J, Cols,
136 [&](index_t j, auto nj) {
137 foreach_chunked_merged( // Loop over the diagonal blocks of A
138 0, I, Rows, [&](index_t i, auto ni) { blk(i, ni, j, nj); }, LoopDir::Backward);
139 },
141 else
142 foreach_chunked_merged( // Loop over the diagonal blocks of A
143 0, I, Rows,
144 [&](index_t i, auto ni) {
145 foreach_chunked_merged( // Loop over block columns of B and D
146 0, J, Cols, [&](index_t j, auto nj) { blk(i, ni, j, nj); }, LoopDir::Forward);
147 },
149}
150
151} // namespace batmat::linalg::micro_kernels::trsm
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
void trsm_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, uview< T, Abi, OD > D, index_t k) noexcept
Definition trsm.tpp:28
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
void trsm_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, view< T, Abi, OD > D) noexcept
Triangular solve D = (A⁽ᵀ⁾)⁻¹ B⁽ᵀ⁾ where A⁽ᵀ⁾ is lower triangular.
Definition trsm.tpp:92
const constinit auto trsm_copy_lut
Definition trsm.tpp:16
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114
Self middle_cols(this const Self &self, index_t c) noexcept
Definition uview.hpp:118