batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
view.hpp
Go to the documentation of this file.
1#pragma once
2
3/// @file
4/// Non-owning view of a batch of matrices.
5/// @ingroup topic-matrix
6
7#include <batmat/assume.hpp>
8#include <batmat/config.hpp>
10#include <guanaqo/mat-view.hpp>
11#include <type_traits>
12
13namespace batmat::matrix {
14
15/// Non-owning view of an array of matrices, stored in an efficient batched format.
16///
17/// A view is a pointer to a buffer of elements, together with a layout describing how the elements
18/// are arranged in memory. Conceptually, a batmat::matrix::View represents a 3D array of shape
19/// `(depth, rows, cols)`, where `depth` is the number of matrices in the array. We refer to a
20/// single matrix in the batch as a "layer". A single layer can be accessed using
21/// [`view(layer)`](@ref View::operator()(index_type)const), and individual elements can be accessed
22/// using [`view(layer, row, col)`](@ref View::operator()(index_type,index_type,index_type)const).
23/// Each layer is either stored in column-major or row-major order, depending on the storage order
24/// of the view. Slicing and tiling are made possible by dynamic outer strides (_leading dimension_
25/// in BLAS parlance) and layer strides.
26///
27/// To enable efficient vectorization of linear algebra operations on views, the data is not stored
28/// in memory as a 3D array, but rather as a 4D array with an additional "batch" dimension, i.e.
29/// with shape `(batch_size, rows, cols, num_batches)`, where `batch_size` is the number of layers
30/// in each batch, and `num_batches = ceil(depth / batch_size)`. The corresponding strides are
31/// `(1, batch_size, outer_stride, layer_stride)` for the column-major case, and
32/// `(1, outer_stride, batch_size, layer_stride)` for the row-major case.
33///
34/// Batches can be accessed using the @ref batch() method, which returns a view of the batch with
35/// the same layout but with `depth` equal to `batch_size`. If the total depth is not a multiple of
36/// the batch size, the last batch is padded.
37///
38/// @see @ref batmat::matrix::Matrix for an owning array of matrices that handles allocation and
39/// converts to a @ref View.
40///
41/// @tparam T
42/// Element value type (possibly const-qualified).
43/// @tparam I
44/// Index and size type. Usually `std::ptrdiff_t` or `int`.
45/// @tparam S
46/// Inner stride type (batch size). Usually `std::integral_constant<I, N>` for some `N`.
47/// @tparam D
48/// Batch depth type. Usually equal to @p S for a single batch, or @p I for a dynamic depth.
49/// @tparam L
50/// Layer stride type. Usually @ref DefaultStride (which implies that the layer stride is
51/// equal to `outer_stride() * outer_size()`), or @p I for a dynamic layer stride.
52/// Dynamic strides are used for subviews of views with a larger `outer_size()`.
53/// @tparam O
54/// %Matrix storage order, @ref guanaqo::StorageOrder::RowMajor "RowMajor" or
55/// @ref guanaqo::StorageOrder::ColMajor "ColMajor".
56/// @ingroup topic-matrix
57template <class T, class I = index_t, class S = std::integral_constant<I, 1>, class D = I,
58 class L = DefaultStride, StorageOrder O = StorageOrder::ColMajor>
59struct View {
61 using value_type = T;
70 static constexpr bool is_row_major = layout_type::is_row_major;
71
72 /// True if @ref batch_size() and @ref depth() are compile-time constants and are equal.
73 /// @note Views with dynamic batch size and depth may still have a single batch at runtime,
74 /// but this cannot be statically asserted.
75 static constexpr bool has_single_batch_at_compile_time = requires {
76 S::value;
77 D::value;
78 } && S{} == D{};
79 /// True if @ref depth() is a compile-time constant and is equal to one.
80 static constexpr bool has_single_layer_at_compile_time = requires { D::value; } && D{} == 1;
81 /// When extracing a single batch, the depth equals the batch size, and the layer stride is no
82 /// longer relevant.
84 /// When slicing along the outer dimension, the layer stride stays the same, but the outer size
85 /// may be smaller, which means that even if the original view has a default layer stride, the
86 /// sliced view may require a dynamic layer stride. For a single batch, the layer stride is not
87 /// relevant, so it is preserved.
89 std::conditional_t<has_single_batch_at_compile_time, View, View<T, I, S, D, I, O>>;
90 /// View with the correct layer stride when slicing along the column dimension.
91 /// For row-major storage, slicing along columns does not change the outer size, in which case
92 /// the layer stride is still correct. For column-major storage, slicing along columns does
93 /// change the outer size, in which case we need a dynamic layer stride.
94 using col_slice_view_type = std::conditional_t<is_row_major, View, general_slice_view_type>;
95 /// View with the correct layer stride when slicing along the row dimension.
96 /// @see @ref col_slice_view_type
97 using row_slice_view_type = std::conditional_t<is_column_major, View, general_slice_view_type>;
98
99 /// Pointer to the first element of the first layer.
101 /// Layout describing the dimensions and strides of the view.
103
104 /// @name Constructors
105 /// @{
106
107 /// POD helper struct to enable designated initializers during construction.
119
120 /// Create a new view.
121 /// @note It is recommended to use designated initializers for the arguments to avoid mistakes.
122 constexpr View(PlainBatchedMatrixView p = {})
123 : data_ptr{p.data}, layout{{.depth = p.depth,
124 .rows = p.rows,
125 .cols = p.cols,
126 .outer_stride = p.outer_stride,
127 .batch_size = p.batch_size,
128 .layer_stride = p.layer_stride}} {}
129 /// Create a new view with the given layout, using the given buffer.
130 constexpr View(std::span<T> data, layout_type layout) : data_ptr{data.data()}, layout{layout} {
131 BATMAT_ASSERT(data.size() == layout.padded_size());
132 }
133 /// Create a new view with the given layout, using the given buffer.
135
136 /// Copy a view. No data is copied.
137 View(const View &) = default;
138
139 /// Reassign the buffer and layout of this view to those of another view. No data is copied.
141 this->data_ptr = other.data_ptr;
142 this->layout = other.layout;
143 return *this;
144 }
145
146 /// @}
147
148 /// @name Element access
149 /// @{
150
151 /// Access a single element at layer @p l, row @p r and column @p c.
152 [[nodiscard]] value_type &operator()(index_type l, index_type r, index_type c) const {
153 return layout(data_ptr, l, r, c);
154 }
155
156 /// @}
157
158 /// @name Batch-wise slicing
159 /// @{
160
161 /// Access a batch of @ref batch_size() layers, starting at batch index @p b (i.e. starting at
162 /// layer `b * batch_size()`).
163 [[nodiscard]] batch_view_type batch(index_type b) const {
164 const auto layer = b * static_cast<index_t>(batch_size());
165 return {{.data = data() + layout.layer_index(layer),
166 .depth = batch_size(),
167 .rows = rows(),
168 .cols = cols(),
169 .outer_stride = outer_stride(),
170 .batch_size = batch_size()}};
171 }
172
173 /// Same as @ref batch(), but returns a view with a dynamic batch size. If the total depth is
174 /// not a multiple of the batch size, the last batch will have a smaller size.
176 const auto d = static_cast<I>(depth());
177 const auto layer = b * static_cast<index_t>(batch_size());
178 const auto last = b == d / batch_size();
179 return {{.data = data() + layout.layer_index(layer),
180 .depth = last ? d - layout.floor_depth() : batch_size(),
181 .rows = rows(),
182 .cols = cols(),
183 .outer_stride = outer_stride(),
184 .batch_size = batch_size(),
185 .layer_stride = layout.layer_stride}};
186 }
187
188 /// Get a view of @p n batches starting at batch @p b, with a stride of @p stride layers.
190 index_type stride = 1) const {
191 const auto bs = static_cast<I>(batch_size());
192 const auto layer = b * bs;
193 BATMAT_ASSERT(n == 0 || layer + (n - 1) * stride * bs + bs <= depth());
194 return {{.data = data() + layout.layer_index(layer),
195 .depth = n,
196 .rows = rows(),
197 .cols = cols(),
198 .outer_stride = outer_stride(),
199 .batch_size = batch_size(),
200 .layer_stride = layout.get_layer_stride() * stride}};
201 }
202
203 /// @}
204
205 /// @name Layer-wise slicing
206 /// @{
207
208 /// Access a single layer @p l as a non-batched view.
209 /// @note The inner stride of the returned view is equal to the batch size of this view, so it
210 /// cannot be used directly with functions that require unit inner stride (e.g. BLAS).
213 return layout(data_ptr, l);
214 }
215
216 /// Get a view of the first @p n layers. Note that @p n can be a compile-time constant.
217 template <class N>
218 [[nodiscard]] View<T, I, S, N, L, O> first_layers(N n) const {
219 BATMAT_ASSERT(n <= depth());
220 return {{.data = data(),
221 .depth = n,
222 .rows = rows(),
223 .cols = cols(),
224 .outer_stride = outer_stride(),
225 .batch_size = batch_size(),
226 .layer_stride = layout.layer_stride}};
227 }
228
229 /// Get a view of @p n layers starting at layer @p l. Note that @p n can be a compile-time
230 /// constant.
231 /// @pre `l % batch_size() == 0` (i.e. the starting layer must be at the start of a batch).
232 template <class N>
233 [[nodiscard]] View<T, I, S, N, L, O> middle_layers(index_type l, N n) const {
234 BATMAT_ASSERT(l + n <= depth());
235 BATMAT_ASSERT(l % static_cast<I>(batch_size()) == 0);
236 return {{.data = data() + layout.layer_index(l),
237 .depth = n,
238 .rows = rows(),
239 .cols = cols(),
240 .outer_stride = outer_stride(),
241 .batch_size = batch_size(),
242 .layer_stride = layout.layer_stride}};
243 }
244
245 /// @}
246
247 /// @name Iterators and buffer access
248 /// @{
249
250 /// Get a pointer to the first element of the first layer.
251 T *data() const { return data_ptr; }
252
253 /// Iterator over all elements of a view.
255 using value_type = T;
256 using reference = T &;
259 T *end;
263
265 ++data;
266 if (data == next_jump && data != end) {
269 }
270 return *this;
271 }
273 linear_iterator t = *this;
274 ++*this;
275 return t;
276 }
277 reference operator*() const { return *data; }
278 bool operator==(std::default_sentinel_t) const { return data == end; }
279 };
280
281 /// Iterate linearly (in storage order) over all elements of the view.
282 /// @pre `has_full_outer_stride()` (i.e. no padding within layers).
283 /// @pre `has_full_layer_stride()` (i.e. no padding between batches).
284 [[nodiscard]] linear_iterator begin() const {
287 // Number of elements in each layer
288 const auto size = layout.rows * layout.cols;
289 // How many layers are in batches that are completely full?
290 const auto contig_layers = layout.floor_depth();
291 // How many layers in total?
292 const auto depth = static_cast<I>(layout.depth);
293 // Remaining layers have padding we should skip over (in the last batch)
294 const auto remaining_layers = depth - contig_layers;
295 // Index of the first padding element
296 const auto first_jump = contig_layers * size + remaining_layers;
297 // Index of last layer in our storage
298 const auto padded_end = layout.ceil_depth();
299 const auto padding_layers = padded_end - depth;
300 const auto end = padded_end * size - padding_layers;
301 const auto batch_size = static_cast<I>(layout.batch_size);
302 return {
303 .data = data(),
304 .end = data() + end,
305 .next_jump = remaining_layers ? data() + first_jump : nullptr,
306 .padding_size = batch_size - remaining_layers,
307 .batch_size = batch_size,
308 };
309 }
310 /// Sentinel for @ref begin().
311 [[nodiscard]] std::default_sentinel_t end() const { return {}; }
312
313 /// @}
314
315 /// @name Dimensions
316 /// @{
317
318 /// Total number of elements in the view (excluding padding).
319 [[nodiscard]] constexpr index_type size() const { return layout.size(); }
320 /// Total number of elements in the view (including all padding).
321 [[nodiscard]] constexpr index_type padded_size() const { return layout.padded_size(); }
322
323 /// Number of layers in the view (i.e. depth).
324 [[nodiscard, gnu::always_inline]] constexpr depth_type depth() const { return layout.depth; }
325 /// The depth rounded up to a multiple of the batch size.
326 [[nodiscard, gnu::always_inline]] constexpr index_type ceil_depth() const {
327 return layout.ceil_depth();
328 }
329 /// The batch size, i.e. the number of layers in each batch. Equals the inner stride.
330 [[nodiscard, gnu::always_inline]] constexpr batch_size_type batch_size() const {
331 return layout.batch_size;
332 }
333 /// Number of batches in the view, i.e. `ceil_depth() / batch_size()`.
334 [[nodiscard, gnu::always_inline]] constexpr index_type num_batches() const {
335 return layout.num_batches();
336 }
337 /// Number of rows of the matrices.
338 [[nodiscard, gnu::always_inline]] constexpr index_type rows() const { return layout.rows; }
339 /// Number of columns of the matrices.
340 [[nodiscard, gnu::always_inline]] constexpr index_type cols() const { return layout.cols; }
341 /// The size of the outer dimension, i.e. the number of columns for column-major storage, or the
342 /// number of rows for row-major storage.
343 [[nodiscard, gnu::always_inline]] constexpr index_type outer_size() const {
344 return layout.outer_size();
345 }
346 /// The size of the inner dimension, i.e. the number of rows for column-major storage, or the
347 /// number of columns for row-major storage.
348 [[nodiscard, gnu::always_inline]] constexpr index_type inner_size() const {
349 return layout.inner_size();
350 }
351
352 /// @}
353
354 /// @name Strides
355 /// @{
356
357 /// Outer stride of the matrices (leading dimension in BLAS parlance). Should be multiplied by
358 /// the batch size to get the actual number of elements.
359 [[nodiscard, gnu::always_inline]] constexpr index_type outer_stride() const {
360 return layout.outer_stride;
361 }
362 /// The inner stride of the matrices. Should be multiplied by the batch size to get the actual
363 /// number of elements.
364 [[nodiscard, gnu::always_inline]] constexpr auto inner_stride() const {
365 return layout.inner_stride;
366 }
367 /// The row stride of the matrices, i.e. the distance between elements in consecutive rows in
368 /// a given column. Should be multiplied by the batch size to get the actual number of elements.
369 [[nodiscard, gnu::always_inline]] constexpr auto row_stride() const {
370 return layout.row_stride();
371 }
372 /// The column stride of the matrices, i.e. the distance between elements in consecutive columns
373 /// in a given row. Should be multiplied by the batch size to get the actual number of elements.
374 [[nodiscard, gnu::always_inline]] constexpr auto col_stride() const {
375 return layout.col_stride();
376 }
377 /// The layer stride, i.e. the distance between the first layer of one batch and the first layer
378 /// of the next batch. Should be multiplied by the batch size to get the actual number of
379 /// elements.
380 [[nodiscard, gnu::always_inline]] constexpr index_type layer_stride() const {
381 return layout.get_layer_stride();
382 }
383 /// Whether the `layer_stride() == outer_stride() * outer_size()`.
384 [[nodiscard, gnu::always_inline]] constexpr bool has_full_layer_stride() const {
385 return layout.has_full_layer_stride();
386 }
387 /// Whether the `outer_stride() == inner_stride() * inner_size()`.
388 [[nodiscard, gnu::always_inline]] constexpr bool has_full_outer_stride() const {
389 return layout.has_full_outer_stride();
390 }
391 /// Whether the `inner_stride() == 1`. Always true.
392 [[nodiscard, gnu::always_inline]] constexpr bool has_full_inner_stride() const {
393 return layout.has_full_inner_stride();
394 }
395
396 /// @}
397
398 private:
399 template <class V>
400 constexpr auto get_layer_stride_for() const {
401 if constexpr (std::is_same_v<typename V::layer_stride_type, layer_stride_type>)
402 return layout.layer_stride;
403 else
404 return layer_stride();
405 }
406
407 public:
408 /// @name Reshaping and slicing
409 /// @{
410
411 /// Reshape the view to the given dimensions. The total size should not change.
413 BATMAT_ASSERT(rows * cols == this->rows() * this->cols());
415 return general_slice_view_type{typename general_slice_view_type::PlainBatchedMatrixView{
416 .data = data(),
417 .depth = depth(),
418 .rows = rows,
419 .cols = cols,
420 .outer_stride = is_row_major ? cols : rows,
421 .batch_size = batch_size(),
422 .layer_stride = this->get_layer_stride_for<general_slice_view_type>()}};
423 }
424
425 /// Get a view of the first @p n rows.
426 [[nodiscard]] row_slice_view_type top_rows(index_type n) const {
427 BATMAT_ASSERT(0 <= n && n <= rows());
428 return row_slice_view_type{typename row_slice_view_type::PlainBatchedMatrixView{
429 .data = data(),
430 .depth = depth(),
431 .rows = n,
432 .cols = cols(),
433 .outer_stride = outer_stride(),
434 .batch_size = batch_size(),
435 .layer_stride = this->get_layer_stride_for<row_slice_view_type>()}};
436 }
437
438 /// Get a view of the first @p n columns.
439 [[nodiscard]] col_slice_view_type left_cols(index_type n) const {
440 BATMAT_ASSERT(0 <= n && n <= cols());
441 return col_slice_view_type{typename col_slice_view_type::PlainBatchedMatrixView{
442 .data = data(),
443 .depth = depth(),
444 .rows = rows(),
445 .cols = n,
446 .outer_stride = outer_stride(),
447 .batch_size = batch_size(),
448 .layer_stride = this->get_layer_stride_for<col_slice_view_type>()}};
449 }
450
451 /// Get a view of the last @p n rows.
453 BATMAT_ASSERT(0 <= n && n <= rows());
454 const auto bs = static_cast<I>(batch_size());
455 const auto offset = (is_row_major ? outer_stride() : 1) * bs * (rows() - n);
456 return row_slice_view_type{typename row_slice_view_type::PlainBatchedMatrixView{
457 .data = data() + offset,
458 .depth = depth(),
459 .rows = n,
460 .cols = cols(),
461 .outer_stride = outer_stride(),
462 .batch_size = batch_size(),
463 .layer_stride = this->get_layer_stride_for<row_slice_view_type>()}};
464 }
465
466 /// Get a view of the last @p n columns.
468 BATMAT_ASSERT(0 <= n && n <= cols());
469 const auto bs = static_cast<I>(batch_size());
470 const auto offset = (is_row_major ? 1 : outer_stride()) * bs * (cols() - n);
471 return col_slice_view_type{typename col_slice_view_type::PlainBatchedMatrixView{
472 .data = data() + offset,
473 .depth = depth(),
474 .rows = rows(),
475 .cols = n,
476 .outer_stride = outer_stride(),
477 .batch_size = batch_size(),
478 .layer_stride = this->get_layer_stride_for<col_slice_view_type>()}};
479 }
480
481 /// Get a view of @p n rows starting at row @p r.
483 return bottom_rows(rows() - r).top_rows(n);
484 }
485
486 /// Get a view of @p n rows starting at row @p r, with stride @p stride.
488 index_type stride) const
489 requires is_row_major
490 {
491 BATMAT_ASSERT(0 <= r);
492 BATMAT_ASSERT(r + (n - 1) * stride < rows());
493 const auto bs = static_cast<I>(batch_size());
494 const auto offset = outer_stride() * bs * r;
495 return row_slice_view_type{typename row_slice_view_type::PlainBatchedMatrixView{
496 .data = data() + offset,
497 .depth = depth(),
498 .rows = n,
499 .cols = cols(),
500 .outer_stride = outer_stride() * stride,
501 .batch_size = batch_size(),
502 .layer_stride = this->get_layer_stride_for<row_slice_view_type>()}};
503 }
504
505 /// Get a view of @p n columns starting at column @p c.
507 return right_cols(cols() - c).left_cols(n);
508 }
509
510 /// Get a view of @p n columns starting at column @p c, with stride @p stride.
512 index_type stride) const
513 requires is_column_major
514 {
515 BATMAT_ASSERT(0 <= c);
516 BATMAT_ASSERT(c + (n - 1) * stride < cols());
517 const auto bs = static_cast<I>(batch_size());
518 const auto offset = outer_stride() * bs * c;
519 return col_slice_view_type{typename col_slice_view_type::PlainBatchedMatrixView{
520 .data = data() + offset,
521 .depth = depth(),
522 .rows = rows(),
523 .cols = n,
524 .outer_stride = outer_stride() * stride,
525 .batch_size = batch_size(),
526 .layer_stride = this->get_layer_stride_for<col_slice_view_type>()}};
527 }
528
529 /// Get a view of the top-left @p nr by @p nc block of the matrices.
531 return top_rows(nr).left_cols(nc);
532 }
533
534 /// Get a view of the top-right @p nr by @p nc block of the matrices.
536 return top_rows(nr).right_cols(nc);
537 }
538
539 /// Get a view of the bottom-left @p nr by @p nc block of the matrices.
541 return bottom_rows(nr).left_cols(nc);
542 }
543
544 /// Get a view of the bottom-right @p nr by @p nc block of the matrices.
546 return bottom_rows(nr).right_cols(nc);
547 }
548
549 /// Get a view of the @p nr by @p nc block of the matrices starting at row @p r and column @p c.
551 index_type nc) const {
552 return middle_rows(r, nr).middle_cols(c, nc);
553 }
554
555 /// Get a view of the given span as a column vector.
556 [[nodiscard]] static View as_column(std::span<T> v) {
557 return {{.data = v.data(), .rows = static_cast<index_type>(v.size()), .cols = 1}};
558 }
559
560 /// Get a transposed view of the matrices. Note that the data itself is not modified, the
561 /// returned view simply accesses the same data with rows and column indices swapped.
562 [[nodiscard]] auto transposed() const {
563 using TpBm = View<T, I, S, D, L, transpose(O)>;
564 return TpBm{typename TpBm::PlainBatchedMatrixView{.data = data(),
565 .depth = depth(),
566 .rows = cols(),
567 .cols = rows(),
568 .outer_stride = outer_stride(),
569 .batch_size = batch_size(),
570 .layer_stride = layout.layer_stride}};
571 }
572
573 /// @}
574
575 /// @name Value manipulation
576 /// @{
577
579 const auto bs = static_cast<I>(batch_size());
580 const auto n = std::min(rows(), cols());
581 for (index_type b = 0; b < num_batches(); ++b) {
582 auto *p = batch(b).data();
583 for (index_type i = 0; i < n; ++i) {
584 for (index_type r = 0; r < bs; ++r)
585 *p++ += t;
586 p += bs * outer_stride();
587 }
588 }
589 }
590
592 const auto bs = static_cast<I>(batch_size());
593 for (index_type b = 0; b < num_batches(); ++b) {
594 auto *dst = this->batch(b).data();
595 for (index_type c = 0; c < this->outer_size(); ++c) {
596 auto *dst_ = dst;
597 const index_type n = inner_size() * bs;
598 for (index_type r = 0; r < n; ++r)
599 *dst_++ = t;
600 dst += bs * this->outer_stride();
601 }
602 }
603 }
604
605 void negate() {
606 const auto bs = static_cast<I>(batch_size());
607 for (index_type b = 0; b < num_batches(); ++b) {
608 auto *dst = this->batch(b).data();
609 for (index_type c = 0; c < this->outer_size(); ++c) {
610 auto *dst_ = dst;
611 const index_type n = inner_size() * bs;
612 for (index_type r = 0; r < n; ++r, ++dst_)
613 *dst_ = -*dst_;
614 dst += bs * this->outer_stride();
615 }
616 }
617 }
618
619 template <class Other>
620 void copy_values(const Other &other) const {
621 static_assert(is_row_major == Other::is_row_major);
622 assert(other.rows() == this->rows());
623 assert(other.cols() == this->cols());
624 assert(other.batch_size() == this->batch_size());
625 const auto bs = static_cast<I>(batch_size());
626 for (index_type b = 0; b < num_batches(); ++b) {
627 const auto *src = other.batch(b).data();
628 auto *dst = this->batch(b).data();
629 for (index_type c = 0; c < this->outer_size(); ++c) {
630 const auto *src_ = src;
631 auto *dst_ = dst;
632 const index_type n = inner_size() * bs;
633 for (index_type r = 0; r < n; ++r)
634 *dst_++ = *src_++;
635 src += bs * other.outer_stride();
636 dst += bs * this->outer_stride();
637 }
638 }
639 }
640
641 /// Copy assignment copies the values from another view with the same layout to this view.
642 View &operator=(const View &other) {
643 if (this != &other)
644 copy_values(other);
645 return *this;
646 }
647 /// Copy values from another view with a compatible value type and the same layout to this view.
648 template <class U, class J, class R, class E, class M>
649 requires(!std::is_const_v<T> && std::convertible_to<U, std::remove_cv_t<T>> &&
650 std::equality_comparable_with<I, J>)
651 View &operator=(View<U, J, R, E, M, O> other) {
652 copy_values(other);
653 return *this;
654 }
655 // TODO: abstract logic into generic function (and check performance)
656 template <class U, class J, class R, class E, class M>
657 requires(!std::is_const_v<T> && std::convertible_to<U, std::remove_cv_t<T>> &&
658 std::equality_comparable_with<I, J>)
659 View &operator+=(View<U, J, R, E, M, O> other) {
660 assert(other.rows() == this->rows());
661 assert(other.cols() == this->cols());
662 assert(other.batch_size() == this->batch_size());
663 const auto bs = static_cast<I>(batch_size());
664 for (index_type b = 0; b < num_batches(); ++b) {
665 const auto *src = other.batch(b).data();
666 auto *dst = this->batch(b).data();
667 for (index_type c = 0; c < this->outer_size(); ++c) {
668 const auto *src_ = src;
669 auto *dst_ = dst;
670 const index_type n = inner_size() * bs;
671 for (index_type r = 0; r < n; ++r)
672 *dst_++ += *src_++;
673 src += bs * other.outer_stride();
674 dst += bs * this->outer_stride();
675 }
676 }
677 return *this;
678 }
679
680 /// @}
681
682 /// @name View conversions
683 /// @{
684
685 /// Returns the same view. For consistency with @ref Matrix.
686 [[nodiscard]] View view() const { return *this; }
687 /// Explicit conversion to a const view.
688 [[nodiscard]] const_view_type as_const() const { return *this; }
689
690 /// Non-const views implicitly convert to const views.
691 operator const_view_type() const
692 requires(!std::is_const_v<T>)
693 {
694 return {data_ptr, layout};
695 }
696
697 /// If we have a single layer at compile time, we can implicitly convert to a non-batched view.
703
704 /// Implicit conversion to a view with a dynamic depth.
706 requires(!std::same_as<integral_value_type_t<D>, D>)
707 {
708 const auto bs = static_cast<integral_value_type_t<D>>(batch_size());
709 return {{.data = data(),
710 .depth = depth(),
711 .rows = rows(),
712 .cols = cols(),
713 .outer_stride = outer_stride(),
714 .batch_size = bs,
715 .layer_stride = layout.layer_stride}};
716 }
717 /// Implicit conversion to a view with a dynamic depth, going from non-const to const.
719 requires(!std::is_const_v<T> && !std::same_as<integral_value_type_t<D>, D>)
720 {
721 const auto bs = static_cast<integral_value_type_t<D>>(batch_size());
722 return {{.data = data(),
723 .depth = depth(),
724 .rows = rows(),
725 .cols = cols(),
726 .outer_stride = outer_stride(),
727 .batch_size = bs,
728 .layer_stride = layout.layer_stride}};
729 }
730 /// Implicit conversion to a view with a dynamic layer stride.
731 operator View<T, I, S, D, I, O>() const
732 requires(!std::same_as<I, L>)
733 {
734 return {{.data = data(),
735 .depth = depth(),
736 .rows = rows(),
737 .cols = cols(),
738 .outer_stride = outer_stride(),
739 .batch_size = batch_size(),
740 .layer_stride = layer_stride()}};
741 }
742 /// Implicit conversion to a view with a dynamic layer stride, going from non-const to const.
744 requires(!std::is_const_v<T> && !std::same_as<I, L>)
745 {
746 return {{.data = data(),
747 .depth = depth(),
748 .rows = rows(),
749 .cols = cols(),
750 .outer_stride = outer_stride(),
751 .batch_size = batch_size(),
752 .layer_stride = layer_stride()}};
753 }
754
755 /// @}
756};
757
758template <class T, class I, class S, class D, class L, StorageOrder P>
759bool operator==(std::default_sentinel_t s, typename View<T, I, S, D, L, P>::linear_iterator i) {
760 return i == s;
761}
762template <class T, class I, class S, class D, class L, StorageOrder P>
763bool operator!=(std::default_sentinel_t s, typename View<T, I, S, D, L, P>::linear_iterator i) {
764 return !(i == s);
765}
766template <class T, class I, class S, class D, class L, StorageOrder P>
767bool operator!=(typename View<T, I, S, D, L, P>::linear_iterator i, std::default_sentinel_t s) {
768 return !(i == s);
769}
770
771// TODO: tag-invoke style CPOs instead of free functions with ADL?
772
773template <class T, class I, class S, class D, class L, StorageOrder O>
774constexpr auto data(const View<T, I, S, D, L, O> &v) {
775 return v.data();
776}
777template <class T, class I, class S, class D, class L, StorageOrder O>
778constexpr auto rows(const View<T, I, S, D, L, O> &v) {
779 return v.rows();
780}
781template <class T, class I, class S, class D, class L, StorageOrder O>
782constexpr auto cols(const View<T, I, S, D, L, O> &v) {
783 return v.cols();
784}
785template <class T, class I, class S, class D, class L, StorageOrder O>
786constexpr auto outer_stride(const View<T, I, S, D, L, O> &v) {
787 return v.outer_stride();
788}
789template <class T, class I, class S, class D, class L, StorageOrder O>
790constexpr auto depth(const View<T, I, S, D, L, O> &v) {
791 return v.depth();
792}
793
794} // namespace batmat::matrix
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
Layout description for a batch of matrices, independent of any storage.
constexpr auto cols(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:497
typename integral_value_type< T >::type integral_value_type_t
Definition layout.hpp:25
constexpr auto data(Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:483
constexpr auto outer_stride(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:501
bool operator!=(std::default_sentinel_t s, typename View< T, I, S, D, L, P >::linear_iterator i)
Definition view.hpp:763
constexpr auto rows(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:493
bool operator==(std::default_sentinel_t s, typename View< T, I, S, D, L, P >::linear_iterator i)
Definition view.hpp:759
constexpr auto depth(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:505
Shape and strides describing a batch of matrices, independent of any storage.
Definition layout.hpp:50
static constexpr bool is_column_major
Definition layout.hpp:58
static constexpr bool is_row_major
Definition layout.hpp:59
static constexpr StorageOrder storage_order
Definition layout.hpp:57
std::conditional_t< requires { S::value; }, std::integral_constant< index_t, S::value >, index_t > standard_stride_type
Definition layout.hpp:62
Iterator over all elements of a view.
Definition view.hpp:254
bool operator==(std::default_sentinel_t) const
Definition view.hpp:278
linear_iterator operator++(int)
Definition view.hpp:272
linear_iterator & operator++()
Definition view.hpp:264
Non-owning view of an array of matrices, stored in an efficient batched format.
Definition view.hpp:59
constexpr View(PlainBatchedMatrixView p={})
Create a new view.
Definition view.hpp:122
Layout< index_t, stride, stride, layer_stride, O > layout_type
Definition view.hpp:60
View< T, I, S, N, L, O > first_layers(N n) const
Get a view of the first n layers. Note that n can be a compile-time constant.
Definition view.hpp:218
typename layout_type::layer_stride_type layer_stride_type
Definition view.hpp:65
guanaqo::MatrixView< T, I, standard_stride_type, O > operator()(index_type l) const
Access a single layer l as a non-batched view.
Definition view.hpp:212
general_slice_view_type bottom_right(index_type nr, index_type nc) const
Get a view of the bottom-right nr by nc block of the matrices.
Definition view.hpp:545
typename layout_type::batch_size_type batch_size_type
Definition view.hpp:63
typename layout_type::standard_stride_type standard_stride_type
Definition view.hpp:66
View< T, index_t, stride, stride, DefaultStride, O > batch_view_type
Definition view.hpp:83
general_slice_view_type block(index_type r, index_type c, index_type nr, index_type nc) const
Get a view of the nr by nc block of the matrices starting at row r and column c.
Definition view.hpp:550
void set_constant(value_type t)
Definition view.hpp:591
value_type & operator()(index_type l, index_type r, index_type c) const
Access a single element at layer l, row r and column c.
Definition view.hpp:152
auto transposed() const
Get a transposed view of the matrices.
Definition view.hpp:562
constexpr auto inner_stride() const
The inner stride of the matrices.
Definition view.hpp:364
constexpr index_type num_batches() const
Number of batches in the view, i.e. ceil_depth() / batch_size().
Definition view.hpp:334
col_slice_view_type middle_cols(index_type c, index_type n, index_type stride) const
Get a view of n columns starting at column c, with stride stride.
Definition view.hpp:511
constexpr View(value_type *data, layout_type layout)
Create a new view with the given layout, using the given buffer.
Definition view.hpp:134
col_slice_view_type middle_cols(index_type c, index_type n) const
Get a view of n columns starting at column c.
Definition view.hpp:506
constexpr index_type padded_size() const
Total number of elements in the view (including all padding).
Definition view.hpp:321
void add_to_diagonal(const value_type &t)
Definition view.hpp:578
linear_iterator begin() const
Iterate linearly (in storage order) over all elements of the view.
Definition view.hpp:284
constexpr auto get_layer_stride_for() const
Definition view.hpp:400
row_slice_view_type bottom_rows(index_type n) const
Get a view of the last n rows.
Definition view.hpp:452
general_slice_view_type reshaped(index_type rows, index_type cols) const
Reshape the view to the given dimensions. The total size should not change.
Definition view.hpp:412
View< T, I, S, N, L, O > middle_layers(index_type l, N n) const
Get a view of n layers starting at layer l.
Definition view.hpp:233
View(const View &)=default
Copy a view. No data is copied.
constexpr auto col_stride() const
The column stride of the matrices, i.e.
Definition view.hpp:374
general_slice_view_type bottom_left(index_type nr, index_type nc) const
Get a view of the bottom-left nr by nc block of the matrices.
Definition view.hpp:540
View< const T, index_t, stride, stride, layer_stride, O > const_view_type
Definition view.hpp:67
row_slice_view_type middle_rows(index_type r, index_type n) const
Get a view of n rows starting at row r.
Definition view.hpp:482
constexpr index_type cols() const
Number of columns of the matrices.
Definition view.hpp:340
col_slice_view_type right_cols(index_type n) const
Get a view of the last n columns.
Definition view.hpp:467
View< T, I, S, I, L, O > batch_dyn(index_type b) const
Same as batch(), but returns a view with a dynamic batch size.
Definition view.hpp:175
constexpr bool has_full_inner_stride() const
Whether the inner_stride() == 1. Always true.
Definition view.hpp:392
col_slice_view_type left_cols(index_type n) const
Get a view of the first n columns.
Definition view.hpp:439
static View as_column(std::span< T > v)
Get a view of the given span as a column vector.
Definition view.hpp:556
row_slice_view_type top_rows(index_type n) const
Get a view of the first n rows.
Definition view.hpp:426
constexpr index_type rows() const
Number of rows of the matrices.
Definition view.hpp:338
constexpr index_type inner_size() const
The size of the inner dimension, i.e.
Definition view.hpp:348
constexpr index_type ceil_depth() const
The depth rounded up to a multiple of the batch size.
Definition view.hpp:326
View & operator=(const View &other)
Copy assignment copies the values from another view with the same layout to this view.
Definition view.hpp:642
std::conditional_t< has_single_batch_at_compile_time, View, View< T, index_t, stride, stride, index_t, O > > general_slice_view_type
Definition view.hpp:88
constexpr auto row_stride() const
The row stride of the matrices, i.e.
Definition view.hpp:369
std::conditional_t< is_column_major, View, general_slice_view_type > row_slice_view_type
Definition view.hpp:97
row_slice_view_type middle_rows(index_type r, index_type n, index_type stride) const
Get a view of n rows starting at row r, with stride stride.
Definition view.hpp:487
constexpr View(std::span< T > data, layout_type layout)
Create a new view with the given layout, using the given buffer.
Definition view.hpp:130
general_slice_view_type top_left(index_type nr, index_type nc) const
Get a view of the top-left nr by nc block of the matrices.
Definition view.hpp:530
constexpr index_type outer_size() const
The size of the outer dimension, i.e.
Definition view.hpp:343
View view() const
Returns the same view. For consistency with Matrix.
Definition view.hpp:686
constexpr index_type layer_stride() const
The layer stride, i.e.
Definition view.hpp:380
View< T, I, S, I, I, O > middle_batches(index_type b, index_type n, index_type stride=1) const
Get a view of n batches starting at batch b, with a stride of stride layers.
Definition view.hpp:189
const_view_type as_const() const
Explicit conversion to a const view.
Definition view.hpp:688
general_slice_view_type top_right(index_type nr, index_type nc) const
Get a view of the top-right nr by nc block of the matrices.
Definition view.hpp:535
std::conditional_t< is_row_major, View, general_slice_view_type > col_slice_view_type
Definition view.hpp:94
void copy_values(const Other &other) const
Definition view.hpp:620
View & reassign(View other)
Reassign the buffer and layout of this view to those of another view. No data is copied.
Definition view.hpp:140
constexpr index_type outer_stride() const
Outer stride of the matrices (leading dimension in BLAS parlance).
Definition view.hpp:359
batch_view_type batch(index_type b) const
Access a batch of batch_size() layers, starting at batch index b (i.e.
Definition view.hpp:163
POD helper struct to enable designated initializers during construction.
Definition view.hpp:108