guanaqo develop
Utilities for scientific software
Loading...
Searching...
No Matches
mat-view.hpp
Go to the documentation of this file.
1#pragma once
2
3/// @file
4/// @ingroup linalg_views
5/// Non-owning matrix view.
6
7#include <algorithm>
8#include <cassert>
9#include <concepts>
10#include <cstddef>
11#include <span>
12#include <type_traits>
13
14namespace guanaqo {
15
16/// @addtogroup linalg_views
17/// @{
18
19/// Storage order of matrices.
20enum class StorageOrder : bool {
21 /// Column-major storage order (Fortran-style, layout left, unit row stride).
23 /// Row-major storage order (C-style, layout right, unit column stride).
25};
26
27/// Transpose the storage order (swaps row and column major).
32
33/// Triangular matrix structure.
34/// @todo Unit-diagonal variants?
35enum class Triangular {
36 Lower, ///< Lower triangular (including diagonal).
37 StrictlyLower, ///< Strictly lower triangular (excluding diagonal).
38 Upper, ///< Upper triangular (including diagonal).
39 StrictlyUpper, ///< Strictly upper triangular (excluding diagonal).
40};
41
42/// Get the default stride value for a given integer type (primary template).
43template <class S>
45
46/// Get the default stride value for a run-time integer type.
47template <std::integral S>
48struct default_stride<S> {
49 static constexpr S value{1};
50};
51
52/// Get the default stride value for a compile-time integer type such as `std::integral_constant`.
53template <class S>
54 requires(std::default_initializable<S> && !std::constructible_from<S, int>)
55struct default_stride<S> {
56 static constexpr S value{};
57};
58
59/// A lightweight view of a 2D matrix.
60/// @tparam T Element type (may be const-qualified).
61/// @tparam I Index type for dimensions and strides.
62/// @tparam S Stride type for inner dimension. Can be an integral constant for
63/// compile-time known strides, or can be the same as @p I for
64/// runtime strides.
65/// @tparam O Storage order (column-major or row-major).
66template <class T, class I = ptrdiff_t, class S = std::integral_constant<I, 1>,
67 StorageOrder O = StorageOrder::ColMajor>
68struct MatrixView {
69 using value_type = T;
70 using index_type = I;
72 static constexpr StorageOrder storage_order = O;
73 static constexpr bool is_column_major = O == StorageOrder::ColMajor;
74 static constexpr bool is_row_major = O == StorageOrder::RowMajor;
75
79 [[no_unique_address]] inner_stride_type inner_stride;
81
82 [[nodiscard]] constexpr index_type outer_size() const {
83 return is_row_major ? rows : cols;
84 }
85 [[nodiscard]] constexpr index_type inner_size() const {
86 return is_row_major ? cols : rows;
87 }
88 [[nodiscard]] constexpr auto row_stride() const {
89 if constexpr (is_row_major)
90 return outer_stride;
91 else
92 return inner_stride;
93 }
94 [[nodiscard]] constexpr index_type col_stride() const {
95 if constexpr (is_column_major)
96 return outer_stride;
97 else
98 return inner_stride;
99 }
100
101 /// POD type for designated initializers
110
111 MatrixView(PlainMatrixView p)
112 : data{p.data}, rows{p.rows}, cols{p.cols},
114#ifdef NDEBUG
115 [[gnu::always_inline]]
116#endif
118 assert(0 <= r && r < rows);
119 assert(0 <= c && c < cols);
120 return data[c * col_stride() + r * row_stride()];
121 }
122#if __cpp_multidimensional_subscript >= 202110L
123#ifdef NDEBUG
124 [[gnu::always_inline]]
125#endif
126 value_type &operator[](index_type r, index_type c) const {
127 return operator()(r, c);
128 }
129#endif
130 [[nodiscard]] bool empty() const { return rows == 0 || cols == 0; }
131
133 assert(0 <= n && n <= rows);
134 return {{
135 .data = data,
136 .rows = n,
137 .cols = cols,
138 .inner_stride = inner_stride,
139 .outer_stride = outer_stride,
140 }};
141 }
143 assert(0 <= n && n <= cols);
144 return {{
145 .data = data,
146 .rows = rows,
147 .cols = n,
148 .inner_stride = inner_stride,
149 .outer_stride = outer_stride,
150 }};
151 }
153 assert(0 <= n && n <= rows);
154 return {{
155 .data = data + row_stride() * (rows - n),
156 .rows = n,
157 .cols = cols,
160 }};
161 }
163 assert(0 <= n && n <= cols);
164 return {{
165 .data = data + col_stride() * (cols - n),
166 .rows = rows,
167 .cols = n,
170 }};
171 }
173 return bottom_rows(rows - r).top_rows(n);
174 }
176 return right_cols(cols - c).left_cols(n);
177 }
179 return top_rows(nr).left_cols(nc);
180 }
182 return top_rows(nr).right_cols(nc);
183 }
185 return bottom_rows(nr).left_cols(nc);
186 }
188 return bottom_rows(nr).right_cols(nc);
189 }
191 index_type nc) const {
192 return middle_rows(r, nr).middle_cols(c, nc);
193 }
194
195 static MatrixView as_column(std::span<T> v)
196 requires is_column_major
197 {
198 return {{
199 .data = v.data(),
200 .rows = static_cast<index_type>(v.size()),
201 .cols = 1,
202 }};
203 }
204
205 static MatrixView as_row(std::span<T> v)
206 requires is_row_major
207 {
208 return {{
209 .data = v.data(),
210 .rows = 1,
211 .cols = static_cast<index_type>(v.size()),
212 }};
213 }
214
216 return {{
217 .data = data,
218 .rows = rows,
219 .cols = cols,
220 .inner_stride = inner_stride,
221 .outer_stride = outer_stride,
222 }};
223 }
224
225 MatrixView<T, I, S, transpose(O)> transposed() const {
226 return {{
227 .data = data,
228 .rows = cols,
229 .cols = rows,
230 .inner_stride = inner_stride,
231 .outer_stride = outer_stride,
232 }};
233 }
234
235 void set_constant(const value_type &t) {
236 if (inner_stride == 1)
237 for (index_type c = 0; c < outer_size(); ++c)
238 std::fill_n(data + c * outer_stride, inner_size(), t);
239 else if constexpr (is_column_major)
240 for (index_type c = 0; c < cols; ++c)
241 for (index_type r = 0; r < rows; ++r)
242 (*this)(r, c) = t;
243 else
244 for (index_type r = 0; r < rows; ++r)
245 for (index_type c = 0; c < cols; ++c)
246 (*this)(r, c) = t;
247 }
249 requires(inner_stride_type::value == 1 && is_column_major)
250 {
251 using std::fill_n;
252 auto n = std::min(rows, cols);
253 switch (tr) {
255 for (index_type c = 0; c < n; ++c)
256 fill_n(data + c + c * col_stride(), rows - c, t);
257 break;
259 for (index_type c = 0; c < n; ++c)
260 fill_n(data + c + 1 + c * col_stride(), rows - c - 1, t);
261 break;
263 for (index_type c = 0; c < n; ++c)
264 fill_n(data + c * col_stride(), 1 + c, t);
265 for (index_type c = n; c < cols; ++c)
266 fill_n(data + c * col_stride(), rows, t);
267 break;
269 for (index_type c = 1; c < n; ++c)
270 fill_n(data + c * col_stride(), c, t);
271 for (index_type c = n; c < cols; ++c)
272 fill_n(data + c * col_stride(), rows, t);
273 break;
274 default: assert(!"Unexpected value for guanaqo::Triangular");
275 }
276 }
278 requires(inner_stride_type::value == 1 && is_row_major)
279 {
280 using std::fill_n;
281 auto n = std::min(rows, cols);
282 switch (tr) {
284 for (index_type r = 0; r < n; ++r)
285 fill_n(data + r * row_stride(), 1 + r, t);
286 for (index_type r = n; r < rows; ++r)
287 fill_n(data + r * row_stride(), cols, t);
288 break;
290 for (index_type r = 1; r < n; ++r)
291 fill_n(data + r * row_stride(), r, t);
292 for (index_type r = n; r < rows; ++r)
293 fill_n(data + r * row_stride(), cols, t);
294 break;
296 for (index_type r = 0; r < n; ++r)
297 fill_n(data + r + r * row_stride(), cols - r, t);
298 break;
300 for (index_type r = 0; r < n; ++r)
301 fill_n(data + r + 1 + r * row_stride(), cols - r - 1, t);
302 break;
303 default: assert(!"Unexpected value for guanaqo::Triangular");
304 }
305 }
307 auto n = std::max(rows, cols);
308 return {{
309 .data = data,
310 .rows = n,
311 .cols = 1,
312 .inner_stride = outer_stride + inner_stride,
313 }};
314 }
315 [[gnu::always_inline]] void iter_diagonal(const auto f) {
316 auto *p = data;
317 auto n = std::max(rows, cols);
318 for (index_type i = 0; i < n; ++i) {
319 f(i, *p);
321 }
322 }
323 [[gnu::always_inline]] void iter_diagonal(const auto f) const {
324 auto *p = data;
325 auto n = std::max(rows, cols);
326 for (index_type i = 0; i < n; ++i) {
327 f(i, *p);
329 }
330 }
331 void set_diagonal(const value_type &t) {
332 iter_diagonal([&t](index_type, value_type &value) { value = t; });
333 }
335 iter_diagonal([&t](index_type, value_type &value) { value += t; });
336 }
337 template <class Other>
338 void copy_values(const Other &other) {
339 static_assert(storage_order == Other::storage_order);
340 assert(other.rows == this->rows);
341 assert(other.cols == this->cols);
342 const auto *src = other.data;
343 auto *dst = this->data;
344 for (index_type c = 0; c < this->outer_size(); ++c) {
345 if (other.inner_stride == 1 && this->inner_stride == 1)
346 std::copy_n(src, this->inner_size(), dst);
347 else {
348 const auto *src_ = src;
349 auto *dst_ = dst;
350 for (index_type r = 0; r < this->inner_size(); ++r) {
351 *dst_ = *src_;
352 src_ += other.inner_stride;
353 dst_ += this->inner_stride;
354 }
355 }
356 src += other.outer_stride;
357 dst += this->outer_stride;
358 }
359 }
360 MatrixView(const MatrixView &) = default;
362 if (this != &other)
363 copy_values(other);
364 return *this;
365 }
366 template <class U, class J, class R>
367 requires(!std::is_const_v<T> &&
368 std::convertible_to<U, std::remove_cv_t<T>> &&
369 std::equality_comparable_with<I, J>)
371 copy_values(other);
372 return *this;
373 }
374 // TODO: abstract logic into generic function (and check performance)
375 template <class U, class J, class R>
376 requires(!std::is_const_v<T> &&
377 std::convertible_to<U, std::remove_cv_t<T>> &&
378 std::equality_comparable_with<I, J>)
380 assert(other.rows == this->rows);
381 assert(other.cols == this->cols);
382 const auto *src = other.data;
383 auto *dst = this->data;
384 for (index_type c = 0; c < this->outer_size(); ++c) {
385 const auto *src_ = src;
386 auto *dst_ = dst;
387 for (index_type r = 0; r < this->inner_size(); ++r) {
388 *dst_ += *src_;
389 src_ += other.inner_stride;
390 dst_ += this->inner_stride;
391 }
392 src += other.outer_stride;
393 dst += this->outer_stride;
394 }
395 return *this;
396 }
397 // TODO: abstract logic into generic function (and check performance)
398 template <class U>
399 MatrixView &operator+=(const U &u) {
400 auto *dst = this->data;
401 for (index_type c = 0; c < this->outer_size(); ++c) {
402 auto *dst_ = dst;
403 for (index_type r = 0; r < this->inner_size(); ++r) {
404 *dst_ += u;
405 dst_ += this->inner_stride;
406 }
407 dst += this->outer_stride;
408 }
409 return *this;
410 }
411 template <class Generator>
412 void generate(Generator gen) {
413 for (index_type c = 0; c < cols; ++c)
414 for (index_type r = 0; r < rows; ++r)
415 (*this)(r, c) = gen();
416 }
418 this->data = other.data;
419 this->rows = other.rows;
420 this->cols = other.cols;
421 this->inner_stride = other.inner_stride;
422 this->outer_stride = other.outer_stride;
423 return *this;
424 }
425};
426
427/// Convenience alias for a row-major @ref MatrixView.
428template <class T, class I = ptrdiff_t, class S = std::integral_constant<I, 1>>
430
431/// @}
432
433} // namespace guanaqo
MatrixView< T, I, S, StorageOrder::RowMajor > MatrixViewRM
Convenience alias for a row-major MatrixView.
Definition mat-view.hpp:429
StorageOrder
Storage order of matrices.
Definition mat-view.hpp:20
constexpr StorageOrder transpose(StorageOrder o)
Transpose the storage order (swaps row and column major).
Definition mat-view.hpp:28
Triangular
Triangular matrix structure.
Definition mat-view.hpp:35
@ ColMajor
Column-major storage order (Fortran-style, layout left, unit row stride).
Definition mat-view.hpp:22
@ RowMajor
Row-major storage order (C-style, layout right, unit column stride).
Definition mat-view.hpp:24
@ StrictlyLower
Strictly lower triangular (excluding diagonal).
Definition mat-view.hpp:37
@ Upper
Upper triangular (including diagonal).
Definition mat-view.hpp:38
@ StrictlyUpper
Strictly upper triangular (excluding diagonal).
Definition mat-view.hpp:39
@ Lower
Lower triangular (including diagonal).
Definition mat-view.hpp:36
Get the default stride value for a given integer type (primary template).
Definition mat-view.hpp:44
A lightweight view of a 2D matrix.
Definition mat-view.hpp:68
static MatrixView as_row(std::span< T > v)
Definition mat-view.hpp:205
MatrixView(PlainMatrixView p)
Definition mat-view.hpp:111
MatrixView middle_cols(index_type c, index_type n) const
Definition mat-view.hpp:175
MatrixView right_cols(index_type n) const
Definition mat-view.hpp:162
MatrixView< T, I, I, O > diagonal()
Definition mat-view.hpp:306
void set_constant(const value_type &t)
Definition mat-view.hpp:235
constexpr index_type col_stride() const
Definition mat-view.hpp:94
MatrixView< T, I, S, transpose(O)> transposed() const
Definition mat-view.hpp:225
MatrixView bottom_right(index_type nr, index_type nc) const
Definition mat-view.hpp:187
MatrixView top_right(index_type nr, index_type nc) const
Definition mat-view.hpp:181
MatrixView top_rows(index_type n) const
Definition mat-view.hpp:132
void iter_diagonal(const auto f)
Definition mat-view.hpp:315
MatrixView block(index_type r, index_type c, index_type nr, index_type nc) const
Definition mat-view.hpp:190
MatrixView middle_rows(index_type r, index_type n) const
Definition mat-view.hpp:172
MatrixView bottom_left(index_type nr, index_type nc) const
Definition mat-view.hpp:184
void copy_values(const Other &other)
Definition mat-view.hpp:338
MatrixView bottom_rows(index_type n) const
Definition mat-view.hpp:152
constexpr auto row_stride() const
Definition mat-view.hpp:88
value_type & operator()(index_type r, index_type c) const
Definition mat-view.hpp:117
MatrixView & operator+=(const U &u)
Definition mat-view.hpp:399
void generate(Generator gen)
Definition mat-view.hpp:412
MatrixView top_left(index_type nr, index_type nc) const
Definition mat-view.hpp:178
constexpr index_type outer_size() const
Definition mat-view.hpp:82
void set_diagonal(const value_type &t)
Definition mat-view.hpp:331
void add_to_diagonal(const value_type &t)
Definition mat-view.hpp:334
MatrixView & reassign(MatrixView other)
Definition mat-view.hpp:417
static MatrixView as_column(std::span< T > v)
Definition mat-view.hpp:195
MatrixView & operator=(const MatrixView &other)
Definition mat-view.hpp:361
void iter_diagonal(const auto f) const
Definition mat-view.hpp:323
MatrixView left_cols(index_type n) const
Definition mat-view.hpp:142
MatrixView(const MatrixView &)=default
void set_constant(const value_type &t, Triangular tr)
Definition mat-view.hpp:277
constexpr index_type inner_size() const
Definition mat-view.hpp:85
bool empty() const
Definition mat-view.hpp:130
void set_constant(const value_type &t, Triangular tr)
Definition mat-view.hpp:248
POD type for designated initializers.
Definition mat-view.hpp:102
static constexpr S value
Definition mat-view.hpp:49