batmat 0.0.17
Batched linear algebra routines
Loading...
Searching...
No Matches
rotate.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/simd.hpp>
4#include <cassert>
5
6namespace batmat::ops {
7
8namespace detail {
9
10/*
11 * Note: The return types -> F are crucial to avoid a bug in GCC. https://godbolt.org/z/Phfvsq7hY
12 */
13
14/// Rotate the elements of @p x to the right by @p s positions.
15/// For example, `rotr<1>([x0, x1, x2, x3]) == [x3, x0, x1, x2]`
16/// and `rotr<-1>([x0, x1, x2, x3]) == [x1, x2, x3, x0]`.
17template <class F, class Abi>
18[[gnu::always_inline]] inline datapar::simd<F, Abi> rot(datapar::simd<F, Abi> x, int s) {
19 assert(s <= 0 || static_cast<size_t>(+s) < x.size());
20 assert(s >= 0 || static_cast<size_t>(-s) < x.size());
21 const int n = x.size();
22 return datapar::simd<F, Abi>{[&](int j) -> F { return x[(n + j - s) % n]; }};
23}
24
25template <int S, class F, class Abi>
26[[gnu::always_inline]] inline datapar::simd<F, Abi> rotl(datapar::simd<F, Abi> x) {
27 static_assert(S > 0 && S < x.size());
28 const int n = x.size();
29 return datapar::simd<F, Abi>{[&](int j) -> F { return x[(j + S) % n]; }};
30}
31
32template <int S, class F, class Abi>
33[[gnu::always_inline]] inline datapar::simd<F, Abi> rotr(datapar::simd<F, Abi> x) {
34 static_assert(S > 0 && S < x.size());
35 const int n = x.size();
36 return datapar::simd<F, Abi>{[&](int j) -> F { return x[(n + j - S) % n]; }};
37}
38
39template <int S, class F, class Abi>
40[[gnu::always_inline]] inline datapar::simd<F, Abi> shiftl(datapar::simd<F, Abi> x) {
41 static_assert(S > 0 && S < x.size());
42 const int n = x.size();
43 return datapar::simd<F, Abi>{[&](int j) -> F { return j + S < n ? x[j + S] : F{}; }};
44}
45
46template <int S, class F, class Abi>
47[[gnu::always_inline]] inline datapar::simd<F, Abi> shiftr(datapar::simd<F, Abi> x) {
48 static_assert(S > 0 && S < x.size());
49 return datapar::simd<F, Abi>{[&](int j) -> F { return j >= S ? x[j - S] : F{}; }};
50}
51
52#if defined(__AVX512F__)
53
54[[gnu::always_inline]] inline auto rot(datapar::deduced_simd<double, 8> x, int s) {
55 assert(s <= 0 || static_cast<size_t>(+s) < x.size());
56 assert(s >= 0 || static_cast<size_t>(-s) < x.size());
57 constexpr size_t N = x.size();
58 static constinit std::array<int64_t, 2 * N - 1> indices_lut = [] {
59 std::array<int64_t, 2 * N - 1> lut{};
60 for (size_t i = 0; i < 2 * N - 1; ++i)
61 lut[i] = static_cast<int64_t>((i + 1) % N);
62 return lut;
63 }();
64 // rot(+1, [0, 1, 2, 3, 4, 5, 6, 7]) == [7, 0, 1, 2, 3, 4, 5, 6]
65 // rot(+2, [0, 1, 2, 3, 4, 5, 6, 7]) == [6, 7, 0, 1, 2, 3, 4, 5]
66 // rot(+7, [0, 1, 2, 3, 4, 5, 6, 7]) == [1, 2, 3, 4, 5, 6, 7, 0]
67 //
68 // rot(-1, [0, 1, 2, 3, 4, 5, 6, 7]) == [1, 2, 3, 4, 5, 6, 7, 0]
69 // rot(-2, [0, 1, 2, 3, 4, 5, 6, 7]) == [2, 3, 4, 5, 6, 7, 0, 1]
70 // rot(-7, [0, 1, 2, 3, 4, 5, 6, 7]) == [7, 0, 1, 2, 3, 4, 5, 6]
71 //
72 // [1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]
73 // 0
74 // 1
75 // -7
76 // 2
77 // -6
78 // 7
79 // -1
80 static constinit const int64_t *p = indices_lut.data() + N - 1;
81 if (s < 0)
82 s += N;
83 const __m512i indices = _mm512_loadu_epi64(p - s);
84 __m512d y = _mm512_permutexvar_pd(indices, static_cast<__m512d>(x));
85 return decltype(x){y};
86}
87
88[[gnu::always_inline]] inline auto rot(datapar::deduced_simd<double, 4> x, int s) {
89 assert(s <= 0 || static_cast<size_t>(+s) < x.size());
90 assert(s >= 0 || static_cast<size_t>(-s) < x.size());
91 constexpr size_t N = x.size();
92 static constinit std::array<int64_t, 2 * N - 1> indices_lut = [] {
93 std::array<int64_t, 2 * N - 1> lut{};
94 for (size_t i = 0; i < 2 * N - 1; ++i)
95 lut[i] = static_cast<int64_t>((i + 1) % N);
96 return lut;
97 }();
98 static constinit const int64_t *p = indices_lut.data() + N - 1;
99 if (s < 0)
100 s += N;
101 const __m256i indices = _mm256_loadu_epi64(p - s);
102 __m256d y = _mm256_permutexvar_pd(indices, static_cast<__m256d>(x));
103 return decltype(x){y};
104}
105
106template <int S>
107[[gnu::always_inline]] inline auto rotl(datapar::deduced_simd<double, 8> x) {
108 static_assert(S > 0 && S < x.size());
109 constexpr size_t N = x.size();
110 const __m512i indices = _mm512_set_epi64((S + 7) % N, (S + 6) % N, (S + 5) % N, (S + 4) % N,
111 (S + 3) % N, (S + 2) % N, (S + 1) % N, S % N);
112 __m512d y = _mm512_permutexvar_pd(indices, static_cast<__m512d>(x));
113 return decltype(x){y};
114}
115
116template <int S>
117[[gnu::always_inline]] inline auto rotr(datapar::deduced_simd<double, 8> x) {
118 static_assert(S > 0 && S < x.size());
119 constexpr size_t N = x.size();
120 const __m512i indices =
121 _mm512_set_epi64((N - S + 7) % N, (N - S + 6) % N, (N - S + 5) % N, (N - S + 4) % N,
122 (N - S + 3) % N, (N - S + 2) % N, (N - S + 1) % N, (N - S) % N);
123 __m512d y = _mm512_permutexvar_pd(indices, static_cast<__m512d>(x));
124 return decltype(x){y};
125}
126
127template <int S>
128[[gnu::always_inline]] inline auto shiftl(datapar::deduced_simd<double, 8> x) {
129 static_assert(S > 0 && S < x.size());
130 constexpr uint8_t mask = (1u << (x.size() - S)) - 1u;
131 auto y = static_cast<__m512d>(rotl<S>(x));
132 y = _mm512_mask_blend_pd(mask, _mm512_set1_pd(0), y);
133 return decltype(x){y};
134}
135
136template <int S>
137[[gnu::always_inline]] inline auto shiftr(datapar::deduced_simd<double, 8> x) {
138 static_assert(S > 0 && S < x.size());
139 constexpr uint8_t mask = (1u << S) - 1u;
140 auto y = static_cast<__m512d>(rotr<S>(x));
141 y = _mm512_mask_blend_pd(mask, y, _mm512_set1_pd(0));
142 return decltype(x){y};
143}
144
145#endif
146
147#if defined(__AVX2__)
148
149template <int S>
150[[gnu::always_inline]] inline auto shiftl(datapar::deduced_simd<double, 4> x) {
151 static_assert(S > 0 && S < x.size());
152 constexpr uint8_t mask = (1u << (x.size() - S)) - 1u;
153 auto y = static_cast<__m256d>(rotl<S>(x));
154 y = _mm256_blend_pd(_mm256_set1_pd(0), y, mask);
155 return decltype(x){y};
156}
157
158template <int S>
159[[gnu::always_inline]] inline auto shiftr(datapar::deduced_simd<double, 4> x) {
160 static_assert(S > 0 && S < x.size());
161 constexpr uint8_t mask = (1u << S) - 1u;
162 auto y = static_cast<__m256d>(rotr<S>(x));
163 y = _mm256_blend_pd(y, _mm256_set1_pd(0), mask);
164 return decltype(x){y};
165}
166
167#endif
168
169#if defined(__AVX512F__)
170
171template <int S>
172[[gnu::always_inline]] inline auto rotl(datapar::deduced_simd<double, 4> x) {
173 static_assert(S > 0 && S < x.size());
174 constexpr size_t N = x.size();
175 const __m256i indices = _mm256_set_epi64x((S + 3) % N, (S + 2) % N, (S + 1) % N, S % N);
176 __m256d y = _mm256_permutexvar_pd(indices, static_cast<__m256d>(x));
177 return decltype(x){y};
178}
179
180template <int S>
181[[gnu::always_inline]] inline auto rotr(datapar::deduced_simd<double, 4> x) {
182 static_assert(S > 0 && S < x.size());
183 constexpr size_t N = x.size();
184 const __m256i indices =
185 _mm256_set_epi64x((N - S + 3) % N, (N - S + 2) % N, (N - S + 1) % N, (N - S) % N);
186 __m256d y = _mm256_permutexvar_pd(indices, static_cast<__m256d>(x));
187 return decltype(x){y};
188}
189
190#elif defined(__AVX2__)
191
192template <int S>
193[[gnu::always_inline]] inline auto rotl(datapar::deduced_simd<double, 4> x) {
194 static_assert(S > 0 && S < x.size());
195 constexpr size_t N = x.size();
196 constexpr int indices =
197 (((S + 3) % N) << 6) | (((S + 2) % N) << 4) | (((S + 1) % N) << 2) | (S % N);
198 __m256d y = _mm256_permute4x64_pd(static_cast<__m256d>(x), indices);
199 return decltype(x){y};
200}
201
202template <int S>
203[[gnu::always_inline]] inline auto rotr(datapar::deduced_simd<double, 4> x) {
204 static_assert(S > 0 && S < x.size());
205 constexpr size_t N = x.size();
206 constexpr int indices = (((N - S + 3) % N) << 6) | (((N - S + 2) % N) << 4) |
207 (((N - S + 1) % N) << 2) | ((N - S) % N);
208 __m256d y = _mm256_permute4x64_pd(static_cast<__m256d>(x), indices);
209 return decltype(x){y};
210}
211
212#endif
213
214} // namespace detail
215
216/// @addtogroup topic-low-level-ops
217/// @{
218
219/// @name Lane-wise rotations of SIMD vectors
220/// @{
221
222/// Rotates the elements of @p x by @p s positions to the left.
223/// For example, `rotl<1>([x0, x1, x2, x3]) == [x1, x2, x3, x0]`
224/// and `rotl<-1>([x0, x1, x2, x3]) == [x3, x0, x1, x2]`.
225template <int S, class F, class Abi>
226[[gnu::always_inline]] inline datapar::simd<F, Abi> rotl(datapar::simd<F, Abi> x) {
227 if constexpr (S % x.size() == 0)
228 return x;
229 else if constexpr (S < 0)
230 return detail::rotr<-S>(x);
231 else
232 return detail::rotl<S>(x);
233}
234
235/// Rotate the elements of @p x to the right by @p S positions.
236/// For example, `rotr<1>([x0, x1, x2, x3]) == [x3, x0, x1, x2]`
237/// and `rotr<-1>([x0, x1, x2, x3]) == [x1, x2, x3, x0]`.
238template <int S, class F, class Abi>
239[[gnu::always_inline]] inline datapar::simd<F, Abi> rotr(datapar::simd<F, Abi> x) {
240 if constexpr (S % x.size() == 0)
241 return x;
242 else if constexpr (S < 0)
243 return detail::rotl<-S>(x);
244 else
245 return detail::rotr<S>(x);
246}
247
248/// Shift the elements of @p x to the left by @p S positions, shifting in zeros.
249/// For example, `shiftl<1>([x0, x1, x2, x3]) == [x1, x2, x3, 0]`
250/// and `shiftl<-1>([x0, x1, x2, x3]) == [0, x0, x1, x2]`.
251template <int S, class F, class Abi>
252[[gnu::always_inline]] inline datapar::simd<F, Abi> shiftl(datapar::simd<F, Abi> x) {
253 if constexpr (S == 0)
254 return x;
255 else if constexpr (S >= static_cast<int>(x.size()) || -S >= static_cast<int>(x.size()))
256 return datapar::simd<F, Abi>{0};
257 else if constexpr (S < 0)
258 return detail::shiftr<-S>(x);
259 else
260 return detail::shiftl<S>(x);
261}
262
263/// Shift the elements of @p x to the right by @p S positions, shifting in zeros.
264/// For example, `shiftr<1>([x0, x1, x2, x3]) == [0, x0, x1, x2]`
265/// and `shiftr<-1>([x0, x1, x2, x3]) == [x1, x2, x3, 0]`.
266template <int S, class F, class Abi>
267[[gnu::always_inline]] inline datapar::simd<F, Abi> shiftr(datapar::simd<F, Abi> x) {
268 if constexpr (S == 0)
269 return x;
270 else if constexpr (S >= static_cast<int>(x.size()) || -S >= static_cast<int>(x.size()))
271 return datapar::simd<F, Abi>{0};
272 else if constexpr (S < 0)
273 return detail::shiftl<-S>(x);
274 else
275 return detail::shiftr<S>(x);
276}
277
278using detail::rot;
279
280/// @}
281
282/// @}
283
284} // namespace batmat::ops
datapar::simd< F, Abi > rotl(datapar::simd< F, Abi > x)
Rotates the elements of x by s positions to the left.
Definition rotate.hpp:226
datapar::simd< F, Abi > rotr(datapar::simd< F, Abi > x)
Rotate the elements of x to the right by S positions.
Definition rotate.hpp:239
datapar::simd< F, Abi > shiftl(datapar::simd< F, Abi > x)
Shift the elements of x to the left by S positions, shifting in zeros.
Definition rotate.hpp:252
datapar::simd< F, Abi > shiftr(datapar::simd< F, Abi > x)
Shift the elements of x to the right by S positions, shifting in zeros.
Definition rotate.hpp:267
simd< Tp, deduced_abi< Tp, Np > > deduced_simd
Definition simd.hpp:103
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
datapar::simd< F, Abi > shiftl(datapar::simd< F, Abi > x)
Definition rotate.hpp:40
datapar::simd< F, Abi > rotr(datapar::simd< F, Abi > x)
Definition rotate.hpp:33
datapar::simd< F, Abi > rotl(datapar::simd< F, Abi > x)
Definition rotate.hpp:26
datapar::simd< F, Abi > shiftr(datapar::simd< F, Abi > x)
Definition rotate.hpp:47
datapar::simd< F, Abi > rot(datapar::simd< F, Abi > x, int s)
Rotate the elements of x to the right by s positions.
Definition rotate.hpp:18