batmat 0.0.14
Batched linear algebra routines
Loading...
Searching...
No Matches
matrix.hpp
Go to the documentation of this file.
1#pragma once
2
3/// @file
4/// Class for a batch of matrices that owns its storage.
5/// @ingroup topic-matrix
6
9
10#include <type_traits>
11#include <utility>
12
13namespace batmat::matrix {
14
15namespace detail {
16
17template <class, class I, class Stride>
19 using type = std::integral_constant<I, 0>;
20};
21template <class T, class I, class Stride>
22 requires requires {
23 { Stride::value } -> std::convertible_to<I>;
24 }
25struct default_alignment<T, I, Stride> {
26 using type = std::integral_constant<I, alignof(T) * Stride::value>;
27};
28template <class T, class I, class Stride>
30
33
34} // namespace detail
35
36/// Class for a batch of matrices that owns its storage.
37/// @tparam T
38/// Element value type.
39/// @tparam I
40/// Index type.
41/// @tparam S
42/// Inner stride (batch size).
43/// @tparam D
44/// Depth type.
45/// @tparam O
46/// Storage order (column or row major).
47/// @tparam A
48/// Batch alignment type.
49/// @ingroup topic-matrix
50template <class T, class I = index_t, class S = std::integral_constant<I, 1>, class D = I,
51 StorageOrder O = StorageOrder::ColMajor, class A = detail::default_alignment_t<T, I, S>>
52struct Matrix {
53 static_assert(!std::is_const_v<T>);
57 using plain_layout_type = typename layout_type::PlainLayout;
58 using value_type = T;
59 using index_type = typename layout_type::index_type;
60 using batch_size_type = typename layout_type::batch_size_type;
61 using depth_type = typename layout_type::depth_type;
62 using standard_stride_type = typename layout_type::standard_stride_type;
63 using alignment_type = A;
64
65 private:
67
68 static constexpr auto default_alignment(layout_type layout) {
69 if constexpr (std::is_integral_v<alignment_type>)
70 return alignof(T) * static_cast<size_t>(layout.batch_size); // TODO
71 if constexpr (alignment_type::value == 0)
72 return alignof(T) * static_cast<size_t>(layout.batch_size);
73 else
74 return alignment_type{};
75 }
76 [[nodiscard]] static auto allocate(layout_type layout) {
77 const auto alignment = default_alignment(layout);
78 return make_aligned_unique_ptr<T>(layout.padded_size(), alignment);
79 }
80 [[nodiscard]] static auto allocate(layout_type layout, uninitialized_t init) {
81 const auto alignment = default_alignment(layout);
82 return make_aligned_unique_ptr<T>(layout.padded_size(), alignment, init);
83 }
84 void clear() {
85 const auto alignment = default_alignment(view_.layout);
86 if (auto d = std::exchange(view_.data_ptr, nullptr))
87 aligned_deleter<T, decltype(alignment)>(view_.layout.padded_size(), alignment)(d);
88 view_.layout.rows = 0;
89 }
90
91 public:
92 [[nodiscard, gnu::always_inline]] view_type view() { return view_; }
93 [[nodiscard, gnu::always_inline]] const_view_type view() const { return view_.as_const(); }
94
95 [[nodiscard]] layout_type layout() const { return view_.layout; }
96
97 void resize(layout_type new_layout) {
98 if (new_layout.padded_size() != layout().padded_size()) {
99 clear();
100 view_.data_ptr = allocate(new_layout).release();
101 }
102 view_.layout = new_layout;
103 }
104
106
107 Matrix() = default;
113
114 Matrix(const Matrix &o) : Matrix{o.layout()} {
115 this->view().copy_values(o.view()); // TODO: exception safety
116 }
117 Matrix(Matrix &&o) noexcept : view_{o.view()} { o.view_.reassign({}); }
119 if (&o != this) {
120 clear();
121 view_.reassign({allocate(o.layout()).release(), o.layout()});
122 // TODO: use allocate_for_overwrite or similar to avoid copy
123 // assignment
124 this->view_.copy_values(o.view_); // TODO: exception safety
125 }
126 return *this;
127 }
128 Matrix &operator=(Matrix &&o) noexcept {
129 using std::swap;
130 if (&o != this) {
131 swap(o.view_.data_ptr, this->view_.data_ptr);
132 swap(o.view_.layout, this->view_.layout);
133 }
134 return *this;
135 }
136 ~Matrix() { clear(); }
137
138 operator view_type() { return view(); }
139 operator const_view_type() const { return view(); }
140 operator View<T, I, S, D, I, O>() { return view(); }
141 operator View<const T, I, S, D, I, O>() const { return view(); }
142
147 return view()(l, r, c);
148 }
151 return view()(l);
152 }
153 [[nodiscard]] const value_type &operator()(index_type l, index_type r, index_type c) const {
154 return view()(l, r, c);
155 }
156
157 [[nodiscard]] auto batch(index_type b) { return view().batch(b); }
158 [[nodiscard]] auto batch(index_type b) const { return view().batch(b); }
159 [[nodiscard]] auto batch_dyn(index_type b) { return view().batch_dyn(b); }
160 [[nodiscard]] auto batch_dyn(index_type b) const { return view().batch_dyn(b); }
161
162 [[nodiscard]] auto reshaped(index_type rows, index_type cols) {
163 return view().reshaped(rows, cols);
164 }
165 [[nodiscard]] auto reshaped(index_type rows, index_type cols) const {
166 return view().reshaped(rows, cols);
167 }
168 [[nodiscard]] auto top_rows(index_type n) { return view().top_rows(n); }
169 [[nodiscard]] auto top_rows(index_type n) const { return view().top_rows(n); }
170 [[nodiscard]] auto left_cols(index_type n) { return view().left_cols(n); }
171 [[nodiscard]] auto left_cols(index_type n) const { return view().left_cols(n); }
172 [[nodiscard]] auto bottom_rows(index_type n) { return view().bottom_rows(n); }
173 [[nodiscard]] auto bottom_rows(index_type n) const { return view().bottom_rows(n); }
174 [[nodiscard]] auto right_cols(index_type n) { return view().right_cols(n); }
175 [[nodiscard]] auto right_cols(index_type n) const { return view().right_cols(n); }
176 [[nodiscard]] auto middle_rows(index_type r, index_type n) { return view().middle_rows(r, n); }
177 [[nodiscard]] auto middle_rows(index_type r, index_type n) const {
178 return view().middle_rows(r, n);
179 }
180 [[nodiscard]] auto middle_cols(index_type c, index_type n) { return view().middle_cols(c, n); }
181 [[nodiscard]] auto middle_cols(index_type c, index_type n) const {
182 return view().middle_cols(c, n);
183 }
184 [[nodiscard]] auto top_left(index_type nr, index_type nc) { return view().top_left(nr, nc); }
185 [[nodiscard]] auto top_left(index_type nr, index_type nc) const {
186 return view().top_left(nr, nc);
187 }
188 [[nodiscard]] auto top_right(index_type nr, index_type nc) { return view().top_right(nr, nc); }
189 [[nodiscard]] auto top_right(index_type nr, index_type nc) const {
190 return view().top_right(nr, nc);
191 }
192 [[nodiscard]] auto bottom_left(index_type nr, index_type nc) {
193 return view().bottom_left(nr, nc);
194 }
195 [[nodiscard]] auto bottom_left(index_type nr, index_type nc) const {
196 return view().bottom_left(nr, nc);
197 }
198 [[nodiscard]] auto bottom_right(index_type nr, index_type nc) {
199 return view().bottom_right(nr, nc);
200 }
201 [[nodiscard]] auto bottom_right(index_type nr, index_type nc) const {
202 return view().bottom_right(nr, nc);
203 }
204 [[nodiscard]] auto block(index_type r, index_type c, index_type nr, index_type nc) {
205 return view().block(r, c, nr, nc);
206 }
207 [[nodiscard]] auto block(index_type r, index_type c, index_type nr, index_type nc) const {
208 return view().block(r, c, nr, nc);
209 }
210 [[nodiscard]] auto transposed() { return view().transposed(); }
211 [[nodiscard]] auto transposed() const { return view().transposed(); }
212
213 [[nodiscard]] auto as_const() const { return view(); }
214 [[nodiscard]] auto begin() { return view().begin(); }
215 [[nodiscard]] auto begin() const { return view().begin(); }
216 [[nodiscard]] auto end() { return view().end(); }
217 [[nodiscard]] auto end() const { return view().end(); }
218 [[nodiscard]] index_type size() const { return view().size(); }
219 [[nodiscard]] index_type padded_size() const { return view().padded_size(); }
220
221 [[gnu::always_inline]] value_type *data() { return view().data(); }
222 [[gnu::always_inline]] const value_type *data() const { return view().data(); }
223 [[gnu::always_inline]] depth_type depth() const { return view().depth(); }
224 [[gnu::always_inline]] index_type ceil_depth() const { return view().ceil_depth(); }
225 [[gnu::always_inline]] index_type num_batches() const { return view().num_batches(); }
226 [[gnu::always_inline]] index_type rows() const { return view().rows(); }
227 [[gnu::always_inline]] index_type cols() const { return view().cols(); }
228 [[gnu::always_inline]] index_type outer_stride() const { return view().outer_stride(); }
229 [[gnu::always_inline]] batch_size_type batch_size() const { return view().batch_size(); }
230};
231
232template <class T, class I, class S, class D, class A, StorageOrder O>
233constexpr auto data(Matrix<T, I, S, D, O, A> &v) {
234 return v.data();
235}
236template <class T, class I, class S, class D, class A, StorageOrder O>
237constexpr auto data(Matrix<T, I, S, D, O, A> &&v) = delete;
238template <class T, class I, class S, class D, class A, StorageOrder O>
239constexpr auto data(const Matrix<T, I, S, D, O, A> &v) {
240 return v.data();
241}
242template <class T, class I, class S, class D, class A, StorageOrder O>
243constexpr auto rows(const Matrix<T, I, S, D, O, A> &v) {
244 return v.rows();
245}
246template <class T, class I, class S, class D, class A, StorageOrder O>
247constexpr auto cols(const Matrix<T, I, S, D, O, A> &v) {
248 return v.cols();
249}
250template <class T, class I, class S, class D, class A, StorageOrder O>
251constexpr auto outer_stride(const Matrix<T, I, S, D, O, A> &v) {
252 return v.outer_stride();
253}
254template <class T, class I, class S, class D, class A, StorageOrder O>
255constexpr auto depth(const Matrix<T, I, S, D, O, A> &v) {
256 return v.depth();
257}
258
259} // namespace batmat::matrix
auto make_aligned_unique_ptr(size_t size, A align)
Returns a smart pointer to an array of T that satisfies the given alignment requirements.
Definition storage.hpp:77
std::integral_constant< I, 0 > type
Definition matrix.hpp:19
typename default_alignment< T, I, Stride >::type default_alignment_t
Definition matrix.hpp:29
std::integral_constant< I, alignof(T) *Stride::value > type
Definition matrix.hpp:26
constexpr auto cols(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:247
constexpr auto data(Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:233
constexpr auto outer_stride(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:251
constexpr auto rows(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:243
constexpr auto depth(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:255
Aligned allocation for matrix storage.
Class for a batch of matrices that owns its storage.
Definition matrix.hpp:52
auto top_right(index_type nr, index_type nc) const
Definition matrix.hpp:189
Matrix(const Matrix &o)
Definition matrix.hpp:114
auto bottom_rows(index_type n) const
Definition matrix.hpp:173
auto bottom_left(index_type nr, index_type nc) const
Definition matrix.hpp:195
auto batch_dyn(index_type b)
Definition matrix.hpp:159
View< S, index_t, simd_stride_t, index_t, DefaultStride, O > view_type
Definition matrix.hpp:54
Matrix(plain_layout_type p)
Definition matrix.hpp:111
Matrix(plain_layout_type p, uninitialized_t init)
Definition matrix.hpp:112
Matrix & operator=(Matrix &&o) noexcept
Definition matrix.hpp:128
auto reshaped(index_type rows, index_type cols) const
Definition matrix.hpp:165
static auto allocate(layout_type layout)
Definition matrix.hpp:76
index_type padded_size() const
Definition matrix.hpp:219
void resize(layout_type new_layout)
Definition matrix.hpp:97
batch_size_type batch_size() const
Definition matrix.hpp:229
value_type * data()
Definition matrix.hpp:221
index_type size() const
Definition matrix.hpp:218
Matrix & operator=(const Matrix &o)
Definition matrix.hpp:118
auto block(index_type r, index_type c, index_type nr, index_type nc)
Definition matrix.hpp:204
auto top_rows(index_type n) const
Definition matrix.hpp:169
auto transposed() const
Definition matrix.hpp:211
Matrix(layout_type layout)
Definition matrix.hpp:108
auto reshaped(index_type rows, index_type cols)
Definition matrix.hpp:162
void set_constant(value_type t)
Definition matrix.hpp:105
auto top_left(index_type nr, index_type nc) const
Definition matrix.hpp:185
auto bottom_rows(index_type n)
Definition matrix.hpp:172
auto left_cols(index_type n)
Definition matrix.hpp:170
auto left_cols(index_type n) const
Definition matrix.hpp:171
auto middle_cols(index_type c, index_type n)
Definition matrix.hpp:180
auto top_rows(index_type n)
Definition matrix.hpp:168
auto batch(index_type b)
Definition matrix.hpp:157
auto middle_cols(index_type c, index_type n) const
Definition matrix.hpp:181
auto right_cols(index_type n)
Definition matrix.hpp:174
auto bottom_right(index_type nr, index_type nc) const
Definition matrix.hpp:201
auto top_right(index_type nr, index_type nc)
Definition matrix.hpp:188
auto bottom_left(index_type nr, index_type nc)
Definition matrix.hpp:192
auto right_cols(index_type n) const
Definition matrix.hpp:175
Matrix(layout_type layout, uninitialized_t init)
Definition matrix.hpp:109
guanaqo::MatrixView< const T, I, standard_stride_type, O > operator()(index_type l) const
Definition matrix.hpp:150
static constexpr auto default_alignment(layout_type layout)
Definition matrix.hpp:68
auto top_left(index_type nr, index_type nc)
Definition matrix.hpp:184
auto batch_dyn(index_type b) const
Definition matrix.hpp:160
value_type & operator()(index_type l, index_type r, index_type c)
Definition matrix.hpp:146
index_type outer_stride() const
Definition matrix.hpp:228
Matrix(Matrix &&o) noexcept
Definition matrix.hpp:117
const_view_type view() const
Definition matrix.hpp:93
auto batch(index_type b) const
Definition matrix.hpp:158
auto begin() const
Definition matrix.hpp:215
auto block(index_type r, index_type c, index_type nr, index_type nc) const
Definition matrix.hpp:207
static auto allocate(layout_type layout, uninitialized_t init)
Definition matrix.hpp:80
const value_type & operator()(index_type l, index_type r, index_type c) const
Definition matrix.hpp:153
index_type ceil_depth() const
Definition matrix.hpp:224
auto middle_rows(index_type r, index_type n) const
Definition matrix.hpp:177
auto bottom_right(index_type nr, index_type nc)
Definition matrix.hpp:198
auto as_const() const
Definition matrix.hpp:213
guanaqo::MatrixView< T, I, standard_stride_type, O > operator()(index_type l)
Definition matrix.hpp:143
auto middle_rows(index_type r, index_type n)
Definition matrix.hpp:176
const value_type * data() const
Definition matrix.hpp:222
depth_type depth() const
Definition matrix.hpp:223
index_type num_batches() const
Definition matrix.hpp:225
Non-owning view of a batch of matrices.
Definition view.hpp:32
Layout< I, S, D, L, O > layout_type
Definition view.hpp:33
general_slice_view_type bottom_right(index_type nr, index_type nc) const
Get a view of the bottom-right nr by nc block of the matrices.
Definition view.hpp:439
general_slice_view_type block(index_type r, index_type c, index_type nr, index_type nc) const
Get a view of the nr by nc block of the matrices starting at row r and column c.
Definition view.hpp:444
void set_constant(value_type t)
Definition view.hpp:485
auto transposed() const
Get a transposed view of the matrices.
Definition view.hpp:456
constexpr index_type num_batches() const
Number of batches in the view, i.e. ceil_depth() / batch_size().
Definition view.hpp:280
col_slice_view_type middle_cols(index_type c, index_type n) const
Get a view of n columns starting at column c.
Definition view.hpp:419
constexpr index_type padded_size() const
Total number of elements in the view (including all padding).
Definition view.hpp:271
linear_iterator begin() const
Iterate linearly (in storage order) over all elements of the view.
Definition view.hpp:236
row_slice_view_type bottom_rows(index_type n) const
Get a view of the last n rows.
Definition view.hpp:384
general_slice_view_type reshaped(index_type rows, index_type cols) const
Reshape the view to the given dimensions. The total size should not change.
Definition view.hpp:344
general_slice_view_type bottom_left(index_type nr, index_type nc) const
Get a view of the bottom-left nr by nc block of the matrices.
Definition view.hpp:434
View< const T, I, S, D, L, O > const_view_type
Definition view.hpp:40
row_slice_view_type middle_rows(index_type r, index_type n) const
Get a view of n rows starting at row r.
Definition view.hpp:414
constexpr batch_size_type batch_size() const
The batch size, i.e. the number of layers in each batch. Equals the inner stride.
Definition view.hpp:306
constexpr index_type cols() const
Number of columns of the matrices.
Definition view.hpp:286
std::default_sentinel_t end() const
Sentinel for begin().
Definition view.hpp:263
col_slice_view_type right_cols(index_type n) const
Get a view of the last n columns.
Definition view.hpp:399
View< T, I, S, I, L, O > batch_dyn(index_type b) const
Same as batch(), but returns a view with a dynamic batch size.
Definition view.hpp:155
col_slice_view_type left_cols(index_type n) const
Get a view of the first n columns.
Definition view.hpp:371
row_slice_view_type top_rows(index_type n) const
Get a view of the first n rows.
Definition view.hpp:358
constexpr index_type rows() const
Number of rows of the matrices.
Definition view.hpp:284
constexpr index_type ceil_depth() const
The depth rounded up to a multiple of the batch size.
Definition view.hpp:276
constexpr index_type size() const
Total number of elements in the view (excluding padding).
Definition view.hpp:269
T * data() const
Get a pointer to the first element of the first layer.
Definition view.hpp:203
general_slice_view_type top_left(index_type nr, index_type nc) const
Get a view of the top-left nr by nc block of the matrices.
Definition view.hpp:424
general_slice_view_type top_right(index_type nr, index_type nc) const
Get a view of the top-right nr by nc block of the matrices.
Definition view.hpp:429
constexpr depth_type depth() const
Number of layers in the view (i.e. depth).
Definition view.hpp:274
void copy_values(const Other &other) const
Definition view.hpp:514
constexpr index_type outer_stride() const
Outer stride of the matrices (leading dimension in BLAS parlance).
Definition view.hpp:289
batch_view_type batch(index_type b) const
Access a batch of batch_size() layers, starting at batch index b (i.e.
Definition view.hpp:143
Deleter for aligned memory allocated with operator new(size, align_val).
Definition storage.hpp:31
Non-owning view of a batch of matrices.