batmat main
Batched linear algebra routines
Loading...
Searching...
No Matches
compress.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
8#include <batmat/simd.hpp>
9#include <batmat/unroll.h>
10#include <guanaqo/trace.hpp>
11
12namespace batmat::linalg {
13
14namespace detail {
15
16template <class T, class Abi, index_t N = 8, StorageOrder OAi>
17[[gnu::always_inline]] inline index_t compress_masks_impl(view<const T, Abi, OAi> A_in,
18 view<const T, Abi> S_in, auto writeS,
19 auto writeA) {
21 BATMAT_ASSERT(A_in.depth() == S_in.depth());
22 BATMAT_ASSERT(A_in.cols() == S_in.rows());
23 BATMAT_ASSERT(S_in.cols() == 1);
24 const auto C = A_in.cols();
25 const auto R = A_in.rows();
26 if (C == 0)
27 return 0;
28 BATMAT_ASSUME(R > 0);
30 using simd = typename types::simd;
31 using isimd = typename types::isimd;
32
33 isimd hist[N]{};
34 index_t j = 0;
35 static const isimd iota{[](auto i) { return i; }};
36
37 [[maybe_unused]] const auto commit_fast = [&](auto Sc, index_t c) {
38 writeS(Sc, j);
39 for (index_t r = 0; r < R; ++r) {
40 auto Ac = types::aligned_load(&A_in(0, r, c));
41 writeA(Ac, Sc, r, j);
42 }
43 ++j;
44 };
45 const auto commit_no_shift = [&](auto h_commit) {
46 auto gather_S = [&] {
47 const auto bs = static_cast<index_t>(S_in.batch_size());
48 const isimd offsets = (h_commit - isimd{1}) * bs + iota;
49 return gather<T, Abi>(&S_in(0, 0, 0), offsets, h_commit);
50 }();
51 writeS(gather_S, j);
52 const auto bs = static_cast<index_t>(A_in.batch_size());
53 const auto stride = OAi == StorageOrder::ColMajor ? A_in.outer_stride() : 1;
54 const isimd offsets = (h_commit - isimd{1}) * bs * stride + iota;
55 for (index_t r = 0; r < R; ++r) {
56 auto gather_A = gather<T, Abi>(&A_in(0, r, 0), offsets, h_commit);
57 writeA(gather_A, gather_S, r, j);
58 }
59 ++j;
60 };
61 const auto commit = [&] [[gnu::always_inline]] () {
62 const isimd h = hist[0];
63 BATMAT_FULLY_UNROLLED_FOR (index_t k = 1; k < N; ++k)
64 hist[k - 1] = hist[k];
65 hist[N - 1] = 0;
66 commit_no_shift(h);
67 };
68
69 isimd c1_simd{0};
70 for (index_t c = 0; c < C; ++c) {
71 c1_simd += isimd{1}; // current column index + 1
72 const simd Sc = types::aligned_load(&S_in(0, c, 0));
73 auto Sc_msk = !(Sc == simd{0});
74#if 0
75 if (all_of(Sc_msk)) {
76 commit_fast(Sc, c);
77 continue; // fast path: all nonzero, no gather needed
78 }
79#endif
80 BATMAT_FULLY_UNROLLED_FOR (auto &h : hist) {
81#if BATMAT_WITH_GSI_HPC_SIMD // TODO
82 auto h_ = h;
83 h = isimd{[&](int i) -> index_t { return h[i] == 0 && Sc_msk[i] ? c1_simd[i] : h[i]; }};
84 Sc_msk = decltype(Sc_msk){[&](int i) -> bool { return Sc_msk[i] && h_[i] != 0; }};
85#else
86 const auto msk = (h == 0) && Sc_msk.__cvt();
87 where(msk, h) = c1_simd;
88 Sc_msk = Sc_msk && (!msk).__cvt();
89#endif
90 }
91 // Masks of all ones can already be written to memory
92 bool first_full = none_of(hist[0] == 0);
93 // If there are nonzero elements in the mask left, we need to make room in the buffer
94 bool mask_left = any_of(Sc_msk);
95 if (first_full || mask_left) {
96 commit();
97 assert(any_of(hist[0] == 0)); // at most one commit per iteration is possible
98 // If there are still bits set in the mask.
99 if (mask_left) {
100 // Check if there's an empty slot (always at the end)
101 auto &h = hist[N - 1];
102#if BATMAT_WITH_GSI_HPC_SIMD // TODO
103 h = isimd{[&](int i) -> index_t { return Sc_msk[i] ? c1_simd[i] : h[i]; }};
104#else
105 where(Sc_msk.__cvt(), h) = c1_simd;
106#endif
107 }
108 }
109 // Invariant: first registers in the buffer contain fewest zeros
110 BATMAT_FULLY_UNROLLED_FOR (index_t i = 1; i < N; ++i) {
112 assert(reduce_count(hist[i] != 0) <= reduce_count(hist[i - 1] != 0));
113 }
114 }
115 BATMAT_FULLY_UNROLLED_FOR (auto &h : hist)
116 if (any_of(h != 0))
117 commit_no_shift(h);
118 return j;
119}
120
121template <class T, class Abi, index_t N = 8>
123 GUANAQO_TRACE("compress_masks_count", 0, S_in.rows() * S_in.depth());
124 const auto C = S_in.rows();
125 if (C == 0)
126 return 0;
128 using simd = typename types::simd;
129 using isimd = typename types::isimd;
130 isimd hist[N]{};
131 index_t j = 0;
132
133 const auto commit = [&] [[gnu::always_inline]] () {
134 BATMAT_FULLY_UNROLLED_FOR (index_t k = 1; k < N; ++k)
135 hist[k - 1] = hist[k];
136 hist[N - 1] = 0;
137 ++j;
138 };
139
140 for (index_t c = 0; c < C; ++c) {
141 const simd Sc = types::aligned_load(&S_in(0, c, 0));
142 auto Sc_msk = !(Sc == simd{0});
143 BATMAT_FULLY_UNROLLED_FOR (auto &h : hist) {
144#if BATMAT_WITH_GSI_HPC_SIMD // TODO
145 auto h_ = h;
146 h = isimd{[&](int i) -> index_t { return h[i] == 0 && Sc_msk[i] ? 1 : h[i]; }};
147 Sc_msk = decltype(Sc_msk){[&](int i) -> bool { return Sc_msk[i] && h_[i] != 0; }};
148#else
149 const auto msk = (h == 0) && Sc_msk.__cvt();
150 where(msk, h) = 1;
151 Sc_msk = Sc_msk && (!msk).__cvt();
152#endif
153 }
154 // Masks of all ones can already be written to memory
155 bool first_full = none_of(hist[0] == 0);
156 // If there are nonzero elements in the mask left, we need to make room in the buffer
157 bool mask_left = any_of(Sc_msk);
158 if (first_full || mask_left) {
159 commit();
160 assert(any_of(hist[0] == 0)); // at most one commit per iteration is possible
161 // If there are still bits set in the mask.
162 if (mask_left) {
163 // Check if there's an empty slot (always at the end)
164 auto &h = hist[N - 1];
165#if BATMAT_WITH_GSI_HPC_SIMD // TODO
166 h = isimd{[&](int i) -> index_t { return Sc_msk[i] ? 1 : h[i]; }};
167#else
168 where(Sc_msk.__cvt(), h) = isimd{1};
169#endif
170 }
171 }
172 // Invariant: first registers in the buffer contain fewest zeros
173 BATMAT_FULLY_UNROLLED_FOR (index_t i = 1; i < N; ++i) {
175 assert(reduce_count(hist[i] != 0) <= reduce_count(hist[i - 1] != 0));
176 }
177 }
178 BATMAT_FULLY_UNROLLED_FOR (auto &h : hist)
179 if (any_of(h != 0))
180 ++j;
181 return j;
182}
183
184template <class T, class Abi, index_t N = 8, StorageOrder OAi, StorageOrder OAo>
186 view<T, Abi, OAo> A_out, view<T, Abi> S_out) {
187 GUANAQO_TRACE("compress_masks", 0, (A_in.rows() + 1) * A_in.cols() * A_in.depth());
188 BATMAT_ASSERT(A_in.rows() == A_out.rows());
189 BATMAT_ASSERT(A_in.cols() == A_out.cols());
190 BATMAT_ASSERT(A_in.depth() == A_out.depth());
191 BATMAT_ASSERT(S_in.rows() == S_out.rows());
192 BATMAT_ASSERT(S_in.cols() == S_out.cols());
193 BATMAT_ASSERT(S_in.depth() == S_out.depth());
194 auto writeS = [S_out] [[gnu::always_inline]] (auto gather_S, index_t j) {
195 datapar::aligned_store(gather_S, &S_out(0, j, 0));
196 };
197 auto writeA = [A_out] [[gnu::always_inline]] (auto gather_A, auto /*gather_S*/, index_t r,
198 index_t j) {
199 datapar::aligned_store(gather_A, &A_out(0, r, j));
200 };
201 return compress_masks_impl<T, Abi, N, OAi>(A_in, S_in, writeS, writeA);
202}
203
204template <class T, class Abi, index_t N = 8, StorageOrder OAi, StorageOrder OAo>
206 view<T, Abi, OAo> A_out, view<T, Abi> S_sign_out = {}) {
207 GUANAQO_TRACE("compress_masks_sqrt", 0, (A_in.rows() + 1) * A_in.cols() * A_in.depth());
208 BATMAT_ASSERT(A_in.rows() == A_out.rows());
209 BATMAT_ASSERT(A_in.cols() == A_out.cols());
210 BATMAT_ASSERT(A_in.depth() == A_out.depth());
211 BATMAT_ASSERT(S_sign_out.rows() == 0 || A_in.depth() == S_sign_out.depth());
212 BATMAT_ASSERT(S_sign_out.rows() == 0 || A_out.cols() == S_sign_out.rows());
213 BATMAT_ASSERT(S_sign_out.rows() == 0 || S_sign_out.cols() == 1);
214 using std::copysign;
215 using std::fabs;
216 using std::sqrt;
217 auto writeS = [S_sign_out] [[gnu::always_inline]] (auto gather_S, index_t j) {
218 if (S_sign_out.rows() > 0)
219 datapar::aligned_store(copysign(decltype(gather_S){0}, gather_S), &S_sign_out(0, j, 0));
220 };
221 auto writeA = [A_out] [[gnu::always_inline]] (auto gather_A, auto gather_S, index_t r,
222 index_t j) {
223 datapar::aligned_store(sqrt(fabs(gather_S)) * gather_A, &A_out(0, r, j));
224 };
225 return compress_masks_impl<T, Abi, N, OAi>(A_in, S_in, writeS, writeA);
226}
227
228} // namespace detail
229
230/// @addtogroup topic-linalg
231/// @{
232
233/// @name Compression of masks containing zeros
234/// @{
235
236template <index_t N = 8, simdifiable VA, simdifiable VS, simdifiable VAo, simdifiable VSo>
237index_t compress_masks(VA &&Ain, VS &&Sin, VAo &&Aout, VSo &&Sout) {
239 simdify(Ain).as_const(), simdify(Sin).as_const(), simdify(Aout), simdify(Sout));
240}
241
242template <index_t N = 8, simdifiable VS>
247
248template <index_t N = 8, simdifiable VA, simdifiable VS, simdifiable VAo>
249index_t compress_masks_sqrt(VA &&Ain, VS &&Sin, VAo &&Aout) {
251 simdify(Ain).as_const(), simdify(Sin).as_const(), simdify(Aout));
252}
253
254template <index_t N = 8, simdifiable VA, simdifiable VS, simdifiable VAo, simdifiable VSo>
255index_t compress_masks_sqrt(VA &&Ain, VS &&Sin, VAo &&Aout, VSo &&Sout) {
257 simdify(Ain).as_const(), simdify(Sin).as_const(), simdify(Aout), simdify(Sout));
258}
259
260/// @}
261
262/// @}
263
264} // namespace batmat::linalg
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
int index_t
Definition config.hpp:13
index_t compress_masks_count(VS &&Sin)
Definition compress.hpp:243
index_t compress_masks(VA &&Ain, VS &&Sin, VAo &&Aout, VSo &&Sout)
Definition compress.hpp:237
index_t compress_masks_sqrt(VA &&Ain, VS &&Sin, VAo &&Aout)
Definition compress.hpp:249
datapar::simd< T, AbiT > gather(const T *p, datapar::simd< I, AbiI > idx, M mask)
Gathers elements from memory at the addresses specified by idx, which should be an integer SIMD vecto...
Definition gather.hpp:56
#define GUANAQO_TRACE(name, instance,...)
void aligned_store(V v, typename V::value_type *p)
Definition simd.hpp:124
auto reduce_count(auto v)
Definition simd.hpp:151
index_t compress_masks_impl(view< const T, Abi, OAi > A_in, view< const T, Abi > S_in, auto writeS, auto writeA)
Definition compress.hpp:17
index_t compress_masks_count(view< const T, Abi > S_in)
Definition compress.hpp:122
index_t compress_masks(view< const T, Abi, OAi > A_in, view< const T, Abi > S_in, view< T, Abi, OAo > A_out, view< T, Abi > S_out)
Definition compress.hpp:185
index_t compress_masks_sqrt(view< const T, Abi, OAi > A_in, view< const T, Abi > S_in, view< T, Abi, OAo > A_out, view< T, Abi > S_sign_out={})
Definition compress.hpp:205
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:216
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
int index_t
Definition config.hpp:13
#define BATMAT_FULLY_UNROLLED_FOR(...)
Definition unroll.h:27