batmat 0.0.18
Batched linear algebra routines
Loading...
Searching...
No Matches
gather.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/ops/mask.hpp>
4#include <batmat/simd.hpp>
5#include <concepts>
6#include <type_traits>
7
8namespace batmat::ops::detail {
9
10template <class T, class AbiT, class I, class AbiI>
11[[gnu::always_inline]] inline datapar::simd<T, AbiT>
13 datapar::simd<I, AbiI> vindex, const T *base_addr) {
14 return datapar::simd<T, AbiT>{[=](auto i) { return mask[i] ? base_addr[vindex[i]] : src[i]; }};
15}
16
17template <class T1, class T2>
19 std::integral<T1> && std::integral<T2> && sizeof(T1) == sizeof(T2) && !std::same_as<T1, T2>;
20
21template <std::integral I>
22auto convert_int(I i) {
23 if constexpr (std::is_signed_v<I> && sizeof(I) == sizeof(int32_t)) {
24 return static_cast<int32_t>(i);
25 } else if constexpr (std::is_unsigned_v<I> && sizeof(I) == sizeof(int32_t)) {
26 return static_cast<uint32_t>(i);
27 } else if constexpr (std::is_signed_v<I> && sizeof(I) == sizeof(int64_t)) {
28 return static_cast<int64_t>(i);
29 } else if constexpr (std::is_unsigned_v<I> && sizeof(I) == sizeof(int64_t)) {
30 return static_cast<uint64_t>(i);
31 }
32}
33template <std::integral I>
34using convert_int_t = decltype(convert_int(std::declval<I>()));
35
36} // namespace batmat::ops::detail
37
38#if defined(__AVX512F__)
39#include <batmat/ops/avx-512/gather.hpp>
40#elif defined(__AVX2__)
41#include <batmat/ops/avx2/gather.hpp>
42#endif
43
44namespace batmat::ops {
45
46/// @addtogroup topic-low-level-ops
47/// @{
48
49/// @name Gathering elements from memory
50/// @{
51
52/// Gathers elements from memory at the addresses specified by @p idx, which should be an integer
53/// SIMD vector, and returns them in a SIMD vector of type `datapar::simd<T, AbiT>`. The elements
54/// are gathered relative to the base address @p p. The gathering is masked by @p mask,
55template <class T, class AbiT, class I, class AbiI, class M>
56[[gnu::always_inline]] inline datapar::simd<T, AbiT> gather(const T *p, datapar::simd<I, AbiI> idx,
57 M mask) {
58 using simd = datapar::simd<T, AbiT>;
61#if BATMAT_WITH_GSI_HPC_SIMD // TODO
62 auto mask_ = detail::convert_mask<T, AbiT>(std::bit_cast<msimd>(mask));
63 const auto idx_ = std::bit_cast<isimd>(idx);
64#else
65 auto mask_ = detail::convert_mask<T, AbiT>(simd_cast<msimd>(mask));
66 const auto idx_ = simd_cast<isimd>(idx);
67#endif
68 return detail::gather(simd{}, mask_, idx_, p);
69}
70
71/// @}
72
73/// @}
74
75} // namespace batmat::ops
datapar::simd< T, AbiT > gather(const T *p, datapar::simd< I, AbiI > idx, M mask)
Gathers elements from memory at the addresses specified by idx, which should be an integer SIMD vecto...
Definition gather.hpp:56
stdx::rebind_simd_t< T, V > rebind_simd_t
Definition simd.hpp:141
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
datapar::simd< T, AbiT > gather(datapar::simd< T, AbiT > src, typename datapar::simd< T, AbiT >::mask_type mask, datapar::simd< I, AbiI > vindex, const T *base_addr)
Definition gather.hpp:12
auto convert_int(I i)
Definition gather.hpp:22
mask_type_t< T, AbiT > convert_mask(M mask)
Convert a SIMD mask to the appropriate intrinsic type.
Definition mask.hpp:21
decltype(convert_int(std::declval< I >())) convert_int_t
Definition gather.hpp:34