batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
view.hpp
Go to the documentation of this file.
1#pragma once
2
3/// @file
4/// Non-owning view of a batch of matrices.
5/// @ingroup topic-matrix
6
7#include <batmat/assume.hpp>
8#include <batmat/config.hpp>
10#include <guanaqo/mat-view.hpp>
11#include <type_traits>
12
13namespace batmat::matrix {
14
15/// Non-owning view of an array of matrices, stored in an efficient batched format.
16///
17/// A view is a pointer to a buffer of elements, together with a layout describing how the elements
18/// are arranged in memory. Conceptually, a batmat::matrix::View represents a 3D array of shape
19/// `(depth, rows, cols)`, where `depth` is the number of matrices in the array. We refer to a
20/// single matrix in the batch as a "layer". A single layer can be accessed using
21/// [`view(layer)`](@ref View::operator()(index_type)const), and individual elements can be accessed
22/// using [`view(layer, row, col)`](@ref View::operator()(index_type,index_type,index_type)const).
23/// Each layer is either stored in column-major or row-major order, depending on the storage order
24/// of the view. Slicing and tiling are made possible by dynamic outer strides (_leading dimension_
25/// in BLAS parlance) and layer strides.
26///
27/// To enable efficient vectorization of linear algebra operations on views, the data is not stored
28/// in memory as a 3D array, but rather as a 4D array with an additional "batch" dimension, i.e.
29/// with shape `(batch_size, rows, cols, num_batches)`, where `batch_size` is the number of layers
30/// in each batch, and `num_batches = ceil(depth / batch_size)`. The corresponding strides are
31/// `(1, batch_size, outer_stride, layer_stride)` for the column-major case, and
32/// `(1, outer_stride, batch_size, layer_stride)` for the row-major case.
33///
34/// Batches can be accessed using the @ref batch() method, which returns a view of the batch with
35/// the same layout but with `depth` equal to `batch_size`. If the total depth is not a multiple of
36/// the batch size, the last batch is padded.
37///
38/// @see @ref batmat::matrix::Matrix for an owning array of matrices that handles allocation and
39/// converts to a @ref View.
40///
41/// @tparam T
42/// Element value type (possibly const-qualified).
43/// @tparam I
44/// Index and size type. Usually `std::ptrdiff_t` or `int`.
45/// @tparam S
46/// Inner stride type (batch size). Usually `std::integral_constant<I, N>` for some `N`.
47/// @tparam D
48/// Batch depth type. Usually equal to @p S for a single batch, or @p I for a dynamic depth.
49/// @tparam L
50/// Layer stride type. Usually @ref DefaultStride (which implies that the layer stride is
51/// equal to `outer_stride() * outer_size()`), or @p I for a dynamic layer stride.
52/// Dynamic strides are used for subviews of views with a larger `outer_size()`.
53/// @tparam O
54/// %Matrix storage order, @ref guanaqo::StorageOrder::RowMajor "RowMajor" or
55/// @ref guanaqo::StorageOrder::ColMajor "ColMajor".
56/// @ingroup topic-matrix
57template <class T, class I = index_t, class S = std::integral_constant<I, 1>, class D = I,
58 class L = DefaultStride, StorageOrder O = StorageOrder::ColMajor>
59struct View {
61 using value_type = T;
70 static constexpr bool is_row_major = layout_type::is_row_major;
71
72 /// True if @ref batch_size() and @ref depth() are compile-time constants and are equal.
73 /// @note Views with dynamic batch size and depth may still have a single batch at runtime,
74 /// but this cannot be statically asserted.
75 static constexpr bool has_single_batch_at_compile_time = requires {
76 S::value;
77 D::value;
78 } && S{} == D{};
79 /// True if @ref depth() is a compile-time constant and is equal to one.
80 static constexpr bool has_single_layer_at_compile_time = requires { D::value; } && D{} == 1;
81 /// When extracing a single batch, the depth equals the batch size, and the layer stride is no
82 /// longer relevant.
84 /// When slicing along the outer dimension, the layer stride stays the same, but the outer size
85 /// may be smaller, which means that even if the original view has a default layer stride, the
86 /// sliced view may require a dynamic layer stride. For a single batch, the layer stride is not
87 /// relevant, so it is preserved.
89 std::conditional_t<has_single_batch_at_compile_time, View, View<T, I, S, D, I, O>>;
90 /// View with the correct layer stride when slicing along the column dimension.
91 /// For row-major storage, slicing along columns does not change the outer size, in which case
92 /// the layer stride is still correct. For column-major storage, slicing along columns does
93 /// change the outer size, in which case we need a dynamic layer stride.
94 using col_slice_view_type = std::conditional_t<is_row_major, View, general_slice_view_type>;
95 /// View with the correct layer stride when slicing along the row dimension.
96 /// @see @ref col_slice_view_type
97 using row_slice_view_type = std::conditional_t<is_column_major, View, general_slice_view_type>;
98
99 /// Pointer to the first element of the first layer.
101 /// Layout describing the dimensions and strides of the view.
103
104 /// @name Constructors
105 /// @{
106
107 /// POD helper struct to enable designated initializers during construction.
119
120 /// Create a new view.
121 /// @note It is recommended to use designated initializers for the arguments to avoid mistakes.
122 constexpr View(PlainBatchedMatrixView p = {})
123 : data_ptr{p.data}, layout{{.depth = p.depth,
124 .rows = p.rows,
125 .cols = p.cols,
126 .outer_stride = p.outer_stride,
127 .batch_size = p.batch_size,
128 .layer_stride = p.layer_stride}} {}
129 /// Create a new view with the given layout, using the given buffer.
130 constexpr View(std::span<T> data, layout_type layout) : data_ptr{data.data()}, layout{layout} {
131 BATMAT_ASSERT(data.size() == layout.padded_size());
132 }
133 /// Create a new view with the given layout, using the given buffer.
135
136 /// Copy a view. No data is copied.
137 View(const View &) = default;
138
139 /// Reassign the buffer and layout of this view to those of another view. No data is copied.
141 this->data_ptr = other.data_ptr;
142 this->layout = other.layout;
143 return *this;
144 }
145
146 /// @}
147
148 /// @name Element access
149 /// @{
150
151 /// Access a single element at layer @p l, row @p r and column @p c.
152 [[nodiscard]] value_type &operator()(index_type l, index_type r, index_type c) const {
153 return layout(data_ptr, l, r, c);
154 }
155
156 /// @}
157
158 /// @name Batch-wise slicing
159 /// @{
160
161 /// Access a batch of @ref batch_size() layers, starting at batch index @p b (i.e. starting at
162 /// layer `b * batch_size()`).
163 [[nodiscard]] batch_view_type batch(index_type b) const {
164 const auto layer = b * static_cast<index_t>(batch_size());
165 return {{.data = data() ? data() + layout.layer_index(layer) : nullptr,
166 .depth = batch_size(),
167 .rows = rows(),
168 .cols = cols(),
169 .outer_stride = outer_stride(),
170 .batch_size = batch_size()}};
171 }
172
173 /// Same as @ref batch(), but returns a view with a dynamic batch size. If the total depth is
174 /// not a multiple of the batch size, the last batch will have a smaller size.
176 const auto d = static_cast<I>(depth());
177 const auto layer = b * static_cast<index_t>(batch_size());
178 const auto last = b == d / batch_size();
179 return {{.data = data() ? data() + layout.layer_index(layer) : nullptr,
180 .depth = last ? d - layout.floor_depth() : batch_size(),
181 .rows = rows(),
182 .cols = cols(),
183 .outer_stride = outer_stride(),
184 .batch_size = batch_size(),
185 .layer_stride = layout.layer_stride}};
186 }
187
188 /// Get a view of @p n batches starting at batch @p b.
190 const auto bs = static_cast<I>(batch_size());
191 const auto layer = b * bs;
192 BATMAT_ASSERT(n == 0 || layer + (n - 1) * bs + bs <= depth());
193 return {{.data = data() ? data() + layout.layer_index(layer) : nullptr,
194 .depth = n * bs,
195 .rows = rows(),
196 .cols = cols(),
197 .outer_stride = outer_stride(),
198 .batch_size = batch_size(),
199 .layer_stride = layout.layer_stride}};
200 }
201
202 /// Get a view of @p n batches starting at batch @p b, with a stride of @p stride batches.
204 index_type stride) const {
205 const auto bs = static_cast<I>(batch_size());
206 const auto layer = b * bs;
207 BATMAT_ASSERT(n == 0 || layer + (n - 1) * stride * bs + bs <= depth());
208 return {{.data = data() ? data() + layout.layer_index(layer) : nullptr,
209 .depth = n * bs,
210 .rows = rows(),
211 .cols = cols(),
212 .outer_stride = outer_stride(),
213 .batch_size = batch_size(),
214 .layer_stride = layout.get_layer_stride() * stride}};
215 }
216
217 /// @}
218
219 /// @name Layer-wise slicing
220 /// @{
221
222 /// Access a single layer @p l as a non-batched view.
223 /// @note The inner stride of the returned view is equal to the batch size of this view, so it
224 /// cannot be used directly with functions that require unit inner stride (e.g. BLAS).
227 return layout(data_ptr, l);
228 }
229
230 /// Get a view of the first @p n layers. Note that @p n can be a compile-time constant.
231 template <class N>
232 [[nodiscard]] View<T, I, S, N, L, O> first_layers(N n) const {
233 BATMAT_ASSERT(n <= depth());
234 return {{.data = data(),
235 .depth = n,
236 .rows = rows(),
237 .cols = cols(),
238 .outer_stride = outer_stride(),
239 .batch_size = batch_size(),
240 .layer_stride = layout.layer_stride}};
241 }
242
243 /// Get a view of @p n layers starting at layer @p l. Note that @p n can be a compile-time
244 /// constant.
245 /// @pre `l % batch_size() == 0` (i.e. the starting layer must be at the start of a batch).
246 template <class N>
247 [[nodiscard]] View<T, I, S, N, L, O> middle_layers(index_type l, N n) const {
248 BATMAT_ASSERT(l + n <= depth());
249 BATMAT_ASSERT(l % static_cast<I>(batch_size()) == 0);
250 return {{.data = data() + layout.layer_index(l),
251 .depth = n,
252 .rows = rows(),
253 .cols = cols(),
254 .outer_stride = outer_stride(),
255 .batch_size = batch_size(),
256 .layer_stride = layout.layer_stride}};
257 }
258
259 /// @}
260
261 /// @name Iterators and buffer access
262 /// @{
263
264 /// Get a pointer to the first element of the first layer.
265 T *data() const { return data_ptr; }
266
267 /// Iterator over all elements of a view.
269 using value_type = T;
270 using reference = T &;
273 T *end;
277
279 ++data;
280 if (data == next_jump && data != end) {
283 }
284 return *this;
285 }
287 linear_iterator t = *this;
288 ++*this;
289 return t;
290 }
291 reference operator*() const { return *data; }
292 bool operator==(std::default_sentinel_t) const { return data == end; }
293 };
294
295 /// Iterate linearly (in storage order) over all elements of the view.
296 /// @pre `has_full_outer_stride()` (i.e. no padding within layers).
297 /// @pre `has_full_layer_stride()` (i.e. no padding between batches).
298 [[nodiscard]] linear_iterator begin() const {
301 // Number of elements in each layer
302 const auto size = layout.rows * layout.cols;
303 // How many layers are in batches that are completely full?
304 const auto contig_layers = layout.floor_depth();
305 // How many layers in total?
306 const auto depth = static_cast<I>(layout.depth);
307 // Remaining layers have padding we should skip over (in the last batch)
308 const auto remaining_layers = depth - contig_layers;
309 // Index of the first padding element
310 const auto first_jump = contig_layers * size + remaining_layers;
311 // Index of last layer in our storage
312 const auto padded_end = layout.ceil_depth();
313 const auto padding_layers = padded_end - depth;
314 const auto end = padded_end * size - padding_layers;
315 const auto batch_size = static_cast<I>(layout.batch_size);
316 return {
317 .data = data(),
318 .end = data() + end,
319 .next_jump = remaining_layers ? data() + first_jump : nullptr,
320 .padding_size = batch_size - remaining_layers,
321 .batch_size = batch_size,
322 };
323 }
324 /// Sentinel for @ref begin().
325 [[nodiscard]] std::default_sentinel_t end() const { return {}; }
326
327 /// @}
328
329 /// @name Dimensions
330 /// @{
331
332 /// Total number of elements in the view (excluding padding).
333 [[nodiscard]] constexpr index_type size() const { return layout.size(); }
334 /// Total number of elements in the view (including all padding).
335 [[nodiscard]] constexpr index_type padded_size() const { return layout.padded_size(); }
336
337 /// Number of layers in the view (i.e. depth).
338 [[nodiscard, gnu::always_inline]] constexpr depth_type depth() const { return layout.depth; }
339 /// The depth rounded up to a multiple of the batch size.
340 [[nodiscard, gnu::always_inline]] constexpr index_type ceil_depth() const {
341 return layout.ceil_depth();
342 }
343 /// The batch size, i.e. the number of layers in each batch. Equals the inner stride.
344 [[nodiscard, gnu::always_inline]] constexpr batch_size_type batch_size() const {
345 return layout.batch_size;
346 }
347 /// Number of batches in the view, i.e. `ceil_depth() / batch_size()`.
348 [[nodiscard, gnu::always_inline]] constexpr index_type num_batches() const {
349 return layout.num_batches();
350 }
351 /// Number of rows of the matrices.
352 [[nodiscard, gnu::always_inline]] constexpr index_type rows() const { return layout.rows; }
353 /// Number of columns of the matrices.
354 [[nodiscard, gnu::always_inline]] constexpr index_type cols() const { return layout.cols; }
355 /// The size of the outer dimension, i.e. the number of columns for column-major storage, or the
356 /// number of rows for row-major storage.
357 [[nodiscard, gnu::always_inline]] constexpr index_type outer_size() const {
358 return layout.outer_size();
359 }
360 /// The size of the inner dimension, i.e. the number of rows for column-major storage, or the
361 /// number of columns for row-major storage.
362 [[nodiscard, gnu::always_inline]] constexpr index_type inner_size() const {
363 return layout.inner_size();
364 }
365
366 /// @}
367
368 /// @name Strides
369 /// @{
370
371 /// Outer stride of the matrices (leading dimension in BLAS parlance). Should be multiplied by
372 /// the batch size to get the actual number of elements.
373 [[nodiscard, gnu::always_inline]] constexpr index_type outer_stride() const {
374 return layout.outer_stride;
375 }
376 /// The inner stride of the matrices. Should be multiplied by the batch size to get the actual
377 /// number of elements.
378 [[nodiscard, gnu::always_inline]] constexpr auto inner_stride() const {
379 return layout.inner_stride;
380 }
381 /// The row stride of the matrices, i.e. the distance between elements in consecutive rows in
382 /// a given column. Should be multiplied by the batch size to get the actual number of elements.
383 [[nodiscard, gnu::always_inline]] constexpr auto row_stride() const {
384 return layout.row_stride();
385 }
386 /// The column stride of the matrices, i.e. the distance between elements in consecutive columns
387 /// in a given row. Should be multiplied by the batch size to get the actual number of elements.
388 [[nodiscard, gnu::always_inline]] constexpr auto col_stride() const {
389 return layout.col_stride();
390 }
391 /// The layer stride, i.e. the distance between the first layer of one batch and the first layer
392 /// of the next batch. Should be multiplied by the batch size to get the actual number of
393 /// elements.
394 [[nodiscard, gnu::always_inline]] constexpr index_type layer_stride() const {
395 return layout.get_layer_stride();
396 }
397 /// Whether the `layer_stride() == outer_stride() * outer_size()`.
398 [[nodiscard, gnu::always_inline]] constexpr bool has_full_layer_stride() const {
399 return layout.has_full_layer_stride();
400 }
401 /// Whether the `outer_stride() == inner_stride() * inner_size()`.
402 [[nodiscard, gnu::always_inline]] constexpr bool has_full_outer_stride() const {
403 return layout.has_full_outer_stride();
404 }
405 /// Whether the `inner_stride() == 1`. Always true.
406 [[nodiscard, gnu::always_inline]] constexpr bool has_full_inner_stride() const {
407 return layout.has_full_inner_stride();
408 }
409
410 /// @}
411
412 private:
413 template <class V>
414 constexpr auto get_layer_stride_for() const {
415 if constexpr (std::is_same_v<typename V::layer_stride_type, layer_stride_type>)
416 return layout.layer_stride;
417 else
418 return layer_stride();
419 }
420
421 public:
422 /// @name Reshaping and slicing
423 /// @{
424
425 /// Reshape the view to the given dimensions. The total size should not change.
427 BATMAT_ASSERT(rows * cols == this->rows() * this->cols());
429 return general_slice_view_type{typename general_slice_view_type::PlainBatchedMatrixView{
430 .data = data(),
431 .depth = depth(),
432 .rows = rows,
433 .cols = cols,
434 .outer_stride = is_row_major ? cols : rows,
435 .batch_size = batch_size(),
436 .layer_stride = this->get_layer_stride_for<general_slice_view_type>()}};
437 }
438
439 /// Get a view of the first @p n rows.
440 [[nodiscard]] row_slice_view_type top_rows(index_type n) const {
441 BATMAT_ASSERT(0 <= n && n <= rows());
442 return row_slice_view_type{typename row_slice_view_type::PlainBatchedMatrixView{
443 .data = data(),
444 .depth = depth(),
445 .rows = n,
446 .cols = cols(),
447 .outer_stride = outer_stride(),
448 .batch_size = batch_size(),
449 .layer_stride = this->get_layer_stride_for<row_slice_view_type>()}};
450 }
451
452 /// Get a view of the first @p n columns.
453 [[nodiscard]] col_slice_view_type left_cols(index_type n) const {
454 BATMAT_ASSERT(0 <= n && n <= cols());
455 return col_slice_view_type{typename col_slice_view_type::PlainBatchedMatrixView{
456 .data = data(),
457 .depth = depth(),
458 .rows = rows(),
459 .cols = n,
460 .outer_stride = outer_stride(),
461 .batch_size = batch_size(),
462 .layer_stride = this->get_layer_stride_for<col_slice_view_type>()}};
463 }
464
465 /// Get a view of the last @p n rows.
467 BATMAT_ASSERT(0 <= n && n <= rows());
468 const auto bs = static_cast<I>(batch_size());
469 const auto offset = (is_row_major ? outer_stride() : 1) * bs * (rows() - n);
470 return row_slice_view_type{typename row_slice_view_type::PlainBatchedMatrixView{
471 .data = data() + offset,
472 .depth = depth(),
473 .rows = n,
474 .cols = cols(),
475 .outer_stride = outer_stride(),
476 .batch_size = batch_size(),
477 .layer_stride = this->get_layer_stride_for<row_slice_view_type>()}};
478 }
479
480 /// Get a view of the last @p n columns.
482 BATMAT_ASSERT(0 <= n && n <= cols());
483 const auto bs = static_cast<I>(batch_size());
484 const auto offset = (is_row_major ? 1 : outer_stride()) * bs * (cols() - n);
485 return col_slice_view_type{typename col_slice_view_type::PlainBatchedMatrixView{
486 .data = data() + offset,
487 .depth = depth(),
488 .rows = rows(),
489 .cols = n,
490 .outer_stride = outer_stride(),
491 .batch_size = batch_size(),
492 .layer_stride = this->get_layer_stride_for<col_slice_view_type>()}};
493 }
494
495 /// Get a view of @p n rows starting at row @p r.
497 return bottom_rows(rows() - r).top_rows(n);
498 }
499
500 /// Get a view of @p n rows starting at row @p r, with stride @p stride.
502 index_type stride) const
503 requires is_row_major
504 {
505 BATMAT_ASSERT(0 <= r);
506 BATMAT_ASSERT(r + (n - 1) * stride < rows());
507 const auto bs = static_cast<I>(batch_size());
508 const auto offset = outer_stride() * bs * r;
509 return row_slice_view_type{typename row_slice_view_type::PlainBatchedMatrixView{
510 .data = data() + offset,
511 .depth = depth(),
512 .rows = n,
513 .cols = cols(),
514 .outer_stride = outer_stride() * stride,
515 .batch_size = batch_size(),
516 .layer_stride = this->get_layer_stride_for<row_slice_view_type>()}};
517 }
518
519 /// Get a view of @p n columns starting at column @p c.
521 return right_cols(cols() - c).left_cols(n);
522 }
523
524 /// Get a view of @p n columns starting at column @p c, with stride @p stride.
526 index_type stride) const
527 requires is_column_major
528 {
529 BATMAT_ASSERT(0 <= c);
530 BATMAT_ASSERT(c + (n - 1) * stride < cols());
531 const auto bs = static_cast<I>(batch_size());
532 const auto offset = outer_stride() * bs * c;
533 return col_slice_view_type{typename col_slice_view_type::PlainBatchedMatrixView{
534 .data = data() + offset,
535 .depth = depth(),
536 .rows = rows(),
537 .cols = n,
538 .outer_stride = outer_stride() * stride,
539 .batch_size = batch_size(),
540 .layer_stride = this->get_layer_stride_for<col_slice_view_type>()}};
541 }
542
543 /// Get a view of the top-left @p nr by @p nc block of the matrices.
545 return top_rows(nr).left_cols(nc);
546 }
547
548 /// Get a view of the top-right @p nr by @p nc block of the matrices.
550 return top_rows(nr).right_cols(nc);
551 }
552
553 /// Get a view of the bottom-left @p nr by @p nc block of the matrices.
555 return bottom_rows(nr).left_cols(nc);
556 }
557
558 /// Get a view of the bottom-right @p nr by @p nc block of the matrices.
560 return bottom_rows(nr).right_cols(nc);
561 }
562
563 /// Get a view of the @p nr by @p nc block of the matrices starting at row @p r and column @p c.
565 index_type nc) const {
566 return middle_rows(r, nr).middle_cols(c, nc);
567 }
568
569 /// Get a view of the given span as a column vector.
570 [[nodiscard]] static View as_column(std::span<T> v) {
571 return {{.data = v.data(), .rows = static_cast<index_type>(v.size()), .cols = 1}};
572 }
573
574 /// Get a transposed view of the matrices. Note that the data itself is not modified, the
575 /// returned view simply accesses the same data with rows and column indices swapped.
576 [[nodiscard]] auto transposed() const {
577 using TpBm = View<T, I, S, D, L, transpose(O)>;
578 return TpBm{typename TpBm::PlainBatchedMatrixView{.data = data(),
579 .depth = depth(),
580 .rows = cols(),
581 .cols = rows(),
582 .outer_stride = outer_stride(),
583 .batch_size = batch_size(),
584 .layer_stride = layout.layer_stride}};
585 }
586
587 /// @}
588
589 /// @name Value manipulation
590 /// @{
591
592 /// Add the scalar @p t to the diagonal elements of all matrices.
594 const auto bs = static_cast<I>(batch_size());
595 const auto n = std::min(rows(), cols());
596 for (index_type b = 0; b < num_batches(); ++b) {
597 auto *p = batch(b).data();
598 for (index_type i = 0; i < n; ++i) {
599 for (index_type r = 0; r < bs; ++r)
600 *p++ += t;
601 p += bs * outer_stride();
602 }
603 }
604 }
605
606 /// Replace the diagonal elements of all matrices by the scalar @p t (without modifying any
607 /// other elements).
608 void set_diagonal(const value_type &t) {
609 const auto bs = static_cast<I>(batch_size());
610 const auto n = std::min(rows(), cols());
611 for (index_type b = 0; b < num_batches(); ++b) {
612 auto *p = batch(b).data();
613 for (index_type i = 0; i < n; ++i) {
614 for (index_type r = 0; r < bs; ++r)
615 *p++ = t;
616 p += bs * outer_stride();
617 }
618 }
619 }
620
621 /// Replace the elements of all matrices by the scalar @p t.
623 const auto bs = static_cast<I>(batch_size());
624 for (index_type b = 0; b < num_batches(); ++b) {
625 auto *dst = this->batch(b).data();
626 for (index_type c = 0; c < this->outer_size(); ++c) {
627 auto *dst_ = dst;
628 const index_type n = inner_size() * bs;
629 for (index_type r = 0; r < n; ++r)
630 *dst_++ = t;
631 dst += bs * this->outer_stride();
632 }
633 }
634 }
635
636 void negate() {
637 const auto bs = static_cast<I>(batch_size());
638 for (index_type b = 0; b < num_batches(); ++b) {
639 auto *dst = this->batch(b).data();
640 for (index_type c = 0; c < this->outer_size(); ++c) {
641 auto *dst_ = dst;
642 const index_type n = inner_size() * bs;
643 for (index_type r = 0; r < n; ++r, ++dst_)
644 *dst_ = -*dst_;
645 dst += bs * this->outer_stride();
646 }
647 }
648 }
649
650 template <class Other>
651 void copy_values(const Other &other) const {
652 static_assert(is_row_major == Other::is_row_major);
653 assert(other.rows() == this->rows());
654 assert(other.cols() == this->cols());
655 assert(other.batch_size() == this->batch_size());
656 const auto bs = static_cast<I>(batch_size());
657 for (index_type b = 0; b < num_batches(); ++b) {
658 const auto *src = other.batch(b).data();
659 auto *dst = this->batch(b).data();
660 for (index_type c = 0; c < this->outer_size(); ++c) {
661 const auto *src_ = src;
662 auto *dst_ = dst;
663 const index_type n = inner_size() * bs;
664 for (index_type r = 0; r < n; ++r)
665 *dst_++ = *src_++;
666 src += bs * other.outer_stride();
667 dst += bs * this->outer_stride();
668 }
669 }
670 }
671
672 /// Copy assignment copies the values from another view with the same layout to this view.
673 View &operator=(const View &other) {
674 if (this != &other)
675 copy_values(other);
676 return *this;
677 }
678 /// Copy values from another view with a compatible value type and the same layout to this view.
679 template <class U, class J, class R, class E, class M>
680 requires(!std::is_const_v<T> && std::convertible_to<U, std::remove_cv_t<T>> &&
681 std::equality_comparable_with<I, J>)
682 View &operator=(View<U, J, R, E, M, O> other) {
683 copy_values(other);
684 return *this;
685 }
686 // TODO: abstract logic into generic function (and check performance)
687 template <class U, class J, class R, class E, class M>
688 requires(!std::is_const_v<T> && std::convertible_to<U, std::remove_cv_t<T>> &&
689 std::equality_comparable_with<I, J>)
690 View &operator+=(View<U, J, R, E, M, O> other) {
691 assert(other.rows() == this->rows());
692 assert(other.cols() == this->cols());
693 assert(other.batch_size() == this->batch_size());
694 const auto bs = static_cast<I>(batch_size());
695 for (index_type b = 0; b < num_batches(); ++b) {
696 const auto *src = other.batch(b).data();
697 auto *dst = this->batch(b).data();
698 for (index_type c = 0; c < this->outer_size(); ++c) {
699 const auto *src_ = src;
700 auto *dst_ = dst;
701 const index_type n = inner_size() * bs;
702 for (index_type r = 0; r < n; ++r)
703 *dst_++ += *src_++;
704 src += bs * other.outer_stride();
705 dst += bs * this->outer_stride();
706 }
707 }
708 return *this;
709 }
710
711 /// @}
712
713 /// @name View conversions
714 /// @{
715
716 /// Returns the same view. For consistency with @ref Matrix.
717 [[nodiscard]] View view() const { return *this; }
718 /// Explicit conversion to a const view.
719 [[nodiscard]] const_view_type as_const() const { return *this; }
720
721 /// Non-const views implicitly convert to const views.
722 operator const_view_type() const
723 requires(!std::is_const_v<T>)
724 {
725 return {data_ptr, layout};
726 }
727
728 /// If we have a single layer at compile time, we can implicitly convert to a non-batched view.
734
735 /// Implicit conversion to a view with a dynamic depth.
737 requires(!std::same_as<integral_value_type_t<D>, D>)
738 {
739 return {{.data = data(),
740 .depth = depth(),
741 .rows = rows(),
742 .cols = cols(),
743 .outer_stride = outer_stride(),
744 .batch_size = batch_size(),
745 .layer_stride = layout.layer_stride}};
746 }
747 /// Implicit conversion to a view with a dynamic depth, going from non-const to const.
749 requires(!std::is_const_v<T> && !std::same_as<integral_value_type_t<D>, D>)
750 {
751 return {{.data = data(),
752 .depth = depth(),
753 .rows = rows(),
754 .cols = cols(),
755 .outer_stride = outer_stride(),
756 .batch_size = batch_size(),
757 .layer_stride = layout.layer_stride}};
758 }
759 /// Implicit conversion to a view with a dynamic layer stride.
760 operator View<T, I, S, D, I, O>() const
761 requires(!std::same_as<I, L>)
762 {
763 return {{.data = data(),
764 .depth = depth(),
765 .rows = rows(),
766 .cols = cols(),
767 .outer_stride = outer_stride(),
768 .batch_size = batch_size(),
769 .layer_stride = layer_stride()}};
770 }
771 /// Implicit conversion to a view with a dynamic layer stride, going from non-const to const.
773 requires(!std::is_const_v<T> && !std::same_as<I, L>)
774 {
775 return {{.data = data(),
776 .depth = depth(),
777 .rows = rows(),
778 .cols = cols(),
779 .outer_stride = outer_stride(),
780 .batch_size = batch_size(),
781 .layer_stride = layer_stride()}};
782 }
783
784 /// @}
785};
786
787template <class T, class I, class S, class D, class L, StorageOrder P>
788bool operator==(std::default_sentinel_t s, typename View<T, I, S, D, L, P>::linear_iterator i) {
789 return i == s;
790}
791template <class T, class I, class S, class D, class L, StorageOrder P>
792bool operator!=(std::default_sentinel_t s, typename View<T, I, S, D, L, P>::linear_iterator i) {
793 return !(i == s);
794}
795template <class T, class I, class S, class D, class L, StorageOrder P>
796bool operator!=(typename View<T, I, S, D, L, P>::linear_iterator i, std::default_sentinel_t s) {
797 return !(i == s);
798}
799
800// TODO: tag-invoke style CPOs instead of free functions with ADL?
801
802template <class T, class I, class S, class D, class L, StorageOrder O>
803constexpr auto data(const View<T, I, S, D, L, O> &v) {
804 return v.data();
805}
806template <class T, class I, class S, class D, class L, StorageOrder O>
807constexpr auto rows(const View<T, I, S, D, L, O> &v) {
808 return v.rows();
809}
810template <class T, class I, class S, class D, class L, StorageOrder O>
811constexpr auto cols(const View<T, I, S, D, L, O> &v) {
812 return v.cols();
813}
814template <class T, class I, class S, class D, class L, StorageOrder O>
815constexpr auto outer_stride(const View<T, I, S, D, L, O> &v) {
816 return v.outer_stride();
817}
818template <class T, class I, class S, class D, class L, StorageOrder O>
819constexpr auto depth(const View<T, I, S, D, L, O> &v) {
820 return v.depth();
821}
822
823} // namespace batmat::matrix
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
Layout description for a batch of matrices, independent of any storage.
constexpr auto cols(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:507
typename integral_value_type< T >::type integral_value_type_t
Definition layout.hpp:25
constexpr auto data(Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:493
constexpr auto outer_stride(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:511
bool operator!=(std::default_sentinel_t s, typename View< T, I, S, D, L, P >::linear_iterator i)
Definition view.hpp:792
constexpr auto rows(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:503
bool operator==(std::default_sentinel_t s, typename View< T, I, S, D, L, P >::linear_iterator i)
Definition view.hpp:788
constexpr auto depth(const Matrix< T, I, S, D, O, A > &v)
Definition matrix.hpp:515
int index_t
Definition config.hpp:13
Shape and strides describing a batch of matrices, independent of any storage.
Definition layout.hpp:50
static constexpr bool is_column_major
Definition layout.hpp:58
static constexpr bool is_row_major
Definition layout.hpp:59
static constexpr StorageOrder storage_order
Definition layout.hpp:57
std::conditional_t< requires { S::value; }, std::integral_constant< index_t, S::value >, index_t > standard_stride_type
Definition layout.hpp:62
Iterator over all elements of a view.
Definition view.hpp:268
bool operator==(std::default_sentinel_t) const
Definition view.hpp:292
linear_iterator operator++(int)
Definition view.hpp:286
linear_iterator & operator++()
Definition view.hpp:278
Non-owning view of an array of matrices, stored in an efficient batched format.
Definition view.hpp:59
constexpr View(PlainBatchedMatrixView p={})
Create a new view.
Definition view.hpp:122
Layout< index_t, stride, stride, layer_stride, O > layout_type
Definition view.hpp:60
View< T, I, S, N, L, O > first_layers(N n) const
Get a view of the first n layers. Note that n can be a compile-time constant.
Definition view.hpp:232
typename layout_type::layer_stride_type layer_stride_type
Definition view.hpp:65
guanaqo::MatrixView< T, I, standard_stride_type, O > operator()(index_type l) const
Access a single layer l as a non-batched view.
Definition view.hpp:226
general_slice_view_type bottom_right(index_type nr, index_type nc) const
Get a view of the bottom-right nr by nc block of the matrices.
Definition view.hpp:559
typename layout_type::batch_size_type batch_size_type
Definition view.hpp:63
typename layout_type::standard_stride_type standard_stride_type
Definition view.hpp:66
View< T, index_t, stride, stride, DefaultStride, O > batch_view_type
Definition view.hpp:83
general_slice_view_type block(index_type r, index_type c, index_type nr, index_type nc) const
Get a view of the nr by nc block of the matrices starting at row r and column c.
Definition view.hpp:564
void set_constant(value_type t)
Replace the elements of all matrices by the scalar t.
Definition view.hpp:622
value_type & operator()(index_type l, index_type r, index_type c) const
Access a single element at layer l, row r and column c.
Definition view.hpp:152
auto transposed() const
Get a transposed view of the matrices.
Definition view.hpp:576
constexpr auto inner_stride() const
The inner stride of the matrices.
Definition view.hpp:378
constexpr index_type num_batches() const
Number of batches in the view, i.e. ceil_depth() / batch_size().
Definition view.hpp:348
col_slice_view_type middle_cols(index_type c, index_type n, index_type stride) const
Get a view of n columns starting at column c, with stride stride.
Definition view.hpp:525
constexpr View(value_type *data, layout_type layout)
Create a new view with the given layout, using the given buffer.
Definition view.hpp:134
col_slice_view_type middle_cols(index_type c, index_type n) const
Get a view of n columns starting at column c.
Definition view.hpp:520
constexpr index_type padded_size() const
Total number of elements in the view (including all padding).
Definition view.hpp:335
void add_to_diagonal(const value_type &t)
Add the scalar t to the diagonal elements of all matrices.
Definition view.hpp:593
linear_iterator begin() const
Iterate linearly (in storage order) over all elements of the view.
Definition view.hpp:298
constexpr auto get_layer_stride_for() const
Definition view.hpp:414
row_slice_view_type bottom_rows(index_type n) const
Get a view of the last n rows.
Definition view.hpp:466
general_slice_view_type reshaped(index_type rows, index_type cols) const
Reshape the view to the given dimensions. The total size should not change.
Definition view.hpp:426
View< T, I, S, N, L, O > middle_layers(index_type l, N n) const
Get a view of n layers starting at layer l.
Definition view.hpp:247
View(const View &)=default
Copy a view. No data is copied.
constexpr auto col_stride() const
The column stride of the matrices, i.e.
Definition view.hpp:388
View< T, I, S, I, I, O > middle_batches(index_type b, index_type n, index_type stride) const
Get a view of n batches starting at batch b, with a stride of stride batches.
Definition view.hpp:203
general_slice_view_type bottom_left(index_type nr, index_type nc) const
Get a view of the bottom-left nr by nc block of the matrices.
Definition view.hpp:554
View< const T, index_t, stride, stride, layer_stride, O > const_view_type
Definition view.hpp:67
row_slice_view_type middle_rows(index_type r, index_type n) const
Get a view of n rows starting at row r.
Definition view.hpp:496
constexpr index_type cols() const
Number of columns of the matrices.
Definition view.hpp:354
col_slice_view_type right_cols(index_type n) const
Get a view of the last n columns.
Definition view.hpp:481
View< T, I, S, I, L, O > batch_dyn(index_type b) const
Same as batch(), but returns a view with a dynamic batch size.
Definition view.hpp:175
constexpr bool has_full_inner_stride() const
Whether the inner_stride() == 1. Always true.
Definition view.hpp:406
col_slice_view_type left_cols(index_type n) const
Get a view of the first n columns.
Definition view.hpp:453
static View as_column(std::span< T > v)
Get a view of the given span as a column vector.
Definition view.hpp:570
row_slice_view_type top_rows(index_type n) const
Get a view of the first n rows.
Definition view.hpp:440
constexpr index_type rows() const
Number of rows of the matrices.
Definition view.hpp:352
constexpr index_type inner_size() const
The size of the inner dimension, i.e.
Definition view.hpp:362
constexpr index_type ceil_depth() const
The depth rounded up to a multiple of the batch size.
Definition view.hpp:340
View & operator=(const View &other)
Copy assignment copies the values from another view with the same layout to this view.
Definition view.hpp:673
View< T, I, S, I, L, O > middle_batches(index_type b, index_type n) const
Get a view of n batches starting at batch b.
Definition view.hpp:189
std::conditional_t< has_single_batch_at_compile_time, View, View< T, index_t, stride, stride, index_t, O > > general_slice_view_type
Definition view.hpp:88
constexpr auto row_stride() const
The row stride of the matrices, i.e.
Definition view.hpp:383
std::conditional_t< is_column_major, View, general_slice_view_type > row_slice_view_type
Definition view.hpp:97
void set_diagonal(const value_type &t)
Replace the diagonal elements of all matrices by the scalar t (without modifying any other elements).
Definition view.hpp:608
row_slice_view_type middle_rows(index_type r, index_type n, index_type stride) const
Get a view of n rows starting at row r, with stride stride.
Definition view.hpp:501
constexpr View(std::span< T > data, layout_type layout)
Create a new view with the given layout, using the given buffer.
Definition view.hpp:130
general_slice_view_type top_left(index_type nr, index_type nc) const
Get a view of the top-left nr by nc block of the matrices.
Definition view.hpp:544
constexpr index_type outer_size() const
The size of the outer dimension, i.e.
Definition view.hpp:357
View view() const
Returns the same view. For consistency with Matrix.
Definition view.hpp:717
constexpr index_type layer_stride() const
The layer stride, i.e.
Definition view.hpp:394
const_view_type as_const() const
Explicit conversion to a const view.
Definition view.hpp:719
general_slice_view_type top_right(index_type nr, index_type nc) const
Get a view of the top-right nr by nc block of the matrices.
Definition view.hpp:549
std::conditional_t< is_row_major, View, general_slice_view_type > col_slice_view_type
Definition view.hpp:94
void copy_values(const Other &other) const
Definition view.hpp:651
View & reassign(View other)
Reassign the buffer and layout of this view to those of another view. No data is copied.
Definition view.hpp:140
constexpr index_type outer_stride() const
Outer stride of the matrices (leading dimension in BLAS parlance).
Definition view.hpp:373
batch_view_type batch(index_type b) const
Access a batch of batch_size() layers, starting at batch index b (i.e.
Definition view.hpp:163
POD helper struct to enable designated initializers during construction.
Definition view.hpp:108