batmat 0.0.17
Batched linear algebra routines
Loading...
Searching...
No Matches
dtypes.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/config.hpp>
4#include <algorithm>
5#include <array>
6#include <functional>
7#include <type_traits>
8
9namespace batmat::types {
10
11template <class... Ts>
12struct Types {
13 template <template <class...> class Template, class... Args>
14 using into = Template<Args..., Ts...>;
15};
16
17template <class>
18struct Head;
19template <class T, class... Ts>
20struct Head<Types<T, Ts...>> {
21 using type = Types<T>;
22};
23
24template <class>
25struct Tail;
26template <class T, class... Ts>
27struct Tail<Types<T, Ts...>> {
28 using type = Types<Ts...>;
29};
30
31template <class...>
32struct Concat;
33template <>
34struct Concat<> {
35 using type = Types<>;
36};
37template <class... Ts>
38struct Concat<Types<Ts...>> {
39 using type = Types<Ts...>;
40};
41template <class... Ts1, class... Ts2, class... Rest>
42struct Concat<Types<Ts1...>, Types<Ts2...>, Rest...> {
43 using type = typename Concat<Types<Ts1..., Ts2...>, Rest...>::type;
44};
45
46template <template <class> class Func, class>
47struct Map;
48template <template <class> class Func, class... Ts>
49struct Map<Func, Types<Ts...>> {
50 using type = Types<Func<Ts>...>;
51};
52template <template <class> class Func, class List>
54
55template <template <class> class Func, class>
56struct FlatMap;
57template <template <class> class Func, class... Ts>
58struct FlatMap<Func, Types<Ts...>> {
59 using type = typename Concat<Func<Ts>...>::type;
60};
61template <template <class> class Func, class List>
63
64template <template <class> class Pred, class>
65struct Filter;
66template <template <class> class Pred, class... Ts>
67struct Filter<Pred, Types<Ts...>> {
69};
70template <template <class> class Pred, class List>
72
73template <class T, index_t VL>
75 using dtype = T;
76 using vl_t = std::integral_constant<index_t, VL>;
77 static constexpr index_t vl = VL;
78};
79
80template <index_t VL>
82 template <class T>
83 using type = std::bool_constant<VL == T::vl>;
84};
85
86template <class DT>
87struct DTypeIs {
88 template <class T>
89 using type = std::is_same<DT, typename T::dtype>;
90};
91
92template <index_t VL, class List>
94
95template <class Ts>
96using GetDType = typename Ts::dtype;
97
98#define BATMAT_PREFIX_COMMA(...) , __VA_ARGS__
99/// @ref Types containing all supported dtypes.
100using dtype_all = Tail<Types<void BATMAT_FOREACH_DTYPE(BATMAT_PREFIX_COMMA)>>::type;
101
102/// @ref Types containing @ref DTypeVectorLength for all supported (dtype, VL) combinations.
103#define BATMAT_INST_DT_VL(DT, VL) , DTypeVectorLength<DT, VL>
104using dtype_vl_all = Tail<Types<void BATMAT_FOREACH_DTYPE_VL(BATMAT_INST_DT_VL)>>::type;
105#undef BATMAT_INST_DT_VL
106
107/// Array of supported vector lengths for a given dtype @p T.
108template <class DT>
109constexpr std::array vl_for_dtype = []<class... Dtvls>(Types<Dtvls...>) {
110 return std::array<index_t, sizeof...(Dtvls)>{Dtvls::vl...};
111}(Filter_t<DTypeIs<DT>::template type, dtype_vl_all>{});
112
113/// Array of supported vector lengths for the default @ref real_t.
114constexpr std::array vl_for_real_t = vl_for_dtype<real_t>;
115
116/// @ref Types containing all supported dtypes for a given vector length @p VL.
117template <index_t VL>
119
120/// @ref Types containing the given dtype and vector length combination, if supported.
121template <class DT, index_t VL>
124
125/// Check if a given (dtype, VL) combination is supported.
126template <class DT, index_t VL>
127constexpr bool is_supported_dtype_vl = !std::is_same_v<lookup_dtype_vl<DT, VL>, Types<>>;
128
129/// The smallest supported vector length for dtype @p DT that is greater than or equal to @p VL.
130/// Returns 0 if no supported vector length is large enough.
131template <class DT, index_t VL>
132constexpr index_t vl_at_least = [] {
133 if constexpr (is_supported_dtype_vl<DT, VL>) {
134 return VL;
135 } else {
136 auto options = vl_for_dtype<DT>;
137 std::ranges::sort(options, std::less{});
138 for (auto v : options)
139 if (v >= VL)
140 return v;
141 return index_t{0};
142 }
143}();
144
145/// The largest supported vector length for dtype @p DT that is less than or equal to @p VL.
146/// Returns 0 if no supported vector length is small enough.
147template <class DT, index_t VL>
148constexpr index_t vl_at_most = [] {
149 if constexpr (is_supported_dtype_vl<DT, VL>) {
150 return VL;
151 } else {
152 auto options = vl_for_dtype<DT>;
153 std::ranges::sort(options, std::greater{});
154 for (auto v : options)
155 if (v <= VL)
156 return v;
157 return index_t{0};
158 }
159}();
160
161/// @p VL if it is a supported vector length for dtype @p DT, otherwise the largest supported vector
162/// length for @p DT.
163template <class DT, index_t VL>
164constexpr index_t vl_or_largest = [] {
165 if constexpr (is_supported_dtype_vl<DT, VL>) {
166 return VL;
167 } else {
168 auto options = vl_for_dtype<DT>;
169 std::ranges::sort(options, std::greater{});
170 return options.empty() ? index_t{0} : options.front();
171 }
172}();
173
174/// Call a given function @p f for all supported (dtype, VL) combinations. @p f should be callable
175/// with signature `void(DTypeVectorLength)`.
176template <class F>
177constexpr auto foreach_dtype_vl(F &&f) {
178 return [&f]<class... Ts>(Types<Ts...>) { (f(Ts{}), ...); }(dtype_vl_all{});
179}
180
181} // namespace batmat::types
#define BATMAT_INST_DT_VL(DT, VL)
Types containing DTypeVectorLength for all supported (dtype, VL) combinations.
Definition dtypes.hpp:103
#define BATMAT_PREFIX_COMMA(...)
Definition dtypes.hpp:98
typename Map< Func, List >::type Map_t
Definition dtypes.hpp:53
typename FlatMap< Func, List >::type FlatMap_t
Definition dtypes.hpp:62
Tail< Types< void, DTypeVectorLength< double, 1 >, DTypeVectorLength< double, 4 >, DTypeVectorLength< double, 8 > > >::type dtype_vl_all
Definition dtypes.hpp:104
std::is_same< DT, typename T::dtype > type
Definition dtypes.hpp:89
Template< Args..., Ts... > into
Definition dtypes.hpp:14
typename Concat< Types< Ts1..., Ts2... >, Rest... >::type type
Definition dtypes.hpp:43
std::bool_constant< VL==T::vl > type
Definition dtypes.hpp:83
typename Concat< Func< Ts >... >::type type
Definition dtypes.hpp:59
constexpr auto foreach_dtype_vl(F &&f)
Call a given function f for all supported (dtype, VL) combinations.
Definition dtypes.hpp:177
typename Filter< Pred, List >::type Filter_t
Definition dtypes.hpp:71
constexpr index_t vl_or_largest
VL if it is a supported vector length for dtype DT, otherwise the largest supported vector length for...
Definition dtypes.hpp:164
constexpr std::array vl_for_dtype
Array of supported vector lengths for a given dtype T.
Definition dtypes.hpp:109
constexpr std::array vl_for_real_t
Array of supported vector lengths for the default real_t.
Definition dtypes.hpp:114
typename Ts::dtype GetDType
Definition dtypes.hpp:96
constexpr index_t vl_at_most
The largest supported vector length for dtype DT that is less than or equal to VL.
Definition dtypes.hpp:148
constexpr index_t vl_at_least
The smallest supported vector length for dtype DT that is greater than or equal to VL.
Definition dtypes.hpp:132
FlatMap_t< VectorLengthIs< VL >::template type, List > FilterVL
Definition dtypes.hpp:93
Map_t< GetDType, Filter_t< VectorLengthIs< VL >::template type, dtype_vl_all > > dtypes_for_vl
Types containing all supported dtypes for a given vector length VL.
Definition dtypes.hpp:118
Tail< Types< void, double > >::type dtype_all
Types containing all supported dtypes.
Definition dtypes.hpp:100
Filter_t< DTypeIs< DT >::template type, Filter_t< VectorLengthIs< VL >::template type, dtype_vl_all > > lookup_dtype_vl
Types containing the given dtype and vector length combination, if supported.
Definition dtypes.hpp:122
typename Concat< std::conditional_t< Pred< Ts >::value, Types< Ts >, Types<> >... >::type type
Definition dtypes.hpp:68
constexpr bool is_supported_dtype_vl
Check if a given (dtype, VL) combination is supported.
Definition dtypes.hpp:127
std::integral_constant< index_t, VL > vl_t
Definition dtypes.hpp:76
static constexpr index_t vl
Definition dtypes.hpp:77