batmat 0.0.14
Batched linear algebra routines
Loading...
Searching...
No Matches
layout.hpp
Go to the documentation of this file.
1#pragma once
2
3/// @file
4/// Layout description for a batch of matrices, independent of any storage.
5/// @ingroup topic-matrix
6
7#include <batmat/config.hpp>
8#include <guanaqo/mat-view.hpp>
9
10namespace batmat::matrix {
11
13
14template <class T>
16 using type = T;
17};
18template <class IntConst>
19 requires requires { typename IntConst::value_type; }
20struct integral_value_type<IntConst> {
21 using type = typename IntConst::value_type;
22};
23template <class T>
25
27 DefaultStride() = default;
28 DefaultStride(index_t) {} // TODO: this is error prone
29};
30
31/// Shape and strides describing a batch of matrices, independent of any storage.
32/// @tparam I
33/// Index type.
34/// @tparam S
35/// Inner stride (batch size).
36/// @tparam D
37/// Depth type.
38/// @tparam L
39/// Layer stride type.
40/// @tparam O
41/// Storage order (column or row major).
42/// @ingroup topic-matrix
43template <class I = index_t, class S = std::integral_constant<I, 1>, class D = I,
44 class L = DefaultStride, StorageOrder O = StorageOrder::ColMajor>
45struct Layout {
46 /// @name Compile-time properties
47 /// @{
48 using index_type = I;
49 using batch_size_type = S;
50 using depth_type = D;
52 static constexpr StorageOrder storage_order = O;
53 static constexpr bool is_column_major = O == StorageOrder::ColMajor;
54 static constexpr bool is_row_major = O == StorageOrder::RowMajor;
55
56 using standard_stride_type = std::conditional_t<requires {
57 S::value;
58 }, std::integral_constant<index_t, S::value>, index_t>;
59 /// @}
60
61 /// @name Layout description
62 /// @{
63
64 [[no_unique_address]] depth_type depth;
68 [[no_unique_address]] batch_size_type batch_size;
69 [[no_unique_address]] layer_stride_type layer_stride;
70
71 /// @}
72
73 /// @name Initialization
74 /// @{
75
86
87 constexpr Layout(PlainLayout p = {})
90
91 /// @}
92
93 [[nodiscard]] constexpr index_type outer_size() const { return is_row_major ? rows : cols; }
94 [[nodiscard]] constexpr index_type inner_size() const { return is_row_major ? cols : rows; }
95 [[nodiscard]] constexpr index_type num_batches() const {
96 const auto bs = static_cast<I>(batch_size);
97 const auto d = static_cast<I>(depth);
98 return (d + bs - 1) / bs;
99 }
100 /// Round up the given size @p n to a multiple of @ref batch_size.
101 [[nodiscard]] constexpr index_type ceil_depth(index_type n) const {
102 const auto bs = static_cast<I>(batch_size);
103 return n + (bs - n % bs) % bs;
104 }
105 /// Round up the @ref depth to a multiple of @ref batch_size.
106 [[nodiscard]] constexpr index_type ceil_depth() const {
107 return ceil_depth(static_cast<I>(depth));
108 }
109 /// Round down the given size @p n to a multiple of @ref batch_size.
110 [[nodiscard]] constexpr index_type floor_depth(index_type n) const {
111 const auto bs = static_cast<I>(batch_size);
112 return n - (n % bs);
113 }
114 /// Round down the @ref depth to a multiple of @ref batch_size.
115 [[nodiscard]] constexpr index_type floor_depth() const {
116 return floor_depth(static_cast<I>(depth));
117 }
118 [[nodiscard]] constexpr auto get_layer_stride() const {
119 if constexpr (std::is_same_v<layer_stride_type, DefaultStride>)
120 return outer_stride * outer_size();
121 else
122 return layer_stride;
123 }
124 [[nodiscard]] constexpr bool has_full_layer_stride() const {
125 return static_cast<index_t>(get_layer_stride()) == outer_stride * outer_size() ||
127 }
128 [[nodiscard]] constexpr bool has_full_outer_stride() const {
129 return outer_stride == inner_size() || outer_size() == 1;
130 }
131 [[nodiscard]] constexpr bool has_full_inner_stride() const { return true; }
132 [[nodiscard]] constexpr index_type layer_index(index_type l, index_type s) const {
133 assert(0 <= l && l < ceil_depth());
134 const auto bs = static_cast<I>(batch_size);
135 index_type offset = l % bs;
136 return s * (l - offset) + offset;
137 }
138 [[nodiscard]] constexpr index_type layer_index(index_type l) const {
139 return layer_index(l, get_layer_stride());
140 }
141
143 if constexpr (requires { standard_stride_type::value; })
144 return {};
145 else
146 return static_cast<standard_stride_type>(s);
147 }
148
149 template <class T>
152 return {{.data = data + layer_index(l),
153 .rows = rows,
154 .cols = cols,
155 .inner_stride = convert_to_standard_stride(batch_size),
156 .outer_stride = outer_stride * static_cast<I>(batch_size)}};
157 }
158 template <class T>
159 [[nodiscard]] T &operator()(T *data, index_type l, index_type r, index_type c) const {
160 auto *const p = data + layer_index(l);
161 const auto bs = static_cast<I>(batch_size);
162 return *(is_row_major ? p + bs * (c + outer_stride * r) : p + bs * (r + outer_stride * c));
163 }
164 /// Total number of elements in the view (excluding padding).
165 [[nodiscard]] index_type size() const { return static_cast<I>(depth) * rows * cols; }
166 [[nodiscard]] index_type padded_size() const { return ceil_depth() * get_layer_stride(); }
167};
168
169} // namespace batmat::matrix
constexpr auto cols(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:247
typename integral_value_type< T >::type integral_value_type_t
Definition layout.hpp:24
constexpr auto data(Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:233
constexpr auto outer_stride(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:251
constexpr auto rows(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:243
typename IntConst::value_type type
Definition layout.hpp:21
constexpr auto depth(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:255
constexpr index_type num_batches() const
Definition layout.hpp:95
constexpr Layout(PlainLayout p={})
Definition layout.hpp:87
constexpr index_type ceil_depth(index_type n) const
Round up the given size n to a multiple of batch_size.
Definition layout.hpp:101
T & operator()(T *data, index_type l, index_type r, index_type c) const
Definition layout.hpp:159
constexpr auto get_layer_stride() const
Definition layout.hpp:118
static standard_stride_type convert_to_standard_stride(auto s)
Definition layout.hpp:142
constexpr index_type floor_depth() const
Round down the depth to a multiple of batch_size.
Definition layout.hpp:115
static constexpr bool is_column_major
Definition layout.hpp:53
constexpr index_type inner_size() const
Definition layout.hpp:94
static constexpr bool is_row_major
Definition layout.hpp:54
index_type size() const
Total number of elements in the view (excluding padding).
Definition layout.hpp:165
constexpr bool has_full_outer_stride() const
Definition layout.hpp:128
constexpr index_type floor_depth(index_type n) const
Round down the given size n to a multiple of batch_size.
Definition layout.hpp:110
constexpr bool has_full_layer_stride() const
Definition layout.hpp:124
constexpr index_type layer_index(index_type l, index_type s) const
Definition layout.hpp:132
guanaqo::MatrixView< T, I, standard_stride_type, storage_order > operator()(T *data, index_type l) const
Definition layout.hpp:151
constexpr index_type layer_index(index_type l) const
Definition layout.hpp:138
constexpr index_type ceil_depth() const
Round up the depth to a multiple of batch_size.
Definition layout.hpp:106
static constexpr StorageOrder storage_order
Definition layout.hpp:52
std::conditional_t< requires { S::value; }, std::integral_constant< index_t, S::value >, index_t > standard_stride_type
Definition layout.hpp:56
index_type padded_size() const
Definition layout.hpp:166
constexpr bool has_full_inner_stride() const
Definition layout.hpp:131
constexpr index_type outer_size() const
Definition layout.hpp:93