batmat 0.0.16
Batched linear algebra routines
Loading...
Searching...
No Matches
compress.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
8#include <batmat/unroll.h>
9#include <guanaqo/trace.hpp>
10
11namespace batmat::linalg {
12
13namespace detail {
14
15template <class T, class Abi, index_t N = 8, StorageOrder OAi>
16[[gnu::always_inline]] inline index_t compress_masks_impl(view<const T, Abi, OAi> A_in,
17 view<const T, Abi> S_in, auto writeS,
18 auto writeA) {
20 BATMAT_ASSERT(A_in.depth() == S_in.depth());
21 BATMAT_ASSERT(A_in.cols() == S_in.rows());
22 BATMAT_ASSERT(S_in.cols() == 1);
23 const auto C = A_in.cols();
24 const auto R = A_in.rows();
25 if (C == 0)
26 return 0;
27 BATMAT_ASSUME(R > 0);
29 using simd = typename types::simd;
30 using isimd = typename types::isimd;
31
32 isimd hist[N]{};
33 index_t j = 0;
34 static const isimd iota{[](auto i) { return i; }};
35
36 [[maybe_unused]] const auto commit_fast = [&](auto Sc, index_t c) {
37 writeS(Sc, j);
38 for (index_t r = 0; r < R; ++r) {
39 auto Ac = types::aligned_load(&A_in(0, r, c));
40 writeA(Ac, Sc, r, j);
41 }
42 ++j;
43 };
44 const auto commit_no_shift = [&](auto h_commit) {
45 auto gather_S = [&] {
46 const auto bs = static_cast<index_t>(S_in.batch_size());
47 const isimd offsets = (h_commit - isimd{1}) * bs + iota;
48 return gather<T, Abi>(&S_in(0, 0, 0), offsets, h_commit);
49 }();
50 writeS(gather_S, j);
51 const auto bs = static_cast<index_t>(A_in.batch_size());
52 const auto stride = OAi == StorageOrder::ColMajor ? A_in.outer_stride() : 1;
53 const isimd offsets = (h_commit - isimd{1}) * bs * stride + iota;
54 for (index_t r = 0; r < R; ++r) {
55 auto gather_A = gather<T, Abi>(&A_in(0, r, 0), offsets, h_commit);
56 writeA(gather_A, gather_S, r, j);
57 }
58 ++j;
59 };
60 const auto commit = [&] [[gnu::always_inline]] () {
61 const isimd h = hist[0];
62 BATMAT_FULLY_UNROLLED_FOR (index_t k = 1; k < N; ++k)
63 hist[k - 1] = hist[k];
64 hist[N - 1] = 0;
65 commit_no_shift(h);
66 };
67
68 isimd c1_simd{0};
69 for (index_t c = 0; c < C; ++c) {
70 c1_simd += isimd{1}; // current column index + 1
71 const simd Sc = types::aligned_load(&S_in(0, c, 0));
72 auto Sc_msk = !(Sc == simd{0});
73#if 0
74 if (all_of(Sc_msk)) {
75 commit_fast(Sc, c);
76 continue; // fast path: all nonzero, no gather needed
77 }
78#endif
79 BATMAT_FULLY_UNROLLED_FOR (auto &h : hist) {
80#if BATMAT_WITH_GSI_HPC_SIMD // TODO
81 auto h_ = h;
82 h = isimd{[&](int i) -> index_t { return h[i] == 0 && Sc_msk[i] ? c1_simd[i] : h[i]; }};
83 Sc_msk = decltype(Sc_msk){[&](int i) -> bool { return Sc_msk[i] && h_[i] != 0; }};
84#else
85 const auto msk = (h == 0) && Sc_msk.__cvt();
86 where(msk, h) = c1_simd;
87 Sc_msk = Sc_msk && (!msk).__cvt();
88#endif
89 }
90 // Masks of all ones can already be written to memory
91 bool first_full = none_of(hist[0] == 0);
92 // If there are nonzero elements in the mask left, we need to make room in the buffer
93 bool mask_left = any_of(Sc_msk);
94 if (first_full || mask_left) {
95 commit();
96 assert(any_of(hist[0] == 0)); // at most one commit per iteration is possible
97 // If there are still bits set in the mask.
98 if (mask_left) {
99 // Check if there's an empty slot (always at the end)
100 auto &h = hist[N - 1];
101#if BATMAT_WITH_GSI_HPC_SIMD // TODO
102 h = isimd{[&](int i) -> index_t { return Sc_msk[i] ? c1_simd[i] : h[i]; }};
103#else
104 where(Sc_msk.__cvt(), h) = c1_simd;
105#endif
106 }
107 }
108 // Invariant: first registers in the buffer contain fewest zeros
109 BATMAT_FULLY_UNROLLED_FOR (index_t i = 1; i < N; ++i)
110#if BATMAT_WITH_GSI_HPC_SIMD
111 assert(reduce_count(hist[i] != 0) <= reduce_count(hist[i - 1] != 0));
112#else
113 assert(popcount(hist[i] != 0) <= popcount(hist[i - 1] != 0));
114#endif
115 }
116 BATMAT_FULLY_UNROLLED_FOR (auto &h : hist)
117 if (any_of(h != 0))
118 commit_no_shift(h);
119 return j;
120}
121
122template <class T, class Abi, index_t N = 8>
124 GUANAQO_TRACE("compress_masks_count", 0, S_in.rows() * S_in.depth());
125 const auto C = S_in.rows();
126 if (C == 0)
127 return 0;
129 using simd = typename types::simd;
130 using isimd = typename types::isimd;
131 isimd hist[N]{};
132 index_t j = 0;
133
134 const auto commit = [&] [[gnu::always_inline]] () {
135 BATMAT_FULLY_UNROLLED_FOR (index_t k = 1; k < N; ++k)
136 hist[k - 1] = hist[k];
137 hist[N - 1] = 0;
138 ++j;
139 };
140
141 for (index_t c = 0; c < C; ++c) {
142 const simd Sc = types::aligned_load(&S_in(0, c, 0));
143 auto Sc_msk = !(Sc == simd{0});
144 BATMAT_FULLY_UNROLLED_FOR (auto &h : hist) {
145#if BATMAT_WITH_GSI_HPC_SIMD // TODO
146 auto h_ = h;
147 h = isimd{[&](int i) -> index_t { return h[i] == 0 && Sc_msk[i] ? 1 : h[i]; }};
148 Sc_msk = decltype(Sc_msk){[&](int i) -> bool { return Sc_msk[i] && h_[i] != 0; }};
149#else
150 const auto msk = (h == 0) && Sc_msk.__cvt();
151 where(msk, h) = 1;
152 Sc_msk = Sc_msk && (!msk).__cvt();
153#endif
154 }
155 // Masks of all ones can already be written to memory
156 bool first_full = none_of(hist[0] == 0);
157 // If there are nonzero elements in the mask left, we need to make room in the buffer
158 bool mask_left = any_of(Sc_msk);
159 if (first_full || mask_left) {
160 commit();
161 assert(any_of(hist[0] == 0)); // at most one commit per iteration is possible
162 // If there are still bits set in the mask.
163 if (mask_left) {
164 // Check if there's an empty slot (always at the end)
165 auto &h = hist[N - 1];
166#if BATMAT_WITH_GSI_HPC_SIMD // TODO
167 h = isimd{[&](int i) -> index_t { return Sc_msk[i] ? 1 : h[i]; }};
168#else
169 where(Sc_msk.__cvt(), h) = isimd{1};
170#endif
171 }
172 }
173 // Invariant: first registers in the buffer contain fewest zeros
174 BATMAT_FULLY_UNROLLED_FOR (index_t i = 1; i < N; ++i)
175#if BATMAT_WITH_GSI_HPC_SIMD
176 assert(reduce_count(hist[i] != 0) <= reduce_count(hist[i - 1] != 0));
177#else
178 assert(popcount(hist[i] != 0) <= popcount(hist[i - 1] != 0));
179#endif
180 }
181 BATMAT_FULLY_UNROLLED_FOR (auto &h : hist)
182 if (any_of(h != 0))
183 ++j;
184 return j;
185}
186
187template <class T, class Abi, index_t N = 8, StorageOrder OAi, StorageOrder OAo>
189 view<T, Abi, OAo> A_out, view<T, Abi> S_out) {
190 GUANAQO_TRACE("compress_masks", 0, (A_in.rows() + 1) * A_in.cols() * A_in.depth());
191 BATMAT_ASSERT(A_in.rows() == A_out.rows());
192 BATMAT_ASSERT(A_in.cols() == A_out.cols());
193 BATMAT_ASSERT(A_in.depth() == A_out.depth());
194 BATMAT_ASSERT(S_in.rows() == S_out.rows());
195 BATMAT_ASSERT(S_in.cols() == S_out.cols());
196 BATMAT_ASSERT(S_in.depth() == S_out.depth());
197 auto writeS = [S_out] [[gnu::always_inline]] (auto gather_S, index_t j) {
198 datapar::aligned_store(gather_S, &S_out(0, j, 0));
199 };
200 auto writeA = [A_out] [[gnu::always_inline]] (auto gather_A, auto /*gather_S*/, index_t r,
201 index_t j) {
202 datapar::aligned_store(gather_A, &A_out(0, r, j));
203 };
204 return compress_masks_impl<T, Abi, N, OAi>(A_in, S_in, writeS, writeA);
205}
206
207template <class T, class Abi, index_t N = 8, StorageOrder OAi, StorageOrder OAo>
209 view<T, Abi, OAo> A_out, view<T, Abi> S_sign_out = {}) {
210 GUANAQO_TRACE("compress_masks_sqrt", 0, (A_in.rows() + 1) * A_in.cols() * A_in.depth());
211 BATMAT_ASSERT(A_in.rows() == A_out.rows());
212 BATMAT_ASSERT(A_in.cols() == A_out.cols());
213 BATMAT_ASSERT(A_in.depth() == A_out.depth());
214 BATMAT_ASSERT(S_sign_out.rows() == 0 || A_in.depth() == S_sign_out.depth());
215 BATMAT_ASSERT(S_sign_out.rows() == 0 || A_out.cols() == S_sign_out.rows());
216 BATMAT_ASSERT(S_sign_out.rows() == 0 || S_sign_out.cols() == 1);
217 using std::copysign;
218 using std::fabs;
219 using std::sqrt;
220 auto writeS = [S_sign_out] [[gnu::always_inline]] (auto gather_S, index_t j) {
221 if (S_sign_out.rows() > 0)
222 datapar::aligned_store(copysign(decltype(gather_S){0}, gather_S), &S_sign_out(0, j, 0));
223 };
224 auto writeA = [A_out] [[gnu::always_inline]] (auto gather_A, auto gather_S, index_t r,
225 index_t j) {
226 datapar::aligned_store(sqrt(fabs(gather_S)) * gather_A, &A_out(0, r, j));
227 };
228 return compress_masks_impl<T, Abi, N, OAi>(A_in, S_in, writeS, writeA);
229}
230
231} // namespace detail
232
233/// @addtogroup topic-linalg
234/// @{
235
236/// @name Compression of masks containing zeros
237/// @{
238
239template <index_t N = 8, simdifiable VA, simdifiable VS, simdifiable VAo, simdifiable VSo>
240index_t compress_masks(VA &&Ain, VS &&Sin, VAo &&Aout, VSo &&Sout) {
242 simdify(Ain).as_const(), simdify(Sin).as_const(), simdify(Aout), simdify(Sout));
243}
244
245template <index_t N = 8, simdifiable VS>
246index_t compress_masks_count(VS &&Sin) {
248 simdify(Sin).as_const());
249}
250
251template <index_t N = 8, simdifiable VA, simdifiable VS, simdifiable VAo>
252index_t compress_masks_sqrt(VA &&Ain, VS &&Sin, VAo &&Aout) {
254 simdify(Ain).as_const(), simdify(Sin).as_const(), simdify(Aout));
255}
256
257template <index_t N = 8, simdifiable VA, simdifiable VS, simdifiable VAo, simdifiable VSo>
258index_t compress_masks_sqrt(VA &&Ain, VS &&Sin, VAo &&Aout, VSo &&Sout) {
260 simdify(Ain).as_const(), simdify(Sin).as_const(), simdify(Aout), simdify(Sout));
261}
262
263/// @}
264
265/// @}
266
267} // namespace batmat::linalg
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
index_t compress_masks_count(VS &&Sin)
Definition compress.hpp:246
index_t compress_masks(VA &&Ain, VS &&Sin, VAo &&Aout, VSo &&Sout)
Definition compress.hpp:240
index_t compress_masks_sqrt(VA &&Ain, VS &&Sin, VAo &&Aout)
Definition compress.hpp:252
datapar::simd< T, AbiT > gather(const T *p, datapar::simd< I, AbiI > idx, M mask)
Gathers elements from memory at the addresses specified by idx, which should be an integer SIMD vecto...
Definition gather.hpp:56
#define GUANAQO_TRACE(name, instance,...)
void aligned_store(V v, typename V::value_type *p)
Definition simd.hpp:121
index_t compress_masks_impl(view< const T, Abi, OAi > A_in, view< const T, Abi > S_in, auto writeS, auto writeA)
Definition compress.hpp:16
index_t compress_masks_count(view< const T, Abi > S_in)
Definition compress.hpp:123
index_t compress_masks(view< const T, Abi, OAi > A_in, view< const T, Abi > S_in, view< T, Abi, OAo > A_out, view< T, Abi > S_out)
Definition compress.hpp:188
index_t compress_masks_sqrt(view< const T, Abi, OAi > A_in, view< const T, Abi > S_in, view< T, Abi, OAo > A_out, view< T, Abi > S_sign_out={})
Definition compress.hpp:208
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:204
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:214
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
#define BATMAT_FULLY_UNROLLED_FOR(...)
Definition unroll.h:27