batmat 0.0.16
Batched linear algebra routines
Loading...
Searching...
No Matches
trsm.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
7#include <batmat/loop.hpp>
9
10#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
11
13
14/// @param A Lower or upper trapezoidal RowsReg×(k+RowsReg).
15/// @param B RowsReg×ColsReg.
16/// @param D (k+RowsReg)×ColsReg.
17/// @param k Number of columns in the non-triangular part of A.
18template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder OA,
20[[gnu::hot, gnu::flatten]] void
22 const uview<T, Abi, OD> D, const index_t k) noexcept {
23 static_assert(Conf.struc_A == MatrixStructure::LowerTriangular ||
24 Conf.struc_A == MatrixStructure::UpperTriangular);
25 constexpr bool lower = Conf.struc_A == MatrixStructure::LowerTriangular;
26 static_assert(RowsReg > 0 && ColsReg > 0);
27 using namespace ops;
28 using simd = datapar::simd<T, Abi>;
29 // Pre-compute the offsets of the columns/rows of B
30 const auto B_cached = with_cached_access<RowsReg, ColsReg>(B);
31 // Load accumulator into registers
32 simd B_reg[RowsReg][ColsReg]; // NOLINT(*-c-arrays)
33 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
34 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
35 B_reg[ii][jj] = rotl<Conf.rotate_B>(B_cached.load(ii, jj));
36 // Matrix multiplication
37 const auto D_cached = with_cached_access<0, ColsReg>(D);
38 const index_t l0 = lower ? 0 : RowsReg, l1 = lower ? k : k + RowsReg;
39 for (index_t l = l0; l < l1; ++l)
40 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
41 simd Xlj = D_cached.load(l, jj);
42 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
43 simd Ail = A.load(ii, l);
44 simd &Bij = B_reg[ii][jj];
45 Bij -= Ail * Xlj;
46 }
47 }
48 // Triangular solve
49 if constexpr (lower) {
50 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii) {
51 simd Aii = simd{1} / A.load(ii, k + ii);
52 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
53 simd &Xij = B_reg[ii][jj];
54 UNROLL_FOR (index_t ll = 0; ll < ii; ++ll) {
55 simd Ail = A.load(ii, k + ll);
56 simd &Xlj = B_reg[ll][jj];
57 Xij -= Ail * Xlj;
58 }
59 Xij *= Aii; // Diagonal already inverted
60 }
61 }
62 } else {
63 UNROLL_FOR (index_t ii = RowsReg; ii-- > 0;) {
64 simd Aii = simd{1} / A.load(ii, ii);
65 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj) {
66 simd &Xij = B_reg[ii][jj];
67 UNROLL_FOR (index_t ll = ii + 1; ll < RowsReg; ++ll) {
68 simd Ail = A.load(ii, ll);
69 simd &Xlj = B_reg[ll][jj];
70 Xij -= Ail * Xlj;
71 }
72 Xij *= Aii; // Diagonal already inverted
73 }
74 }
75 }
76 // Store accumulator to memory again
77 UNROLL_FOR (index_t ii = 0; ii < RowsReg; ++ii)
78 UNROLL_FOR (index_t jj = 0; jj < ColsReg; ++jj)
79 D_cached.store(B_reg[ii][jj], lower ? k + ii : ii, jj);
80}
81
82/// Triangular solve D = (A⁽ᵀ⁾)⁻¹ B⁽ᵀ⁾ where A⁽ᵀ⁾ is lower triangular. Using register blocking.
83/// Note: D = A⁻¹ B <=> Dᵀ = Bᵀ A⁻ᵀ
84template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OB, StorageOrder OD>
86 const view<T, Abi, OD> D) noexcept {
87 using enum MatrixStructure;
88 static_assert(Conf.struc_A == LowerTriangular || Conf.struc_A == UpperTriangular);
89 constexpr auto Rows = RowsReg<T, Abi>, Cols = ColsReg<T, Abi>;
90 // Check dimensions
91 const index_t I = A.rows(), K = A.cols(), J = B.cols();
92 BATMAT_ASSUME(K >= I);
93 BATMAT_ASSUME(B.rows() == I);
94 BATMAT_ASSUME(D.rows() == K);
95 BATMAT_ASSUME(D.cols() == J);
96 BATMAT_ASSUME(I > 0);
97 BATMAT_ASSUME(J > 0);
98 BATMAT_ASSUME(K > 0);
99 static const auto microkernel = trsm_copy_lut<T, Abi, Conf, OA, OB, OD>;
100 // Sizeless views to partition and pass to the micro-kernels
101 const uview<const T, Abi, OA> A_ = A;
102 const uview<const T, Abi, OB> B_ = B;
103 const uview<T, Abi, OD> D_ = D;
104
105 // Optimization for very small matrices
106 if (I <= Rows && J <= Cols)
107 return microkernel[I - 1][J - 1](A_, B_, D_, 0);
108
109 // Function to compute a single block X(i,j)
110 auto blk = [&] [[gnu::always_inline]] (index_t i, index_t ni, index_t j, index_t nj) {
111 // i iterates backwards from I to 0, because we want to process the remainder block first,
112 // as processing it last would have poor matrix-matrix performance in the microkernel.
113 if constexpr (Conf.struc_A == LowerTriangular) {
114 i = I - i - ni; // iterate forward, smallest chunk first
115 auto Ai0 = A_.middle_rows(i); // subdiagonal block row
116 auto Bij = B_.block(i, j); // rhs block to solve now
117 auto X0j = D_.middle_cols(j); // solution up to i and solution block to fill in
118 microkernel[ni - 1][nj - 1](Ai0, Bij, X0j, i + K - I);
119 } else {
120 auto Ai0 = A_.block(i, i); // superdiagonal block row
121 auto Bij = B_.block(i, j); // rhs block to solve now
122 auto X0j = D_.block(i, j); // solution up to i and solution block to fill in
123 microkernel[ni - 1][nj - 1](Ai0, Bij, X0j, K - i - ni);
124 }
125 };
126 if constexpr (OD == StorageOrder::ColMajor)
127 foreach_chunked_merged( // Loop over block columns of B and D
128 0, J, Cols,
129 [&](index_t j, auto nj) {
130 foreach_chunked_merged( // Loop over the diagonal blocks of A
131 0, I, Rows, [&](index_t i, auto ni) { blk(i, ni, j, nj); }, LoopDir::Backward);
132 },
134 else
135 foreach_chunked_merged( // Loop over the diagonal blocks of A
136 0, I, Rows,
137 [&](index_t i, auto ni) {
138 foreach_chunked_merged( // Loop over block columns of B and D
139 0, J, Cols, [&](index_t j, auto nj) { blk(i, ni, j, nj); }, LoopDir::Forward);
140 },
142}
143
144} // namespace batmat::linalg::micro_kernels::trsm
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:9
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
void trsm_copy_microkernel(uview< const T, Abi, OA > A, uview< const T, Abi, OB > B, uview< T, Abi, OD > D, index_t k) noexcept
Definition trsm.tpp:21
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
void trsm_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OB > B, view< T, Abi, OD > D) noexcept
Triangular solve D = (A⁽ᵀ⁾)⁻¹ B⁽ᵀ⁾ where A⁽ᵀ⁾ is lower triangular.
Definition trsm.tpp:85
const constinit auto trsm_copy_lut
Definition trsm.hpp:30
cached_uview< Order==StorageOrder::ColMajor ? Cols :Rows, T, Abi, Order > with_cached_access(const uview< T, Abi, Order > &o) noexcept
Definition uview.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
Self middle_rows(this const Self &self, index_t r) noexcept
Definition uview.hpp:114
Self middle_cols(this const Self &self, index_t c) noexcept
Definition uview.hpp:118