batmat 0.0.18
Batched linear algebra routines
Loading...
Searching...
No Matches
layout.hpp
Go to the documentation of this file.
1#pragma once
2
3/// @file
4/// Layout description for a batch of matrices, independent of any storage.
5/// @ingroup topic-matrix
6
7#include <batmat/config.hpp>
8#include <guanaqo/mat-view.hpp>
9#include <type_traits>
10
11namespace batmat::matrix {
12
14
15template <class T>
17 using type = T;
18};
19template <class IntConst>
20 requires requires { typename IntConst::value_type; }
21struct integral_value_type<IntConst> {
22 using type = typename IntConst::value_type;
23};
24template <class T>
26
28 DefaultStride() = default;
29 DefaultStride(index_t) {} // TODO: this is error prone
30};
31
32/// Shape and strides describing a batch of matrices, independent of any storage.
33///
34/// @tparam I
35/// Index and size type. Usually `std::ptrdiff_t` or `int`.
36/// @tparam S
37/// Inner stride type (batch size). Usually `std::integral_constant<I, N>` for some `N`.
38/// @tparam D
39/// Batch depth type. Usually equal to @p S for a single batch, or @p I for a dynamic depth.
40/// @tparam L
41/// Layer stride type. Usually @ref DefaultStride (which implies that the layer stride is
42/// equal to `outer_stride * outer_size()`), or @p I for a dynamic layer stride.
43/// Dynamic strides are used for subviews of views with a larger `outer_size()`.
44/// @tparam O
45/// %Matrix storage order, @ref guanaqo::StorageOrder::RowMajor "RowMajor" or
46/// @ref guanaqo::StorageOrder::ColMajor "ColMajor".
47/// @ingroup topic-matrix
48template <class I = index_t, class S = std::integral_constant<I, 1>, class D = I,
49 class L = DefaultStride, StorageOrder O = StorageOrder::ColMajor>
50struct Layout {
51 /// @name Compile-time properties
52 /// @{
53 using index_type = I;
54 using batch_size_type = S;
55 using depth_type = D;
57 static constexpr StorageOrder storage_order = O;
58 static constexpr bool is_column_major = O == StorageOrder::ColMajor;
59 static constexpr bool is_row_major = O == StorageOrder::RowMajor;
60 static constexpr std::integral_constant<index_type, 1> inner_stride{};
61
62 using standard_stride_type = std::conditional_t<requires {
63 S::value;
64 }, std::integral_constant<index_t, S::value>, index_t>;
65 /// @}
66
67 /// @name Layout description
68 /// @{
69
70 [[no_unique_address]] depth_type depth;
74 [[no_unique_address]] batch_size_type batch_size;
75 [[no_unique_address]] layer_stride_type layer_stride;
76
77 /// @}
78
79 /// @name Initialization
80 /// @{
81
92
93 constexpr Layout(PlainLayout p = {})
96
97 /// @}
98
99 [[nodiscard]] constexpr index_type outer_size() const { return is_row_major ? rows : cols; }
100 [[nodiscard]] constexpr index_type inner_size() const { return is_row_major ? cols : rows; }
101 [[nodiscard]] constexpr index_type num_batches() const {
102 const auto bs = static_cast<I>(batch_size);
103 const auto d = static_cast<I>(depth);
104 return (d + bs - 1) / bs;
105 }
106 /// The row stride of the matrices, i.e. the distance between elements in consecutive rows in
107 /// a given column. Should be multiplied by the batch size to get the actual number of elements.
108 [[nodiscard, gnu::always_inline]] constexpr auto row_stride() const {
109 if constexpr (is_column_major)
110 return std::integral_constant<index_type, 1>{};
111 else
112 return outer_stride;
113 }
114 /// The column stride of the matrices, i.e. the distance between elements in consecutive columns
115 /// in a given row. Should be multiplied by the batch size to get the actual number of elements.
116 [[nodiscard, gnu::always_inline]] constexpr auto col_stride() const {
117 if constexpr (is_column_major)
118 return outer_stride;
119 else
120 return std::integral_constant<index_type, 1>{};
121 }
122 /// Round up the given size @p n to a multiple of @ref batch_size.
123 [[nodiscard]] constexpr index_type ceil_depth(index_type n) const {
124 const auto bs = static_cast<I>(batch_size);
125 return n + (bs - n % bs) % bs;
126 }
127 /// Round up the @ref depth to a multiple of @ref batch_size.
128 [[nodiscard]] constexpr index_type ceil_depth() const {
129 return ceil_depth(static_cast<I>(depth));
130 }
131 /// Round down the given size @p n to a multiple of @ref batch_size.
132 [[nodiscard]] constexpr index_type floor_depth(index_type n) const {
133 const auto bs = static_cast<I>(batch_size);
134 return n - (n % bs);
135 }
136 /// Round down the @ref depth to a multiple of @ref batch_size.
137 [[nodiscard]] constexpr index_type floor_depth() const {
138 return floor_depth(static_cast<I>(depth));
139 }
140 [[nodiscard]] constexpr auto get_layer_stride() const {
141 if constexpr (std::is_same_v<layer_stride_type, DefaultStride>)
142 return outer_stride * outer_size();
143 else
144 return layer_stride;
145 }
146 [[nodiscard]] constexpr bool has_full_layer_stride() const {
147 return static_cast<index_t>(get_layer_stride()) == outer_stride * outer_size() ||
149 }
150 [[nodiscard]] constexpr bool has_full_outer_stride() const {
151 return outer_stride == inner_size() || outer_size() == 1;
152 }
153 [[nodiscard]] constexpr bool has_full_inner_stride() const { return inner_stride == 1; }
154 [[nodiscard]] constexpr index_type layer_index(index_type l, index_type s) const {
155 assert(0 <= l && l < ceil_depth());
156 const auto bs = static_cast<I>(batch_size);
157 index_type offset = l % bs;
158 return s * (l - offset) + offset;
159 }
160 [[nodiscard]] constexpr index_type layer_index(index_type l) const {
161 return layer_index(l, get_layer_stride());
162 }
163
165 if constexpr (requires { standard_stride_type::value; })
166 return {};
167 else
168 return static_cast<standard_stride_type>(s);
169 }
170
171 template <class T>
174 return {{.data = data + layer_index(l),
175 .rows = rows,
176 .cols = cols,
177 .inner_stride = convert_to_standard_stride(batch_size),
178 .outer_stride = outer_stride * static_cast<I>(batch_size)}};
179 }
180 template <class T>
181 [[nodiscard]] T &operator()(T *data, index_type l, index_type r, index_type c) const {
182 auto *const p = data + layer_index(l);
183 const auto bs = static_cast<I>(batch_size);
184 return *(is_row_major ? p + bs * (c + outer_stride * r) : p + bs * (r + outer_stride * c));
185 }
186 /// Total number of elements in the view (excluding padding).
187 [[nodiscard]] index_type size() const { return static_cast<I>(depth) * rows * cols; }
188 [[nodiscard]] index_type padded_size() const { return ceil_depth() * get_layer_stride(); }
189};
190
191} // namespace batmat::matrix
constexpr auto cols(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:497
typename integral_value_type< T >::type integral_value_type_t
Definition layout.hpp:25
constexpr auto data(Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:483
constexpr auto outer_stride(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:501
constexpr auto rows(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:493
typename IntConst::value_type type
Definition layout.hpp:22
constexpr auto depth(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:505
constexpr index_type num_batches() const
Definition layout.hpp:101
constexpr Layout(PlainLayout p={})
Definition layout.hpp:93
constexpr index_type ceil_depth(index_type n) const
Round up the given size n to a multiple of batch_size.
Definition layout.hpp:123
T & operator()(T *data, index_type l, index_type r, index_type c) const
Definition layout.hpp:181
constexpr auto get_layer_stride() const
Definition layout.hpp:140
static standard_stride_type convert_to_standard_stride(auto s)
Definition layout.hpp:164
constexpr auto col_stride() const
The column stride of the matrices, i.e.
Definition layout.hpp:116
constexpr index_type floor_depth() const
Round down the depth to a multiple of batch_size.
Definition layout.hpp:137
static constexpr bool is_column_major
Definition layout.hpp:58
constexpr index_type inner_size() const
Definition layout.hpp:100
static constexpr bool is_row_major
Definition layout.hpp:59
constexpr auto row_stride() const
The row stride of the matrices, i.e.
Definition layout.hpp:108
static constexpr std::integral_constant< index_type, 1 > inner_stride
Definition layout.hpp:60
index_type size() const
Total number of elements in the view (excluding padding).
Definition layout.hpp:187
constexpr bool has_full_outer_stride() const
Definition layout.hpp:150
constexpr index_type floor_depth(index_type n) const
Round down the given size n to a multiple of batch_size.
Definition layout.hpp:132
constexpr bool has_full_layer_stride() const
Definition layout.hpp:146
constexpr index_type layer_index(index_type l, index_type s) const
Definition layout.hpp:154
guanaqo::MatrixView< T, I, standard_stride_type, storage_order > operator()(T *data, index_type l) const
Definition layout.hpp:173
constexpr index_type layer_index(index_type l) const
Definition layout.hpp:160
constexpr index_type ceil_depth() const
Round up the depth to a multiple of batch_size.
Definition layout.hpp:128
static constexpr StorageOrder storage_order
Definition layout.hpp:57
std::conditional_t< requires { S::value; }, std::integral_constant< index_t, S::value >, index_t > standard_stride_type
Definition layout.hpp:62
index_type padded_size() const
Definition layout.hpp:188
constexpr bool has_full_inner_stride() const
Definition layout.hpp:153
constexpr index_type outer_size() const
Definition layout.hpp:99