batmat 0.0.15
Batched linear algebra routines
Loading...
Searching...
No Matches
layout.hpp
Go to the documentation of this file.
1#pragma once
2
3/// @file
4/// Layout description for a batch of matrices, independent of any storage.
5/// @ingroup topic-matrix
6
7#include <batmat/config.hpp>
8#include <guanaqo/mat-view.hpp>
9#include <type_traits>
10
11namespace batmat::matrix {
12
14
15template <class T>
17 using type = T;
18};
19template <class IntConst>
20 requires requires { typename IntConst::value_type; }
21struct integral_value_type<IntConst> {
22 using type = typename IntConst::value_type;
23};
24template <class T>
26
28 DefaultStride() = default;
29 DefaultStride(index_t) {} // TODO: this is error prone
30};
31
32/// Shape and strides describing a batch of matrices, independent of any storage.
33/// @tparam I
34/// Index type.
35/// @tparam S
36/// Inner stride (batch size).
37/// @tparam D
38/// Depth type.
39/// @tparam L
40/// Layer stride type.
41/// @tparam O
42/// Storage order (column or row major).
43/// @ingroup topic-matrix
44template <class I = index_t, class S = std::integral_constant<I, 1>, class D = I,
45 class L = DefaultStride, StorageOrder O = StorageOrder::ColMajor>
46struct Layout {
47 /// @name Compile-time properties
48 /// @{
49 using index_type = I;
50 using batch_size_type = S;
51 using depth_type = D;
53 static constexpr StorageOrder storage_order = O;
54 static constexpr bool is_column_major = O == StorageOrder::ColMajor;
55 static constexpr bool is_row_major = O == StorageOrder::RowMajor;
56 static constexpr std::integral_constant<index_type, 1> inner_stride{};
57
58 using standard_stride_type = std::conditional_t<requires {
59 S::value;
60 }, std::integral_constant<index_t, S::value>, index_t>;
61 /// @}
62
63 /// @name Layout description
64 /// @{
65
66 [[no_unique_address]] depth_type depth;
70 [[no_unique_address]] batch_size_type batch_size;
71 [[no_unique_address]] layer_stride_type layer_stride;
72
73 /// @}
74
75 /// @name Initialization
76 /// @{
77
88
89 constexpr Layout(PlainLayout p = {})
92
93 /// @}
94
95 [[nodiscard]] constexpr index_type outer_size() const { return is_row_major ? rows : cols; }
96 [[nodiscard]] constexpr index_type inner_size() const { return is_row_major ? cols : rows; }
97 [[nodiscard]] constexpr index_type num_batches() const {
98 const auto bs = static_cast<I>(batch_size);
99 const auto d = static_cast<I>(depth);
100 return (d + bs - 1) / bs;
101 }
102 /// The row stride of the matrices, i.e. the distance between elements in consecutive rows in
103 /// a given column. Should be multiplied by the batch size to get the actual number of elements.
104 [[nodiscard, gnu::always_inline]] constexpr auto row_stride() const {
105 if constexpr (is_column_major)
106 return std::integral_constant<index_type, 1>{};
107 else
108 return outer_stride;
109 }
110 /// The column stride of the matrices, i.e. the distance between elements in consecutive columns
111 /// in a given row. Should be multiplied by the batch size to get the actual number of elements.
112 [[nodiscard, gnu::always_inline]] constexpr auto col_stride() const {
113 if constexpr (is_column_major)
114 return outer_stride;
115 else
116 return std::integral_constant<index_type, 1>{};
117 }
118 /// Round up the given size @p n to a multiple of @ref batch_size.
119 [[nodiscard]] constexpr index_type ceil_depth(index_type n) const {
120 const auto bs = static_cast<I>(batch_size);
121 return n + (bs - n % bs) % bs;
122 }
123 /// Round up the @ref depth to a multiple of @ref batch_size.
124 [[nodiscard]] constexpr index_type ceil_depth() const {
125 return ceil_depth(static_cast<I>(depth));
126 }
127 /// Round down the given size @p n to a multiple of @ref batch_size.
128 [[nodiscard]] constexpr index_type floor_depth(index_type n) const {
129 const auto bs = static_cast<I>(batch_size);
130 return n - (n % bs);
131 }
132 /// Round down the @ref depth to a multiple of @ref batch_size.
133 [[nodiscard]] constexpr index_type floor_depth() const {
134 return floor_depth(static_cast<I>(depth));
135 }
136 [[nodiscard]] constexpr auto get_layer_stride() const {
137 if constexpr (std::is_same_v<layer_stride_type, DefaultStride>)
138 return outer_stride * outer_size();
139 else
140 return layer_stride;
141 }
142 [[nodiscard]] constexpr bool has_full_layer_stride() const {
143 return static_cast<index_t>(get_layer_stride()) == outer_stride * outer_size() ||
145 }
146 [[nodiscard]] constexpr bool has_full_outer_stride() const {
147 return outer_stride == inner_size() || outer_size() == 1;
148 }
149 [[nodiscard]] constexpr bool has_full_inner_stride() const { return inner_stride == 1; }
150 [[nodiscard]] constexpr index_type layer_index(index_type l, index_type s) const {
151 assert(0 <= l && l < ceil_depth());
152 const auto bs = static_cast<I>(batch_size);
153 index_type offset = l % bs;
154 return s * (l - offset) + offset;
155 }
156 [[nodiscard]] constexpr index_type layer_index(index_type l) const {
157 return layer_index(l, get_layer_stride());
158 }
159
161 if constexpr (requires { standard_stride_type::value; })
162 return {};
163 else
164 return static_cast<standard_stride_type>(s);
165 }
166
167 template <class T>
170 return {{.data = data + layer_index(l),
171 .rows = rows,
172 .cols = cols,
173 .inner_stride = convert_to_standard_stride(batch_size),
174 .outer_stride = outer_stride * static_cast<I>(batch_size)}};
175 }
176 template <class T>
177 [[nodiscard]] T &operator()(T *data, index_type l, index_type r, index_type c) const {
178 auto *const p = data + layer_index(l);
179 const auto bs = static_cast<I>(batch_size);
180 return *(is_row_major ? p + bs * (c + outer_stride * r) : p + bs * (r + outer_stride * c));
181 }
182 /// Total number of elements in the view (excluding padding).
183 [[nodiscard]] index_type size() const { return static_cast<I>(depth) * rows * cols; }
184 [[nodiscard]] index_type padded_size() const { return ceil_depth() * get_layer_stride(); }
185};
186
187} // namespace batmat::matrix
constexpr auto cols(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:490
typename integral_value_type< T >::type integral_value_type_t
Definition layout.hpp:25
constexpr auto data(Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:476
constexpr auto outer_stride(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:494
constexpr auto rows(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:486
typename IntConst::value_type type
Definition layout.hpp:22
constexpr auto depth(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:498
constexpr index_type num_batches() const
Definition layout.hpp:97
constexpr Layout(PlainLayout p={})
Definition layout.hpp:89
constexpr index_type ceil_depth(index_type n) const
Round up the given size n to a multiple of batch_size.
Definition layout.hpp:119
T & operator()(T *data, index_type l, index_type r, index_type c) const
Definition layout.hpp:177
constexpr auto get_layer_stride() const
Definition layout.hpp:136
static standard_stride_type convert_to_standard_stride(auto s)
Definition layout.hpp:160
constexpr auto col_stride() const
The column stride of the matrices, i.e.
Definition layout.hpp:112
constexpr index_type floor_depth() const
Round down the depth to a multiple of batch_size.
Definition layout.hpp:133
static constexpr bool is_column_major
Definition layout.hpp:54
constexpr index_type inner_size() const
Definition layout.hpp:96
static constexpr bool is_row_major
Definition layout.hpp:55
constexpr auto row_stride() const
The row stride of the matrices, i.e.
Definition layout.hpp:104
static constexpr std::integral_constant< index_type, 1 > inner_stride
Definition layout.hpp:56
index_type size() const
Total number of elements in the view (excluding padding).
Definition layout.hpp:183
constexpr bool has_full_outer_stride() const
Definition layout.hpp:146
constexpr index_type floor_depth(index_type n) const
Round down the given size n to a multiple of batch_size.
Definition layout.hpp:128
constexpr bool has_full_layer_stride() const
Definition layout.hpp:142
constexpr index_type layer_index(index_type l, index_type s) const
Definition layout.hpp:150
guanaqo::MatrixView< T, I, standard_stride_type, storage_order > operator()(T *data, index_type l) const
Definition layout.hpp:169
constexpr index_type layer_index(index_type l) const
Definition layout.hpp:156
constexpr index_type ceil_depth() const
Round up the depth to a multiple of batch_size.
Definition layout.hpp:124
static constexpr StorageOrder storage_order
Definition layout.hpp:53
std::conditional_t< requires { S::value; }, std::integral_constant< index_t, S::value >, index_t > standard_stride_type
Definition layout.hpp:58
index_type padded_size() const
Definition layout.hpp:184
constexpr bool has_full_inner_stride() const
Definition layout.hpp:149
constexpr index_type outer_size() const
Definition layout.hpp:95