batmat 0.0.13
Batched linear algebra routines
Loading...
Searching...
No Matches
copy.hpp
Go to the documentation of this file.
1#pragma once
2
8#include <batmat/loop.hpp>
9#include <batmat/lut.hpp>
10#include <batmat/ops/rotate.hpp>
13#include <batmat/unroll.h>
14#include <guanaqo/trace.hpp>
15#include <algorithm>
16#include <concepts>
17
18namespace batmat::linalg {
19
20namespace detail::copy {
25
26template <class T, class Abi, FillConfig Conf = {}, StorageOrder OB>
27[[gnu::flatten, gnu::noinline]] void fill(T a, view<T, Abi, OB> B) {
28 using std::max;
29 using std::min;
30 using enum MatrixStructure;
31 [[maybe_unused]] static constexpr const auto trace_name =
32 Conf.struc == General ? GUANAQO_TRACE_STATIC_STR("fill")
33 : Conf.struc == LowerTriangular ? GUANAQO_TRACE_STATIC_STR("fill(L)")
34 : Conf.struc == UpperTriangular ? GUANAQO_TRACE_STATIC_STR("fill(U)")
35 : GUANAQO_TRACE_STATIC_STR("fill(?)");
36 GUANAQO_TRACE_LINALG(trace_name,
37 B.rows() * B.cols() * B.depth()); // TODO
38 const auto I = B.rows(), J = B.cols();
39 if (I == 0 || J == 0 || B.depth() == 0)
40 return;
41
42 using types = simd_view_types<T, Abi>;
43 typename types::simd A{a};
44 const index_t JI_adif = max<index_t>(0, J - I), IJ_adif = max<index_t>(0, I - J);
45 if constexpr (OB == StorageOrder::ColMajor)
46 for (index_t j = 0; j < J; ++j) {
47 const index_t i0 = Conf.struc == LowerTriangular ? max<index_t>(0, j - JI_adif) : 0;
48 const index_t i1 = Conf.struc == UpperTriangular ? min(j + 1 + IJ_adif, I) : I;
49 BATMAT_UNROLLED_IVDEP_FOR (8, index_t i = i0; i < i1; ++i)
50 types::template aligned_store<Conf.mask>(A, &B(0, i, j));
51 }
52 else
53 for (index_t i = 0; i < I; ++i) {
54 const index_t j0 = Conf.struc == UpperTriangular ? max<index_t>(0, i - IJ_adif) : 0;
55 const index_t j1 = Conf.struc == LowerTriangular ? min(i + 1 + JI_adif, J) : J;
56 BATMAT_UNROLLED_IVDEP_FOR (8, index_t j = j0; j < j1; ++j)
57 types::template aligned_store<Conf.mask>(A, &B(0, i, j));
58 }
59}
60
66
67template <class T, class Abi, CopyConfig Conf = {}, StorageOrder OA, StorageOrder OB>
68[[gnu::flatten, gnu::noinline]] void copy(view<const T, Abi, OA> A, view<T, Abi, OB> B)
69 requires(!std::same_as<Abi, datapar::scalar_abi<T>> || Conf.struc != MatrixStructure::General)
70{
71 using ops::rotl;
72 using ops::rotr;
73 using std::max;
74 using std::min;
75 using enum MatrixStructure;
76 [[maybe_unused]] static constexpr const auto trace_name =
77 Conf.struc == General ? GUANAQO_TRACE_STATIC_STR("copy")
78 : Conf.struc == LowerTriangular ? GUANAQO_TRACE_STATIC_STR("copy(L)")
79 : Conf.struc == UpperTriangular ? GUANAQO_TRACE_STATIC_STR("copy(U)")
80 : GUANAQO_TRACE_STATIC_STR("copy(?)");
81 GUANAQO_TRACE_LINALG(trace_name,
82 A.rows() * A.cols() * A.depth()); // TODO
83 assert(A.rows() == B.rows());
84 assert(A.cols() == B.cols());
85 const auto I = A.rows(), J = A.cols();
86 if (I == 0 || J == 0 || A.depth() == 0)
87 return;
88
89 using types = simd_view_types<T, Abi>;
90 const index_t JI_adif = max<index_t>(0, J - I), IJ_adif = max<index_t>(0, I - J);
91 if constexpr (OA == StorageOrder::ColMajor)
92 for (index_t j = 0; j < J; ++j) {
93 const index_t i0 = Conf.struc == LowerTriangular ? max<index_t>(0, j - JI_adif) : 0;
94 const index_t i1 = Conf.struc == UpperTriangular ? min(j + 1 + IJ_adif, I) : I;
95 BATMAT_UNROLLED_IVDEP_FOR (8, index_t i = i0; i < i1; ++i)
96 types::template aligned_store<Conf.mask>(
97 rotl<Conf.rotate>(types::aligned_load(&A(0, i, j))), &B(0, i, j));
98 }
99 else
100 for (index_t i = 0; i < I; ++i) {
101 const index_t j0 = Conf.struc == UpperTriangular ? max<index_t>(0, i - IJ_adif) : 0;
102 const index_t j1 = Conf.struc == LowerTriangular ? min(i + 1 + JI_adif, J) : J;
103 BATMAT_UNROLLED_IVDEP_FOR (8, index_t j = j0; j < j1; ++j)
104 types::template aligned_store<Conf.mask>(
105 rotl<Conf.rotate>(types::aligned_load(&A(0, i, j))), &B(0, i, j));
106 }
107}
108
109template <class T, class Abi, CopyConfig Conf = {}, StorageOrder OA, StorageOrder OB>
110[[gnu::flatten, gnu::noinline]] void copy(view<const T, Abi, OA> A, view<T, Abi, OB> B)
111 requires(std::same_as<Abi, datapar::scalar_abi<T>> && OA == OB &&
112 Conf.struc == MatrixStructure::General)
113{
114 GUANAQO_TRACE_LINALG("copy", A.rows() * A.cols() * A.depth());
115 assert(A.rows() == B.rows());
116 assert(A.cols() == B.cols());
117 if constexpr (Conf.mask != 0) // Scalar only
118 return;
119 if (A.rows() == 0 || A.cols() == 0 || A.depth() == 0)
120 return;
121
122 static_assert(typename decltype(A)::batch_size_type() == 1);
123 static_assert(typename decltype(B)::batch_size_type() == 1);
124 if constexpr (OA == StorageOrder::ColMajor)
125 for (index_t j = 0; j < A.cols(); ++j)
126 std::copy_n(&A(0, 0, j), A.rows(), &B(0, 0, j));
127 else
128 for (index_t i = 0; i < A.rows(); ++i)
129 std::copy_n(&A(0, i, 0), A.cols(), &B(0, i, 0));
130}
131
132template <class T, class Abi, CopyConfig Conf = {}, StorageOrder OA, StorageOrder OB>
133[[gnu::flatten, gnu::noinline]] void copy(view<const T, Abi, OA> A, view<T, Abi, OB> B)
134 requires(std::same_as<Abi, datapar::scalar_abi<T>> && OA != OB &&
135 Conf.struc == MatrixStructure::General)
136{
137 GUANAQO_TRACE_LINALG("copy(T)", A.rows() * A.cols() * A.depth());
138 assert(A.rows() == B.rows());
139 assert(A.cols() == B.cols());
140 if constexpr (Conf.mask != 0) // Scalar only
141 return;
142 if (A.rows() == 0 || A.cols() == 0 || A.depth() == 0)
143 return;
144
145 constexpr index_t R = ops::RowsRegTranspose<T>;
146 constexpr index_t C = ops::ColsRegTranspose<T>;
147 [[maybe_unused]] static const constinit auto lut =
148 make_2d_lut<R, C>([]<index_t Row, index_t Col>(index_constant<Row>, index_constant<Col>) {
150 });
151
152 // Always access A contiguously in the inner loop
153 if constexpr (OA == StorageOrder::ColMajor)
154 // Tiled transposition
155 foreach_chunked_merged(0, A.cols(), C, [&](index_t c, auto nc) {
156 foreach_chunked_merged(0, A.rows(), R, [&](index_t r, auto nr) {
157 lut[nr - 1][nc - 1](&A(0, r, c), A.outer_stride(), &B(0, r, c), B.outer_stride());
158 });
159 });
160 else
161 foreach_chunked_merged(0, A.rows(), R, [&](index_t r, auto nr) {
162 foreach_chunked_merged(0, A.cols(), C, [&](index_t c, auto nc) {
163 lut[nc - 1][nr - 1](&A(0, r, c), A.outer_stride(), &B(0, r, c), B.outer_stride());
164 });
165 });
166}
167
168template <class... Opts>
169constexpr CopyConfig apply_options(CopyConfig conf, Opts...) {
170 if (auto s = get_rotate<Opts...>)
171 conf.rotate = *s;
172 if (auto s = get_mask<Opts...>)
173 conf.mask = *s;
174 return conf;
175}
176} // namespace detail::copy
177
178/// @addtogroup topic-linalg
179/// @{
180
181/// @name Copying and filling batches of matrices
182/// @{
183
184/// B = A
185template <simdifiable VA, simdifiable VB, rotate_opt... Opts>
187void copy(VA &&A, VB &&B, Opts... opts) {
188 constexpr auto conf = detail::copy::apply_options({}, opts...);
190 simdify(B));
191}
192
193/// B = A
194template <MatrixStructure S, simdifiable VA, simdifiable VB, rotate_opt... Opts>
195 requires simdify_compatible<VA, VB>
196void copy(Structured<VA, S> A, Structured<VB, S> B, Opts... opts) {
197 constexpr auto conf = detail::copy::apply_options({.struc = S}, opts...);
199 simdify(A.value).as_const(), simdify(B.value));
200}
201
202/// B = a
203template <simdifiable VB>
207
208/// B = a
209template <MatrixStructure S, simdifiable VB>
214
215/// @}
216
217/// @}
218
219} // namespace batmat::linalg
void copy(VA &&A, VB &&B, Opts... opts)
B = A.
Definition copy.hpp:187
void fill(simdified_value_t< VB > a, VB &&B)
B = a.
Definition copy.hpp:204
datapar::simd< F, Abi > rotl(datapar::simd< F, Abi > x)
Rotates the elements of x by s positions to the left.
Definition rotate.hpp:226
void transpose(const T *pa, index_t lda, T *pb, index_t ldb)
Transposes the R × C matrix at pa with leading dimension lda, writing the result to pb with leading d...
Definition transpose.hpp:63
datapar::simd< F, Abi > rotr(datapar::simd< F, Abi > x)
Rotate the elements of x to the right by S positions.
Definition rotate.hpp:239
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
#define GUANAQO_TRACE_LINALG(name, gflops)
#define GUANAQO_TRACE_STATIC_STR(s)
void fill(T a, view< T, Abi, OB > B)
Definition copy.hpp:27
constexpr CopyConfig apply_options(CopyConfig conf, Opts...)
Definition copy.hpp:169
void copy(view< const T, Abi, OA > A, view< T, Abi, OB > B)
Definition copy.hpp:68
typename detail::simdified_value< V >::type simdified_value_t
Definition simdify.hpp:202
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:204
constexpr bool simdify_compatible
Definition simdify.hpp:207
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:214
constexpr std::optional< int > get_rotate
Definition shift.hpp:92
constexpr std::optional< int > get_mask
Definition shift.hpp:99
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
constexpr index_t RowsRegTranspose
Definition avx-512.hpp:23
constexpr index_t ColsRegTranspose
Definition avx-512.hpp:25
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
Light-weight wrapper class used for overload resolution of triangular and symmetric matrices.
#define BATMAT_UNROLLED_IVDEP_FOR(N,...)
Definition unroll.h:29