batmat 0.0.21
Batched linear algebra routines
Loading...
Searching...
No Matches
elementwise.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
4#include <batmat/config.hpp>
9#include <batmat/simd.hpp>
10#include <guanaqo/trace.hpp>
11#include <array>
12#include <cmath>
13#include <concepts>
14#include <tuple>
15#include <utility>
16
17namespace batmat::linalg {
18
19/// @cond DETAIL
20
21namespace detail {
22
23constexpr index_t num_elem(const auto &A) { return A.rows() * A.cols() * A.depth(); }
24
25template <class T, class Abi, StorageOrder O, class F, class X, class... Xs>
26[[gnu::always_inline]] inline void iter_elems(F &&fun, X &&x, Xs &&...xs) {
27 using types = simd_view_types<T, Abi>;
28 if constexpr (O == StorageOrder::ColMajor) {
29 for (index_t c = 0; c < x.cols(); ++c)
30 for (index_t r = 0; r < x.rows(); ++r)
31 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
32 } else {
33 for (index_t r = 0; r < x.rows(); ++r)
34 for (index_t c = 0; c < x.cols(); ++c)
35 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
36 }
37}
38
39template <class T, class Abi, StorageOrder O, class F, class X, class... Xs>
40[[gnu::always_inline]] inline void iter_elems_store(F &&fun, X &&x, Xs &&...xs) {
41 using types = simd_view_types<T, Abi>;
42 if constexpr (O == StorageOrder::ColMajor) {
43 for (index_t c = 0; c < x.cols(); ++c)
44 for (index_t r = 0; r < x.rows(); ++r)
45 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
46 } else {
47 for (index_t r = 0; r < x.rows(); ++r)
48 for (index_t c = 0; c < x.cols(); ++c)
49 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
50 }
51}
52
53template <class T, class Abi, StorageOrder O, class F, class X0, class X1, class... Xs>
54[[gnu::always_inline]] inline void iter_elems_store2(F &&fun, X0 &&x0, X1 &&x1, Xs &&...xs) {
55 using types = simd_view_types<T, Abi>;
56 if constexpr (O == StorageOrder::ColMajor) {
57 for (index_t c = 0; c < x0.cols(); ++c)
58 for (index_t r = 0; r < x0.rows(); ++r) {
59 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
60 types::aligned_store(r0, &x0(0, r, c));
61 types::aligned_store(r1, &x1(0, r, c));
62 }
63 } else {
64 for (index_t r = 0; r < x0.rows(); ++r)
65 for (index_t c = 0; c < x0.cols(); ++c) {
66 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
67 types::aligned_store(r0, &x0(0, r, c));
68 types::aligned_store(r1, &x1(0, r, c));
69 }
70 }
71}
72
73template <class T, class Abi, StorageOrder O, class F, class... Ys, class... Xs>
74[[gnu::always_inline]] inline void iter_elems_store_n(F &&fun, std::tuple<Ys...> ys, Xs &&...xs) {
75 using std::get;
76 using types = simd_view_types<T, Abi>;
77 const index_t rows = std::get<0>(ys).rows(), cols = std::get<0>(ys).cols();
78 if constexpr (O == StorageOrder::ColMajor) {
79 for (index_t c = 0; c < cols; ++c)
80 for (index_t r = 0; r < rows; ++r) {
81 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
82 static_assert(std::tuple_size_v<decltype(rs)> == sizeof...(Ys));
83 [&]<size_t... Is>(std::index_sequence<Is...>) {
84 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
85 }(std::index_sequence_for<Ys...>());
86 }
87 } else {
88 for (index_t r = 0; r < rows; ++r)
89 for (index_t c = 0; c < cols; ++c) {
90 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
91 static_assert(std::tuple_size_v<decltype(rs)> == sizeof...(Ys));
92 [&]<size_t... Is>(std::index_sequence<Is...>) {
93 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
94 }(std::index_sequence_for<Ys...>());
95 }
96 }
97}
98
99/// Iterate element-wise over the diagonal elements of matrices and over the elements of vectors.
100/// Any argument can either be a square matrix or a vector, but this function has a fast path
101/// when the first input and output arguments are square matrices and all others are column vectors.
102template <class T, class Abi, class F, class... Ys, class... Xs>
103[[gnu::always_inline]] inline void iter_diag_store_n(F &&fun, std::tuple<Ys...> ys, Xs &&...xs) {
104 using std::get;
105 using types = simd_view_types<T, Abi>;
106 const index_t rows = std::get<0>(ys).rows(), cols = std::get<0>(ys).cols();
107 const index_t n = std::get<0>(ys).storage_order == StorageOrder::ColMajor
108 ? (cols > 1 ? std::min(rows, cols) : rows)
109 : (rows > 1 ? std::min(rows, cols) : cols);
110 // Optimized implementation for the special case where the first input and the first output are
111 // square matrices and all others are column vectors.
112 static constexpr auto all_vectors_except_first = [](auto &x0, auto &...x1s) {
113 return x0.rows() == x0.cols() &&
114 ((x1s.storage_order == StorageOrder::ColMajor && x1s.cols() == 1) && ...);
115 };
116 const bool all_xs_vectors_except_first = all_vectors_except_first(xs...);
117 const bool all_ys_vector_except_first = std::apply(all_vectors_except_first, ys);
118 if (all_xs_vectors_except_first && all_ys_vector_except_first) {
119 for (index_t r = 0; r < n; ++r) {
120 auto rs = [&](auto &x0, auto &...x1s) {
121 return fun(types::aligned_load(&x0(0, r, r)),
122 types::aligned_load(&x1s(0, r, 0))...);
123 }(xs...);
124 static_assert(std::tuple_size_v<decltype(rs)> == sizeof...(Ys));
125 [&]<size_t... Is>(std::index_sequence<Is...>) {
126 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, Is == 0 ? r : 0))), ...);
127 }(std::index_sequence_for<Ys...>());
128 }
129 }
130 // Fully generic implementation
131 // (GCC does not seem to hoist the conditionals in the access function out of the loop, so we
132 // pay for some cmovs here, even at -O3. Should be fine though, since we're most likely memory-
133 // bound anyway.)
134 else {
135 static constexpr auto access = [](auto &x, index_t r) -> auto & {
136 return x.storage_order == StorageOrder::ColMajor
137 ? (x.cols() > 1 ? x(0, r, r) : x(0, r, 0))
138 : (x.rows() > 1 ? x(0, r, r) : x(0, 0, r));
139 };
140 for (index_t r = 0; r < n; ++r) {
141 auto rs = fun(types::aligned_load(&access(xs, r))...);
142 static_assert(std::tuple_size_v<decltype(rs)> == sizeof...(Ys));
143 [&]<size_t... Is>(std::index_sequence<Is...>) {
144 ((types::aligned_store(get<Is>(rs), &access(get<Is>(ys), r))), ...);
145 }(std::index_sequence_for<Ys...>());
146 }
147 }
148}
149
150/// Scalar product.
151template <class T, class Abi, StorageOrder OB, StorageOrder OC>
152[[gnu::flatten]] void scale(T a, view<const T, Abi, OB> B, view<T, Abi, OC> C) {
153 BATMAT_ASSERT(B.rows() == C.rows());
154 BATMAT_ASSERT(B.cols() == C.cols());
155 iter_elems_store<T, Abi, OC>([&](auto Bi) { return a * Bi; }, C, B);
156}
157
158/// Hadamard (elementwise) product.
159template <class T, class Abi, StorageOrder OA, StorageOrder OB, StorageOrder OC>
162 BATMAT_ASSERT(A.rows() == B.rows());
163 BATMAT_ASSERT(A.cols() == B.cols());
164 BATMAT_ASSERT(A.rows() == C.rows());
165 BATMAT_ASSERT(A.cols() == C.cols());
166 iter_elems_store<T, Abi, OC>([&](auto Ai, auto Bi) { return Ai * Bi; }, C, A, B);
167}
168
169/// Elementwise clamping z = max(lo, min(x, hi)).
170template <class T, class Abi, StorageOrder O>
171[[gnu::flatten]] void clamp(view<const T, Abi, O> x, view<const T, Abi, O> lo,
173 BATMAT_ASSERT(x.rows() == lo.rows());
174 BATMAT_ASSERT(x.cols() == lo.cols());
175 BATMAT_ASSERT(x.rows() == hi.rows());
176 BATMAT_ASSERT(x.cols() == hi.cols());
177 BATMAT_ASSERT(x.rows() == z.rows());
178 BATMAT_ASSERT(x.cols() == z.cols());
179 const auto clamp = [&](auto xi, auto loi, auto hii) { return fmax(loi, fmin(xi, hii)); };
180 iter_elems_store<T, Abi, O>(clamp, z, x, lo, hi);
181}
182
183/// Elementwise clamping z = max(lo, min(x, hi)), with scalar lo and hi.
184template <class T, class Abi, StorageOrder O>
185[[gnu::flatten]] void clamp(view<const T, Abi, O> x, T lo, T hi, view<T, Abi, O> z) {
186 BATMAT_ASSERT(x.rows() == z.rows());
187 BATMAT_ASSERT(x.cols() == z.cols());
188 const auto clamp = [&](auto xi) { return fmax(decltype(xi){lo}, fmin(xi, decltype(xi){hi})); };
189 iter_elems_store<T, Abi, O>(clamp, z, x);
190}
191
192/// Elementwise clamping residual z = x - max(lo, min(x, hi)).
193template <class T, class Abi, StorageOrder O>
196 BATMAT_ASSERT(x.rows() == lo.rows());
197 BATMAT_ASSERT(x.cols() == lo.cols());
198 BATMAT_ASSERT(x.rows() == hi.rows());
199 BATMAT_ASSERT(x.cols() == hi.cols());
200 BATMAT_ASSERT(x.rows() == z.rows());
201 BATMAT_ASSERT(x.cols() == z.cols());
203 const auto clamp_resid = [&](auto xi, auto loi, auto hii) {
204 return fmax(xi - hii, fmin(simd{0}, xi - loi));
205 };
206 iter_elems_store<T, Abi, O>(clamp_resid, z, x, lo, hi);
207}
208
209/// Linear combination of vectors z = beta * z + sum_i alpha_i * x_i.
210template <class T, class Abi, T Beta, StorageOrder O, class... Xs>
211[[gnu::flatten]] void gaxpby(view<T, Abi, O> z, const std::array<T, sizeof...(Xs)> &alphas,
212 const Xs &...xs) {
213 BATMAT_ASSERT(((z.rows() == xs.rows()) && ...));
214 BATMAT_ASSERT(((z.cols() == xs.cols()) && ...));
215 if constexpr (Beta == 0)
216 iter_elems_store<T, Abi, O>(
217 [&](auto... xis) {
218 return [&]<std::size_t... Is>(std::index_sequence<Is...>, auto... xis) {
219 return ((xis * alphas[Is]) + ...);
220 }(std::make_index_sequence<sizeof...(Xs)>(), xis...);
221 },
222 z, xs...);
223 else
224 iter_elems_store<T, Abi, O>(
225 [&](auto zi, auto... xis) {
226 return [&]<std::size_t... Is>(std::index_sequence<Is...>, auto... xis) {
227 return zi * Beta + ((xis * alphas[Is]) + ...);
228 }(std::make_index_sequence<sizeof...(Xs)>(), xis...);
229 },
230 z, z, xs...);
231}
232
233/// Negate a matrix or vector.
234/// @todo: add Negate option to batmat::linalg::copy and remove this function, then this also
235/// supports transposition.
236template <class T, class Abi, int Rotate, StorageOrder OA, StorageOrder OB>
237[[gnu::flatten]] void negate(view<const T, Abi, OA> A, view<T, Abi, OB> B) {
238 BATMAT_ASSERT(A.rows() == B.rows());
239 BATMAT_ASSERT(A.cols() == B.cols());
240 using ops::rotl;
241 iter_elems_store<T, Abi, OB>([&](auto Ai) { return -rotl<Rotate>(Ai); }, B, A);
242}
243
244/// Subtract two matrices or vectors C = A - B.
245template <class T, class Abi, int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
247 BATMAT_ASSERT(A.rows() == B.rows());
248 BATMAT_ASSERT(A.cols() == B.cols());
249 BATMAT_ASSERT(A.rows() == C.rows());
250 BATMAT_ASSERT(A.cols() == C.cols());
251 using ops::rotl;
252 iter_elems_store<T, Abi, OC>([&](auto Ai, auto Bi) { return Ai - rotl<Rotate>(Bi); }, C, A, B);
253}
254
255/// Add two matrices or vectors C = A + B.
256template <class T, class Abi, int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
258 BATMAT_ASSERT(A.rows() == B.rows());
259 BATMAT_ASSERT(A.cols() == B.cols());
260 BATMAT_ASSERT(A.rows() == C.rows());
261 BATMAT_ASSERT(A.cols() == C.cols());
262 using ops::rotl;
263 iter_elems_store<T, Abi, OC>([&](auto Ai, auto Bi) { return Ai + rotl<Rotate>(Bi); }, C, A, B);
264}
265
266} // namespace detail
267
268/// @endcond
269
270/// @addtogroup topic-linalg
271/// @{
272
273/// @name Single-batch elementwise operations
274/// @{
275
276/// Multiply a vector by a scalar z = αx.
277template <simdifiable Vx, simdifiable Vz, std::convertible_to<simdified_value_t<Vx>> T>
279void scale(T alpha, Vx &&x, Vz &&z) {
280 GUANAQO_TRACE_LINALG("scale", detail::num_elem(simdify(x)));
281 detail::scale<simdified_value_t<Vx>, simdified_abi_t<Vx>>(alpha, simdify(x).as_const(),
282 simdify(z));
283}
284
285/// Multiply a vector by a scalar x = αx.
286template <simdifiable Vx, std::convertible_to<simdified_value_t<Vx>> T>
287void scale(T alpha, Vx &&x) {
288 GUANAQO_TRACE_LINALG("scale", detail::num_elem(simdify(x)));
289 detail::scale<simdified_value_t<Vx>, simdified_abi_t<Vx>>(alpha, simdify(x).as_const(),
290 simdify(x));
291}
292
293/// Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
294template <simdifiable Vx, simdifiable Vy, simdifiable Vz>
296void hadamard(Vx &&x, Vy &&y, Vz &&z) {
297 GUANAQO_TRACE_LINALG("hadamard", detail::num_elem(simdify(x)));
298 detail::hadamard<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(),
299 simdify(y).as_const(), simdify(z));
300}
301
302/// Compute the Hadamard (elementwise) product of two vectors x = x ⊙ y.
303template <simdifiable Vx, simdifiable Vy>
305void hadamard(Vx &&x, Vy &&y) {
306 GUANAQO_TRACE_LINALG("hadamard", detail::num_elem(simdify(x)));
307 detail::hadamard<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(),
308 simdify(y).as_const(), simdify(x));
309}
310
311/// Elementwise clamping z = max(lo, min(x, hi)).
312template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
314void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
315 GUANAQO_TRACE_LINALG("clamp", 2 * detail::num_elem(simdify(x))); // max, min
316 detail::clamp<simdified_value_t<Vx>, simdified_abi_t<Vx>>(
317 simdify(x).as_const(), simdify(lo).as_const(), simdify(hi).as_const(), simdify(z));
318}
319
320/// Elementwise clamping residual z = x - max(lo, min(x, hi)).
321template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
323void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
324 GUANAQO_TRACE_LINALG("clamp_resid", 3 * detail::num_elem(simdify(x))); // sub, max, min
325 detail::clamp_resid<simdified_value_t<Vx>, simdified_abi_t<Vx>>(
326 simdify(x).as_const(), simdify(lo).as_const(), simdify(hi).as_const(), simdify(z));
327}
328
329/// Elementwise clamping z = max(lo, min(x, hi)), with scalar lo and hi.
330template <simdifiable Vx, simdifiable Vz>
333 GUANAQO_TRACE_LINALG("clamp", 2 * detail::num_elem(simdify(x))); // max, min
334 detail::clamp<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(), lo, hi,
335 simdify(z));
336}
337
338/// Add scaled vector z = αx + βy.
339template <simdifiable Vx, simdifiable Vy, simdifiable Vz, //
340 std::convertible_to<simdified_value_t<Vx>> Ta,
341 std::convertible_to<simdified_value_t<Vx>> Tb>
343void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
344 GUANAQO_TRACE_LINALG("axpby", 2 * detail::num_elem(simdify(x))); // mul, fma
345 detail::gaxpby<simdified_value_t<Vx>, simdified_abi_t<Vx>, simdified_value_t<Vx>{0}>(
346 simdify(z), {{alpha, beta}}, simdify(x).as_const(), simdify(y).as_const());
347}
348
349/// Add scaled vector y = αx + βy.
350template <simdifiable Vx, simdifiable Vy, //
351 std::convertible_to<simdified_value_t<Vx>> Ta,
352 std::convertible_to<simdified_value_t<Vx>> Tb>
354void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
355 GUANAQO_TRACE_LINALG("axpby", 2 * detail::num_elem(simdify(x))); // mul, fma
356 detail::gaxpby<simdified_value_t<Vx>, simdified_abi_t<Vx>, simdified_value_t<Vx>{0}>(
357 simdify(y), {{alpha, beta}}, simdify(x).as_const(), simdify(y).as_const());
358}
359
360/// Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
361template <auto Beta = 1, simdifiable Vy, simdifiable... Vx>
362 requires simdify_compatible<Vy, Vx...>
363void axpy(Vy &&y, const std::array<simdified_value_t<Vy>, sizeof...(Vx)> &alphas, Vx &&...x) {
364 [[maybe_unused]] static constexpr index_t num_mul = Beta != 1 && Beta != 0 ? 1 : 0;
365 [[maybe_unused]] static constexpr index_t num_fma = sizeof...(Vx);
366 GUANAQO_TRACE_LINALG("axpy", (num_mul + num_fma) * detail::num_elem(simdify(y))); // mul, fma
367 detail::gaxpby<simdified_value_t<Vy>, simdified_abi_t<Vy>, simdified_value_t<Vy>{Beta}>(
368 simdify(y), alphas, simdify(x).as_const()...);
369}
370
371/// Add scaled vector z = αx + y.
372template <simdifiable Vx, simdifiable Vy, simdifiable Vz,
373 std::convertible_to<simdified_value_t<Vx>> Ta>
375void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
376 axpby(alpha, x, Ta{1}, y, z);
377}
378
379/// Add scaled vector y = αx + βy (where β is a compile-time constant).
380template <auto Beta = 1, simdifiable Vx, simdifiable Vy,
381 std::convertible_to<simdified_value_t<Vx>> Ta>
383void axpy(Ta alpha, Vx &&x, Vy &&y) {
384 [[maybe_unused]] static constexpr index_t num_mul = Beta != 1 && Beta != 0 ? 1 : 0;
385 [[maybe_unused]] static constexpr index_t num_fma = 1;
386 GUANAQO_TRACE_LINALG("axpy", (num_mul + num_fma) * detail::num_elem(simdify(y))); // mul, fma
387 detail::gaxpby<simdified_value_t<Vx>, simdified_abi_t<Vx>, simdified_value_t<Vx>{Beta}>(
388 simdify(y), {{alpha}}, simdify(x).as_const());
389}
390
391/// Negate a matrix or vector B = -A.
392template <simdifiable VA, simdifiable VB, int Rotate = 0>
394void negate(VA &&A, VB &&B, with_rotate_t<Rotate> = {}) {
395 GUANAQO_TRACE_LINALG("negate", detail::num_elem(simdify(A)));
396 detail::negate<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(simdify(A).as_const(),
397 simdify(B));
398}
399
400/// Negate a matrix or vector A = -A.
401template <simdifiable VA, int Rotate = 0>
402void negate(VA &&A, with_rotate_t<Rotate> = {}) {
403 GUANAQO_TRACE_LINALG("negate", detail::num_elem(simdify(A)));
404 detail::negate<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(simdify(A).as_const(),
405 simdify(A));
406}
407
408/// Subtract two matrices or vectors C = A - B. Rotate affects B.
409template <simdifiable VA, simdifiable VB, simdifiable VC, int Rotate = 0>
411void sub(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> = {}) {
412 GUANAQO_TRACE_LINALG("sub", detail::num_elem(simdify(A)));
413 detail::sub<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
414 simdify(A).as_const(), simdify(B).as_const(), simdify(C));
415}
416
417/// Subtract two matrices or vectors A = A - B. Rotate affects B.
418template <simdifiable VA, simdifiable VB, int Rotate = 0>
420void sub(VA &&A, VB &&B, with_rotate_t<Rotate> = {}) {
421 GUANAQO_TRACE_LINALG("sub", detail::num_elem(simdify(A)));
422 detail::sub<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
423 simdify(A).as_const(), simdify(B).as_const(), simdify(A));
424}
425
426/// Add two matrices or vectors C = A + B. Rotate affects B.
427template <simdifiable VA, simdifiable VB, simdifiable VC, int Rotate = 0>
429void add(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> = {}) {
430 GUANAQO_TRACE_LINALG("add", detail::num_elem(simdify(A)));
431 detail::add<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
432 simdify(A).as_const(), simdify(B).as_const(), simdify(C));
433}
434
435/// Add two matrices or vectors A = A + B. Rotate affects B.
436template <simdifiable VA, simdifiable VB, int Rotate = 0>
438void add(VA &&A, VB &&B, with_rotate_t<Rotate> = {}) {
439 GUANAQO_TRACE_LINALG("add", detail::num_elem(simdify(A)));
440 detail::add<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
441 simdify(A).as_const(), simdify(B).as_const(), simdify(A));
442}
443
444/// Apply a function to all elements of the given matrices or vectors.
445template <class F, simdifiable VA, simdifiable... VAs>
446 requires simdify_compatible<VA, VAs...>
447void for_each_elementwise(F &&fun, VA &&A, VAs &&...As) {
448 static constexpr auto storage_order = simdified_view_t<VA>::storage_order;
449 detail::iter_elems<simdified_value_t<VA>, simdified_abi_t<VA>, storage_order>(
450 std::forward<F>(fun), simdify(A).as_const(), simdify(As).as_const()...);
451}
452
453/// Apply a function to all elements of the given matrices or vectors, storing the result in the
454/// first argument.
455template <class F, simdifiable VA, simdifiable... VAs>
456 requires simdify_compatible<VA, VAs...>
457void transform_elementwise(F &&fun, VA &&A, VAs &&...As) {
458 static constexpr auto storage_order = simdified_view_t<VA>::storage_order;
459 detail::iter_elems_store<simdified_value_t<VA>, simdified_abi_t<VA>, storage_order>(
460 std::forward<F>(fun), simdify(A), simdify(As).as_const()...);
461}
462
463/// Apply a function to all elements of the given matrices or vectors, storing the results in the
464/// first two arguments.
465template <class F, simdifiable VA, simdifiable VB, simdifiable... VAs>
466 requires simdify_compatible<VA, VB, VAs...>
467void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As) {
468 static constexpr auto storage_order = simdified_view_t<VA>::storage_order;
469 detail::iter_elems_store2<simdified_value_t<VA>, simdified_abi_t<VA>, storage_order>(
470 std::forward<F>(fun), simdify(A), simdify(B), simdify(As).as_const()...);
471}
472
473/// Apply a function to all elements of the given matrices or vectors, storing the results in the
474/// tuple of matrices given as the first argument.
475template <class F, simdifiable... VAs, simdifiable... VBs>
476 requires simdify_compatible<VAs..., VBs...>
477void transform_n_elementwise(F &&fun, std::tuple<VAs...> As, VBs &&...Bs) {
478 using VA0 = std::tuple_element_t<0, decltype(As)>;
479 static constexpr auto storage_order = simdified_view_t<VA0>::storage_order;
480 detail::iter_elems_store_n<simdified_value_t<VA0>, simdified_abi_t<VA0>, storage_order>(
481 std::forward<F>(fun),
482 std::apply([](auto &&...a) { return std::make_tuple(simdify(a)...); }, As),
483 simdify(Bs).as_const()...);
484}
485
486/// Apply a function to all elements of the given vectors and the diagonal elements of the given
487/// square matrices, storing the results in the tuple of vectors or matrices given as the first
488/// argument. Most efficient if only the first argument contains matrices, and all other arguments
489/// are column vectors.
490template <class F, simdifiable... VAs, simdifiable... VBs>
491 requires simdify_compatible<VAs..., VBs...>
492void transform_n_diag(F &&fun, std::tuple<VAs...> As, VBs &&...Bs) {
493 constexpr auto check_size = [](auto &x) {
494 if (x.rows() == x.cols())
495 return x.rows();
496 else if (x.storage_order == StorageOrder::ColMajor) {
497 BATMAT_ASSERT(x.cols() == 1);
498 return x.rows();
499 } else {
500 BATMAT_ASSERT(x.rows() == 1);
501 return x.cols();
502 }
503 };
504 [[maybe_unused]] const index_t n = check_size(std::get<0>(As));
505 [&]<size_t... Is>(std::index_sequence<Is...>) {
506 BATMAT_ASSERT(((check_size(get<Is>(As)) == n) && ...));
507 }(std::index_sequence_for<VAs...>());
508 BATMAT_ASSERT(((check_size(Bs) == n) && ...));
509 using VA0 = std::tuple_element_t<0, decltype(As)>;
510 detail::iter_diag_store_n<simdified_value_t<VA0>, simdified_abi_t<VA0>>(
511 std::forward<F>(fun),
512 std::apply([](auto &&...a) { return std::make_tuple(simdify(a)...); }, As),
513 simdify(Bs).as_const()...);
514}
515
516/// Copy the diagonal elements of a matrix. The arguments @p A and @p B must either be square
517/// matrices or vectors. This function supports setting the diagonal of a matrix to the values of
518/// a vector, copying the diagonal of one matrix to the diagonal of another, or copying the diagonal
519/// elements of a matrix to a vector.
520template <class F, simdifiable VA, simdifiable VB>
522void copy_diag(VA &&A, VB &&B) {
523 [[maybe_unused]] const index_t n =
524 A.storage_order == StorageOrder::ColMajor ? A.rows() : A.cols();
525 if constexpr (A.storage_order == StorageOrder::ColMajor) {
526 BATMAT_ASSERT(A.rows() == n);
527 BATMAT_ASSERT(A.cols() == n || A.cols() == 1);
528 } else {
529 BATMAT_ASSERT(A.rows() == n || A.rows() == 1);
530 BATMAT_ASSERT(A.cols() == n);
531 }
532 if constexpr (B.storage_order == StorageOrder::ColMajor) {
533 BATMAT_ASSERT(B.rows() == n);
534 BATMAT_ASSERT(B.cols() == n || B.cols() == 1);
535 } else {
536 BATMAT_ASSERT(B.rows() == n || B.rows() == 1);
537 BATMAT_ASSERT(B.cols() == n);
538 }
539 GUANAQO_TRACE_LINALG("copy_diag", n * A.depth());
540 detail::iter_diag_store_n<simdified_value_t<VA>, simdified_abi_t<VA>>(
541 [](auto Ai) { return std::make_tuple(Ai); }, std::make_tuple(simdify(B)),
542 simdify(A).as_const());
543}
544
545/// C = A + diag(b).
546template <simdifiable VA, simdifiable VB, simdifiable VC>
548void add_diag(VA &&A, VB &&b, VC &&C) {
549 BATMAT_ASSERT(A.rows() == A.cols());
550 BATMAT_ASSERT(A.rows() == C.rows());
551 BATMAT_ASSERT(A.cols() == C.cols());
552 if constexpr (b.storage_order == StorageOrder::ColMajor) {
553 BATMAT_ASSERT(b.rows() == A.rows());
554 BATMAT_ASSERT(b.cols() == 1);
555 } else {
556 BATMAT_ASSERT(b.rows() == 1);
557 BATMAT_ASSERT(b.cols() == A.cols());
558 }
559 GUANAQO_TRACE_LINALG("add_diag", detail::num_elem(simdify(b)));
560 detail::iter_diag_store_n<simdified_value_t<VA>, simdified_abi_t<VA>>(
561 [](auto Ai, auto bi) { return std::make_tuple(Ai + bi); }, std::make_tuple(simdify(C)),
562 simdify(A).as_const(), simdify(b).as_const());
563}
564
565/// A += diag(b).
566template <simdifiable VA, simdifiable VB>
568void add_diag(VA &&A, VB &&b) {
569 add_diag(std::forward<VA>(A), std::forward<VB>(b), std::forward<VA>(A));
570}
571
572/// @}
573
574/// @}
575
576// TODO: doxygen gets confused because the template parameters are the same as the single-batch
577// versions, so put in a separate namespace
578inline namespace multi {
579
580/// @addtogroup topic-linalg
581/// @{
582
583/// @name Multi-batch elementwise operations
584/// @{
585
586/// Multiply a vector by a scalar z = αx.
587template <simdifiable_multi Vx, simdifiable_multi Vz, std::convertible_to<simdified_value_t<Vx>> T>
589void scale(T alpha, Vx &&x, Vz &&z) {
590 BATMAT_ASSERT(x.num_batches() == z.num_batches());
591 for (index_t b = 0; b < x.num_batches(); ++b)
592 linalg::scale(alpha, x.batch(b), z.batch(b));
593}
594
595/// Multiply a vector by a scalar x = αx.
596template <simdifiable_multi Vx, std::convertible_to<simdified_value_t<Vx>> T>
597void scale(T alpha, Vx &&x) {
598 for (index_t b = 0; b < x.num_batches(); ++b)
599 linalg::scale(alpha, x.batch(b));
600}
601
602/// Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
603template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz>
605void hadamard(Vx &&x, Vy &&y, Vz &&z) {
606 BATMAT_ASSERT(x.num_batches() == y.num_batches());
607 BATMAT_ASSERT(x.num_batches() == z.num_batches());
608 for (index_t b = 0; b < x.num_batches(); ++b)
609 linalg::hadamard(x.batch(b), y.batch(b), z.batch(b));
610}
611
612/// Compute the Hadamard (elementwise) product of two vectors x = x ⊙ y.
613template <simdifiable_multi Vx, simdifiable_multi Vy>
615void hadamard(Vx &&x, Vy &&y) {
616 BATMAT_ASSERT(x.num_batches() == y.num_batches());
617 for (index_t b = 0; b < x.num_batches(); ++b)
618 linalg::hadamard(x.batch(b), y.batch(b));
619}
620
621/// Elementwise clamping z = max(lo, min(x, hi)).
622template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
624void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
625 BATMAT_ASSERT(x.num_batches() == lo.num_batches());
626 BATMAT_ASSERT(x.num_batches() == hi.num_batches());
627 BATMAT_ASSERT(x.num_batches() == z.num_batches());
628 for (index_t b = 0; b < x.num_batches(); ++b)
629 linalg::clamp(x.batch(b), lo.batch(b), hi.batch(b), z.batch(b));
630}
631
632/// Elementwise clamping residual z = x - max(lo, min(x, hi)).
633template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
635void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
636 BATMAT_ASSERT(x.num_batches() == lo.num_batches());
637 BATMAT_ASSERT(x.num_batches() == hi.num_batches());
638 BATMAT_ASSERT(x.num_batches() == z.num_batches());
639 for (index_t b = 0; b < x.num_batches(); ++b)
640 linalg::clamp_resid(x.batch(b), lo.batch(b), hi.batch(b), z.batch(b));
641}
642
643/// Elementwise clamping z = max(lo, min(x, hi)), with scalar lo and hi.
644template <simdifiable_multi Vx, simdifiable_multi Vz>
647 BATMAT_ASSERT(x.num_batches() == z.num_batches());
648 for (index_t b = 0; b < x.num_batches(); ++b)
649 linalg::clamp(x.batch(b), lo, hi, z.batch(b));
650}
651
652/// Add scaled vector z = αx + βy.
653template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz, //
654 std::convertible_to<simdified_value_t<Vx>> Ta,
655 std::convertible_to<simdified_value_t<Vx>> Tb>
657void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
658 BATMAT_ASSERT(x.num_batches() == y.num_batches());
659 BATMAT_ASSERT(x.num_batches() == z.num_batches());
660 for (index_t b = 0; b < x.num_batches(); ++b)
661 linalg::axpby(alpha, x.batch(b), beta, y.batch(b), z.batch(b));
662}
663
664/// Add scaled vector y = αx + βy.
665template <simdifiable_multi Vx, simdifiable_multi Vy, //
666 std::convertible_to<simdified_value_t<Vx>> Ta,
667 std::convertible_to<simdified_value_t<Vx>> Tb>
669void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
670 BATMAT_ASSERT(x.num_batches() == y.num_batches());
671 for (index_t b = 0; b < x.num_batches(); ++b)
672 linalg::axpby(alpha, x.batch(b), beta, y.batch(b));
673}
674
675/// Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
676template <auto Beta = 1, simdifiable_multi Vy, simdifiable_multi... Vx>
677 requires simdify_compatible<Vy, Vx...>
678void axpy(Vy &&y, const std::array<simdified_value_t<Vy>, sizeof...(Vx)> &alphas, Vx &&...x) {
679 BATMAT_ASSERT(((y.num_batches() == x.num_batches()) && ...));
680 for (index_t b = 0; b < y.num_batches(); ++b)
681 linalg::axpy<Beta>(y.batch(b), alphas, x.batch(b)...);
682}
683
684/// Add scaled vector z = αx + y.
685template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz,
686 std::convertible_to<simdified_value_t<Vx>> Ta>
688void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
689 axpby(alpha, x, 1, y, z);
690}
691
692/// Add scaled vector y = αx + βy (where β is a compile-time constant).
693template <auto Beta = 1, simdifiable_multi Vx, simdifiable_multi Vy,
694 std::convertible_to<simdified_value_t<Vx>> Ta>
696void axpy(Ta alpha, Vx &&x, Vy &&y) {
697 BATMAT_ASSERT(x.num_batches() == y.num_batches());
698 for (index_t b = 0; b < x.num_batches(); ++b)
699 linalg::axpy<Beta>(alpha, x.batch(b), y.batch(b));
700}
701
702/// Negate a matrix or vector B = -A.
703template <simdifiable_multi VA, simdifiable_multi VB, int Rotate = 0>
705void negate(VA &&A, VB &&B, with_rotate_t<Rotate> rot = {}) {
706 BATMAT_ASSERT(A.num_batches() == B.num_batches());
707 for (index_t b = 0; b < A.num_batches(); ++b)
708 linalg::negate(A.batch(b), B.batch(b), rot);
709}
710
711/// Negate a matrix or vector A = -A.
712template <simdifiable_multi VA, int Rotate = 0>
713void negate(VA &&A, with_rotate_t<Rotate> rot = {}) {
714 for (index_t b = 0; b < A.num_batches(); ++b)
715 linalg::negate(A.batch(b), rot);
716}
717
718/// Subtract two matrices or vectors C = A - B. Rotate affects B.
719template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC, int Rotate = 0>
721void sub(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> rot = {}) {
722 BATMAT_ASSERT(A.num_batches() == B.num_batches());
723 BATMAT_ASSERT(A.num_batches() == C.num_batches());
724 for (index_t b = 0; b < A.num_batches(); ++b)
725 linalg::sub(A.batch(b), B.batch(b), C.batch(b), rot);
726}
727
728/// Subtract two matrices or vectors A = A - B. Rotate affects B.
729template <simdifiable_multi VA, simdifiable_multi VB, int Rotate = 0>
731void sub(VA &&A, VB &&B, with_rotate_t<Rotate> rot = {}) {
732 BATMAT_ASSERT(A.num_batches() == B.num_batches());
733 for (index_t b = 0; b < A.num_batches(); ++b)
734 linalg::sub(A.batch(b), B.batch(b), rot);
735}
736
737/// Add two matrices or vectors C = A + B. Rotate affects B.
738template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC, int Rotate = 0>
740void add(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> rot = {}) {
741 BATMAT_ASSERT(A.num_batches() == B.num_batches());
742 BATMAT_ASSERT(A.num_batches() == C.num_batches());
743 for (index_t b = 0; b < A.num_batches(); ++b)
744 linalg::add(A.batch(b), B.batch(b), C.batch(b), rot);
745}
746
747/// Add two matrices or vectors A = A + B. Rotate affects B.
748template <simdifiable_multi VA, simdifiable_multi VB, int Rotate = 0>
750void add(VA &&A, VB &&B, with_rotate_t<Rotate> rot = {}) {
751 BATMAT_ASSERT(A.num_batches() == B.num_batches());
752 for (index_t b = 0; b < A.num_batches(); ++b)
753 linalg::add(A.batch(b), B.batch(b), rot);
754}
755
756/// Apply a function to all elements of the given matrices or vectors.
757template <class F, simdifiable_multi VA, simdifiable_multi... VAs>
758 requires simdify_compatible<VA, VAs...>
759void for_each_elementwise(F &&fun, VA &&A, VAs &&...As) {
760 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
761 for (index_t b = 0; b < A.num_batches(); ++b)
762 linalg::for_each_elementwise(fun, A.batch(b), As.batch(b)...);
763}
764
765/// Apply a function to all elements of the given matrices or vectors, storing the result in the
766/// first argument.
767template <class F, simdifiable_multi VA, simdifiable_multi... VAs>
768 requires simdify_compatible<VA, VAs...>
769void transform_elementwise(F &&fun, VA &&A, VAs &&...As) {
770 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
771 for (index_t b = 0; b < A.num_batches(); ++b)
772 linalg::transform_elementwise(fun, A.batch(b), As.batch(b)...);
773}
774
775/// Apply a function to all elements of the given matrices or vectors, storing the results in the
776/// first two arguments.
777template <class F, simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi... VAs>
778 requires simdify_compatible<VA, VB, VAs...>
779void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As) {
780 BATMAT_ASSERT(A.num_batches() == B.num_batches());
781 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
782 for (index_t b = 0; b < A.num_batches(); ++b)
783 linalg::transform2_elementwise(fun, A.batch(b), B.batch(b), As.batch(b)...);
784}
785
786/// Apply a function to all elements of the given matrices or vectors, storing the results in the
787/// tuple of matrices given as the first argument.
788template <class F, simdifiable_multi... VAs, simdifiable_multi... VBs>
789 requires simdify_compatible<VAs..., VBs...>
790void transform_n_elementwise(F &&fun, std::tuple<VAs...> As, VBs &&...Bs) {
791 using std::get;
792 auto &&a0 = get<0>(As);
793 BATMAT_ASSERT(((a0.num_batches() == Bs.num_batches()) && ...));
794 BATMAT_ASSERT([&]<std::size_t... Is>(std::index_sequence<Is...>) {
795 return ((a0.num_batches() == get<Is>(As).num_batches()) && ...);
796 }(std::make_index_sequence<sizeof...(VAs)>()));
797 for (index_t b = 0; b < a0.num_batches(); ++b)
799 fun, std::apply([&](auto &&...a) { return std::make_tuple(a.batch(b)...); }, As),
800 Bs.batch(b)...);
801}
802
803/// @}
804
805/// @}
806
807} // namespace multi
808
809} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
void scale(T alpha, Vx &&x, Vz &&z)
Multiply a vector by a scalar z = αx.
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate > rot={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
void hadamard(Vx &&x, Vy &&y, Vz &&z)
Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
void hadamard(Vx &&x, Vy &&y, Vz &&z)
Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
void axpy(Vy &&y, const std::array< simdified_value_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
void transform_n_elementwise(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given matrices or vectors, storing the results in the tuple o...
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate > rot={})
Add two matrices or vectors C = A + B. Rotate affects B.
void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping z = max(lo, min(x, hi)).
void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping residual z = x - max(lo, min(x, hi)).
void transform_n_diag(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given vectors and the diagonal elements of the given square m...
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
void scale(T alpha, Vx &&x, Vz &&z)
Multiply a vector by a scalar z = αx.
void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping residual z = x - max(lo, min(x, hi)).
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Add two matrices or vectors C = A + B. Rotate affects B.
void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the results in the first t...
void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping z = max(lo, min(x, hi)).
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
void axpy(Vy &&y, const std::array< simdified_value_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
void negate(VA &&A, VB &&B, with_rotate_t< Rotate >={})
Negate a matrix or vector B = -A.
void negate(VA &&A, VB &&B, with_rotate_t< Rotate > rot={})
Negate a matrix or vector B = -A.
void transform_n_elementwise(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given matrices or vectors, storing the results in the tuple o...
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
void copy_diag(VA &&A, VB &&B)
Copy the diagonal elements of a matrix.
void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the results in the first t...
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
void add_diag(VA &&A, VB &&b, VC &&C)
C = A + diag(b).
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
datapar::simd< F, Abi > rotl(datapar::simd< F, Abi > x)
Rotates the elements of x by s positions to the left.
Definition rotate.hpp:226
#define GUANAQO_TRACE_LINALG(name, gflops)
stdx::simd< Tp, Abi > simd
Definition simd.hpp:102
typename detail::simdified_value< V >::type simdified_value_t
Definition simdify.hpp:206
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:208
typename simdified_view_type< V >::type simdified_view_t
Definition simdify.hpp:183
constexpr bool simdify_compatible
Definition simdify.hpp:211
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:218
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
int index_t
Definition config.hpp:13
constexpr auto cols(const MatrixView< T, I, S, O > &v)
Definition simdify.hpp:20
constexpr auto rows(const MatrixView< T, I, S, O > &v)
Definition simdify.hpp:16