batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
elementwise.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
4#include <batmat/config.hpp>
9#include <batmat/simd.hpp>
10#include <guanaqo/trace.hpp>
11#include <array>
12#include <cmath>
13#include <concepts>
14#include <tuple>
15#include <utility>
16
17namespace batmat::linalg {
18
19/// @cond DETAIL
20
21namespace detail {
22
23constexpr index_t num_elem(const auto &A) { return A.rows() * A.cols() * A.depth(); }
24
25template <class T, class Abi, StorageOrder O, class F, class X, class... Xs>
26[[gnu::always_inline]] inline void iter_elems(F &&fun, X &&x, Xs &&...xs) {
27 using types = simd_view_types<T, Abi>;
28 if constexpr (O == StorageOrder::ColMajor) {
29 for (index_t c = 0; c < x.cols(); ++c)
30 for (index_t r = 0; r < x.rows(); ++r)
31 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
32 } else {
33 for (index_t r = 0; r < x.rows(); ++r)
34 for (index_t c = 0; c < x.cols(); ++c)
35 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
36 }
37}
38
39template <class T, class Abi, StorageOrder O, class F, class X, class... Xs>
40[[gnu::always_inline]] inline void iter_elems_store(F &&fun, X &&x, Xs &&...xs) {
41 using types = simd_view_types<T, Abi>;
42 if constexpr (O == StorageOrder::ColMajor) {
43 for (index_t c = 0; c < x.cols(); ++c)
44 for (index_t r = 0; r < x.rows(); ++r)
45 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
46 } else {
47 for (index_t r = 0; r < x.rows(); ++r)
48 for (index_t c = 0; c < x.cols(); ++c)
49 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
50 }
51}
52
53template <class T, class Abi, StorageOrder O, class F, class X0, class X1, class... Xs>
54[[gnu::always_inline]] inline void iter_elems_store2(F &&fun, X0 &&x0, X1 &&x1, Xs &&...xs) {
55 using types = simd_view_types<T, Abi>;
56 if constexpr (O == StorageOrder::ColMajor) {
57 for (index_t c = 0; c < x0.cols(); ++c)
58 for (index_t r = 0; r < x0.rows(); ++r) {
59 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
60 types::aligned_store(r0, &x0(0, r, c));
61 types::aligned_store(r1, &x1(0, r, c));
62 }
63 } else {
64 for (index_t r = 0; r < x0.rows(); ++r)
65 for (index_t c = 0; c < x0.cols(); ++c) {
66 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
67 types::aligned_store(r0, &x0(0, r, c));
68 types::aligned_store(r1, &x1(0, r, c));
69 }
70 }
71}
72
73template <class T, class Abi, StorageOrder O, class F, class... Ys, class... Xs>
74[[gnu::always_inline]] inline void iter_elems_store_n(F &&fun, std::tuple<Ys...> ys, Xs &&...xs) {
75 using std::get;
76 using types = simd_view_types<T, Abi>;
77 const index_t rows = std::get<0>(ys).rows(), cols = std::get<0>(ys).cols();
78 if constexpr (O == StorageOrder::ColMajor) {
79 for (index_t c = 0; c < cols; ++c)
80 for (index_t r = 0; r < rows; ++r) {
81 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
82 static_assert(std::tuple_size_v<decltype(rs)> == sizeof...(Ys));
83 [&]<size_t... Is>(std::index_sequence<Is...>) {
84 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
85 }(std::index_sequence_for<Ys...>());
86 }
87 } else {
88 for (index_t r = 0; r < rows; ++r)
89 for (index_t c = 0; c < cols; ++c) {
90 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
91 static_assert(std::tuple_size_v<decltype(rs)> == sizeof...(Ys));
92 [&]<size_t... Is>(std::index_sequence<Is...>) {
93 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
94 }(std::index_sequence_for<Ys...>());
95 }
96 }
97}
98
99/// Iterate element-wise over the diagonal elements of matrices and over the elements of vectors.
100/// Any argument can either be a square matrix or a vector, but this function has a fast path
101/// when the first input and output arguments are square matrices and all others are column vectors.
102template <class T, class Abi, class F, class... Ys, class... Xs>
103[[gnu::always_inline]] inline void iter_diag_store_n(F &&fun, std::tuple<Ys...> ys, Xs &&...xs) {
104 using std::get;
105 using types = simd_view_types<T, Abi>;
106 const index_t rows = std::get<0>(ys).rows(), cols = std::get<0>(ys).cols();
107 const index_t n = std::get<0>(ys).storage_order == StorageOrder::ColMajor
108 ? (cols > 1 ? std::min(rows, cols) : rows)
109 : (rows > 1 ? std::min(rows, cols) : cols);
110 // Optimized implementation for the special case where the first input and the first output are
111 // square matrices and all others are column vectors.
112 static constexpr auto all_vectors_except_first = [](auto &x0, auto &...x1s) {
113 return x0.rows() == x0.cols() &&
114 ((x1s.storage_order == StorageOrder::ColMajor && x1s.cols() == 1) && ...);
115 };
116 const bool all_xs_vectors_except_first = all_vectors_except_first(xs...);
117 const bool all_ys_vector_except_first = std::apply(all_vectors_except_first, ys);
118 if (all_xs_vectors_except_first && all_ys_vector_except_first) {
119 for (index_t r = 0; r < n; ++r) {
120 auto rs = [&](auto &x0, auto &...x1s) {
121 return fun(types::aligned_load(&x0(0, r, r)),
122 types::aligned_load(&x1s(0, r, 0))...);
123 }(xs...);
124 static_assert(std::tuple_size_v<decltype(rs)> == sizeof...(Ys));
125 [&]<size_t... Is>(std::index_sequence<Is...>) {
126 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, Is == 0 ? r : 0))), ...);
127 }(std::index_sequence_for<Ys...>());
128 }
129 }
130 // Fully generic implementation
131 // (GCC does not seem to hoist the conditionals in the access function out of the loop, so we
132 // pay for some cmovs here, even at -O3. Should be fine though, since we're most likely memory-
133 // bound anyway.)
134 else {
135 static constexpr auto access = [](auto &x, index_t r) -> auto & {
136 return x.storage_order == StorageOrder::ColMajor
137 ? (x.cols() > 1 ? x(0, r, r) : x(0, r, 0))
138 : (x.rows() > 1 ? x(0, r, r) : x(0, 0, r));
139 };
140 for (index_t r = 0; r < n; ++r) {
141 auto rs = fun(types::aligned_load(&access(xs, r))...);
142 static_assert(std::tuple_size_v<decltype(rs)> == sizeof...(Ys));
143 [&]<size_t... Is>(std::index_sequence<Is...>) {
144 ((types::aligned_store(get<Is>(rs), &access(get<Is>(ys), r))), ...);
145 }(std::index_sequence_for<Ys...>());
146 }
147 }
148}
149
150/// Scalar product.
151template <class T, class Abi, StorageOrder OB, StorageOrder OC>
153 BATMAT_ASSERT(B.rows() == C.rows());
154 BATMAT_ASSERT(B.cols() == C.cols());
155 iter_elems_store<T, Abi, OC>([&](auto Bi) { return a * Bi; }, C, B);
156}
157
158/// Hadamard (elementwise) product.
159template <class T, class Abi, StorageOrder OA, StorageOrder OB, StorageOrder OC>
162 BATMAT_ASSERT(A.rows() == B.rows());
163 BATMAT_ASSERT(A.cols() == B.cols());
164 BATMAT_ASSERT(A.rows() == C.rows());
165 BATMAT_ASSERT(A.cols() == C.cols());
166 iter_elems_store<T, Abi, OC>([&](auto Ai, auto Bi) { return Ai * Bi; }, C, A, B);
167}
168
169/// Elementwise clamping z = max(lo, min(x, hi)).
170template <class T, class Abi, StorageOrder O>
171[[gnu::flatten]] void clamp(view<const T, Abi, O> x, view<const T, Abi, O> lo,
173 BATMAT_ASSERT(x.rows() == lo.rows());
174 BATMAT_ASSERT(x.cols() == lo.cols());
175 BATMAT_ASSERT(x.rows() == hi.rows());
176 BATMAT_ASSERT(x.cols() == hi.cols());
177 BATMAT_ASSERT(x.rows() == z.rows());
178 BATMAT_ASSERT(x.cols() == z.cols());
179 const auto clamp = [&](auto xi, auto loi, auto hii) { return fmax(loi, fmin(xi, hii)); };
180 iter_elems_store<T, Abi, O>(clamp, z, x, lo, hi);
181}
182
183/// Elementwise clamping z = max(lo, min(x, hi)), with scalar lo and hi.
184template <class T, class Abi, StorageOrder O>
185[[gnu::flatten]] void clamp(view<const T, Abi, O> x, datapar::simd<T, Abi> lo,
187 BATMAT_ASSERT(x.rows() == z.rows());
188 BATMAT_ASSERT(x.cols() == z.cols());
189 const auto clamp = [&](auto xi) { return fmax(lo, fmin(xi, hi)); };
190 iter_elems_store<T, Abi, O>(clamp, z, x);
191}
192
193/// Elementwise clamping residual z = x - max(lo, min(x, hi)).
194template <class T, class Abi, StorageOrder O>
197 BATMAT_ASSERT(x.rows() == lo.rows());
198 BATMAT_ASSERT(x.cols() == lo.cols());
199 BATMAT_ASSERT(x.rows() == hi.rows());
200 BATMAT_ASSERT(x.cols() == hi.cols());
201 BATMAT_ASSERT(x.rows() == z.rows());
202 BATMAT_ASSERT(x.cols() == z.cols());
204 const auto clamp_resid = [&](auto xi, auto loi, auto hii) {
205 return fmax(xi - hii, fmin(simd{0}, xi - loi));
206 };
207 iter_elems_store<T, Abi, O>(clamp_resid, z, x, lo, hi);
208}
209
210/// Linear combination of vectors z = beta * z + sum_i alpha_i * x_i.
211template <class T, class Abi, T Beta, StorageOrder O, class... Xs>
212[[gnu::flatten]] void gaxpby(view<T, Abi, O> z,
213 const std::array<datapar::simd<T, Abi>, sizeof...(Xs)> &alphas,
214 const Xs &...xs) {
215 BATMAT_ASSERT(((z.rows() == xs.rows()) && ...));
216 BATMAT_ASSERT(((z.cols() == xs.cols()) && ...));
217 if constexpr (Beta == 0)
218 iter_elems_store<T, Abi, O>(
219 [&](auto... xis) {
220 return [&]<std::size_t... Is>(std::index_sequence<Is...>, auto... xis) {
221 return ((xis * alphas[Is]) + ...);
222 }(std::make_index_sequence<sizeof...(Xs)>(), xis...);
223 },
224 z, xs...);
225 else
226 iter_elems_store<T, Abi, O>(
227 [&](auto zi, auto... xis) {
228 return [&]<std::size_t... Is>(std::index_sequence<Is...>, auto... xis) {
229 return zi * Beta + ((xis * alphas[Is]) + ...);
230 }(std::make_index_sequence<sizeof...(Xs)>(), xis...);
231 },
232 z, z, xs...);
233}
234
235/// Negate a matrix or vector.
236/// @todo: add Negate option to batmat::linalg::copy and remove this function, then this also
237/// supports transposition.
238template <class T, class Abi, int Rotate, StorageOrder OA, StorageOrder OB>
239[[gnu::flatten]] void negate(view<const T, Abi, OA> A, view<T, Abi, OB> B) {
240 BATMAT_ASSERT(A.rows() == B.rows());
241 BATMAT_ASSERT(A.cols() == B.cols());
242 using ops::rotl;
243 iter_elems_store<T, Abi, OB>([&](auto Ai) { return -rotl<Rotate>(Ai); }, B, A);
244}
245
246/// Subtract two matrices or vectors C = A - B.
247template <class T, class Abi, int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
249 BATMAT_ASSERT(A.rows() == B.rows());
250 BATMAT_ASSERT(A.cols() == B.cols());
251 BATMAT_ASSERT(A.rows() == C.rows());
252 BATMAT_ASSERT(A.cols() == C.cols());
253 using ops::rotl;
254 iter_elems_store<T, Abi, OC>([&](auto Ai, auto Bi) { return Ai - rotl<Rotate>(Bi); }, C, A, B);
255}
256
257/// Add two matrices or vectors C = A + B.
258template <class T, class Abi, int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
260 BATMAT_ASSERT(A.rows() == B.rows());
261 BATMAT_ASSERT(A.cols() == B.cols());
262 BATMAT_ASSERT(A.rows() == C.rows());
263 BATMAT_ASSERT(A.cols() == C.cols());
264 using ops::rotl;
265 iter_elems_store<T, Abi, OC>([&](auto Ai, auto Bi) { return Ai + rotl<Rotate>(Bi); }, C, A, B);
266}
267
268} // namespace detail
269
270/// @endcond
271
272/// @addtogroup topic-linalg
273/// @{
274
275/// @name Single-batch elementwise operations
276/// @{
277
278/// Multiply a vector by a scalar z = αx.
279template <simdifiable Vx, simdifiable Vz, std::convertible_to<simdified_simd_t<Vx>> T>
281void scale(T alpha, Vx &&x, Vz &&z) {
282 GUANAQO_TRACE_LINALG("scale", detail::num_elem(simdify(x)));
283 detail::scale<simdified_value_t<Vx>, simdified_abi_t<Vx>>(alpha, simdify(x).as_const(),
284 simdify(z));
285}
286
287/// Multiply a vector by a scalar x = αx.
288template <simdifiable Vx, std::convertible_to<simdified_simd_t<Vx>> T>
289void scale(T alpha, Vx &&x) {
290 GUANAQO_TRACE_LINALG("scale", detail::num_elem(simdify(x)));
291 detail::scale<simdified_value_t<Vx>, simdified_abi_t<Vx>>(alpha, simdify(x).as_const(),
292 simdify(x));
293}
294
295/// Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
296template <simdifiable Vx, simdifiable Vy, simdifiable Vz>
298void hadamard(Vx &&x, Vy &&y, Vz &&z) {
299 GUANAQO_TRACE_LINALG("hadamard", detail::num_elem(simdify(x)));
300 detail::hadamard<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(),
301 simdify(y).as_const(), simdify(z));
302}
303
304/// Compute the Hadamard (elementwise) product of two vectors x = x ⊙ y.
305template <simdifiable Vx, simdifiable Vy>
307void hadamard(Vx &&x, Vy &&y) {
308 GUANAQO_TRACE_LINALG("hadamard", detail::num_elem(simdify(x)));
309 detail::hadamard<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(),
310 simdify(y).as_const(), simdify(x));
311}
312
313/// Elementwise clamping z = max(lo, min(x, hi)).
314template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
316void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
317 GUANAQO_TRACE_LINALG("clamp", 2 * detail::num_elem(simdify(x))); // max, min
318 detail::clamp<simdified_value_t<Vx>, simdified_abi_t<Vx>>(
319 simdify(x).as_const(), simdify(lo).as_const(), simdify(hi).as_const(), simdify(z));
320}
321
322/// Elementwise clamping residual z = x - max(lo, min(x, hi)).
323template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
325void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
326 GUANAQO_TRACE_LINALG("clamp_resid", 3 * detail::num_elem(simdify(x))); // sub, max, min
327 detail::clamp_resid<simdified_value_t<Vx>, simdified_abi_t<Vx>>(
328 simdify(x).as_const(), simdify(lo).as_const(), simdify(hi).as_const(), simdify(z));
329}
330
331/// Elementwise clamping z = max(lo, min(x, hi)), with scalar lo and hi.
332template <simdifiable Vx, simdifiable Vz>
334void clamp(Vx &&x, simdified_simd_t<Vx> lo, simdified_simd_t<Vx> hi, Vz &&z) {
335 GUANAQO_TRACE_LINALG("clamp", 2 * detail::num_elem(simdify(x))); // max, min
336 detail::clamp<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(), lo, hi,
337 simdify(z));
338}
339
340/// Add scaled vector z = αx + βy.
341template <simdifiable Vx, simdifiable Vy, simdifiable Vz, //
342 std::convertible_to<simdified_simd_t<Vx>> Ta,
343 std::convertible_to<simdified_simd_t<Vx>> Tb>
345void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
346 GUANAQO_TRACE_LINALG("axpby", 2 * detail::num_elem(simdify(x))); // mul, fma
347 detail::gaxpby<simdified_value_t<Vx>, simdified_abi_t<Vx>, simdified_value_t<Vx>{0}>(
348 simdify(z), {{alpha, beta}}, simdify(x).as_const(), simdify(y).as_const());
349}
350
351/// Add scaled vector y = αx + βy.
352template <simdifiable Vx, simdifiable Vy, //
353 std::convertible_to<simdified_simd_t<Vx>> Ta,
354 std::convertible_to<simdified_simd_t<Vx>> Tb>
356void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
357 GUANAQO_TRACE_LINALG("axpby", 2 * detail::num_elem(simdify(x))); // mul, fma
358 detail::gaxpby<simdified_value_t<Vx>, simdified_abi_t<Vx>, simdified_value_t<Vx>{0}>(
359 simdify(y), {{alpha, beta}}, simdify(x).as_const(), simdify(y).as_const());
360}
361
362/// Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
363template <auto Beta = 1, simdifiable Vy, simdifiable... Vx>
364 requires simdify_compatible<Vy, Vx...>
365void axpy(Vy &&y, const std::array<simdified_simd_t<Vy>, sizeof...(Vx)> &alphas, Vx &&...x) {
366 [[maybe_unused]] static constexpr index_t num_mul = Beta != 1 && Beta != 0 ? 1 : 0;
367 [[maybe_unused]] static constexpr index_t num_fma = sizeof...(Vx);
368 GUANAQO_TRACE_LINALG("axpy", (num_mul + num_fma) * detail::num_elem(simdify(y))); // mul, fma
369 detail::gaxpby<simdified_value_t<Vy>, simdified_abi_t<Vy>, simdified_value_t<Vy>{Beta}>(
370 simdify(y), alphas, simdify(x).as_const()...);
371}
372
373/// Add scaled vector z = αx + y.
374template <simdifiable Vx, simdifiable Vy, simdifiable Vz,
375 std::convertible_to<simdified_simd_t<Vx>> Ta>
377void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
378 axpby(alpha, x, Ta{1}, y, z);
379}
380
381/// Add scaled vector y = αx + βy (where β is a compile-time constant).
382template <auto Beta = 1, simdifiable Vx, simdifiable Vy,
383 std::convertible_to<simdified_simd_t<Vx>> Ta>
385void axpy(Ta alpha, Vx &&x, Vy &&y) {
386 [[maybe_unused]] static constexpr index_t num_mul = Beta != 1 && Beta != 0 ? 1 : 0;
387 [[maybe_unused]] static constexpr index_t num_fma = 1;
388 GUANAQO_TRACE_LINALG("axpy", (num_mul + num_fma) * detail::num_elem(simdify(y))); // mul, fma
389 detail::gaxpby<simdified_value_t<Vx>, simdified_abi_t<Vx>, simdified_value_t<Vx>{Beta}>(
390 simdify(y), {{alpha}}, simdify(x).as_const());
391}
392
393/// Negate a matrix or vector B = -A.
394template <simdifiable VA, simdifiable VB, int Rotate = 0>
396void negate(VA &&A, VB &&B, with_rotate_t<Rotate> = {}) {
397 GUANAQO_TRACE_LINALG("negate", detail::num_elem(simdify(A)));
398 detail::negate<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(simdify(A).as_const(),
399 simdify(B));
400}
401
402/// Negate a matrix or vector A = -A.
403template <simdifiable VA, int Rotate = 0>
404void negate(VA &&A, with_rotate_t<Rotate> = {}) {
405 GUANAQO_TRACE_LINALG("negate", detail::num_elem(simdify(A)));
406 detail::negate<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(simdify(A).as_const(),
407 simdify(A));
408}
409
410/// Subtract two matrices or vectors C = A - B. Rotate affects B.
411template <simdifiable VA, simdifiable VB, simdifiable VC, int Rotate = 0>
413void sub(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> = {}) {
414 GUANAQO_TRACE_LINALG("sub", detail::num_elem(simdify(A)));
415 detail::sub<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
416 simdify(A).as_const(), simdify(B).as_const(), simdify(C));
417}
418
419/// Subtract two matrices or vectors A = A - B. Rotate affects B.
420template <simdifiable VA, simdifiable VB, int Rotate = 0>
422void sub(VA &&A, VB &&B, with_rotate_t<Rotate> = {}) {
423 GUANAQO_TRACE_LINALG("sub", detail::num_elem(simdify(A)));
424 detail::sub<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
425 simdify(A).as_const(), simdify(B).as_const(), simdify(A));
426}
427
428/// Add two matrices or vectors C = A + B. Rotate affects B.
429template <simdifiable VA, simdifiable VB, simdifiable VC, int Rotate = 0>
431void add(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> = {}) {
432 GUANAQO_TRACE_LINALG("add", detail::num_elem(simdify(A)));
433 detail::add<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
434 simdify(A).as_const(), simdify(B).as_const(), simdify(C));
435}
436
437/// Add two matrices or vectors A = A + B. Rotate affects B.
438template <simdifiable VA, simdifiable VB, int Rotate = 0>
440void add(VA &&A, VB &&B, with_rotate_t<Rotate> = {}) {
441 GUANAQO_TRACE_LINALG("add", detail::num_elem(simdify(A)));
442 detail::add<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
443 simdify(A).as_const(), simdify(B).as_const(), simdify(A));
444}
445
446/// Apply a function to all elements of the given matrices or vectors.
447template <class F, simdifiable VA, simdifiable... VAs>
448 requires simdify_compatible<VA, VAs...>
449void for_each_elementwise(F &&fun, VA &&A, VAs &&...As) {
450 static constexpr auto storage_order = simdified_view_t<VA>::storage_order;
451 detail::iter_elems<simdified_value_t<VA>, simdified_abi_t<VA>, storage_order>(
452 std::forward<F>(fun), simdify(A).as_const(), simdify(As).as_const()...);
453}
454
455/// Apply a function to all elements of the given matrices or vectors, storing the result in the
456/// first argument.
457template <class F, simdifiable VA, simdifiable... VAs>
458 requires simdify_compatible<VA, VAs...>
459void transform_elementwise(F &&fun, VA &&A, VAs &&...As) {
460 static constexpr auto storage_order = simdified_view_t<VA>::storage_order;
461 detail::iter_elems_store<simdified_value_t<VA>, simdified_abi_t<VA>, storage_order>(
462 std::forward<F>(fun), simdify(A), simdify(As).as_const()...);
463}
464
465/// Apply a function to all elements of the given matrices or vectors, storing the results in the
466/// first two arguments.
467template <class F, simdifiable VA, simdifiable VB, simdifiable... VAs>
468 requires simdify_compatible<VA, VB, VAs...>
469void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As) {
470 static constexpr auto storage_order = simdified_view_t<VA>::storage_order;
471 detail::iter_elems_store2<simdified_value_t<VA>, simdified_abi_t<VA>, storage_order>(
472 std::forward<F>(fun), simdify(A), simdify(B), simdify(As).as_const()...);
473}
474
475/// Apply a function to all elements of the given matrices or vectors, storing the results in the
476/// tuple of matrices given as the first argument.
477template <class F, simdifiable... VAs, simdifiable... VBs>
478 requires simdify_compatible<VAs..., VBs...>
479void transform_n_elementwise(F &&fun, std::tuple<VAs...> As, VBs &&...Bs) {
480 using VA0 = std::tuple_element_t<0, decltype(As)>;
481 static constexpr auto storage_order = simdified_view_t<VA0>::storage_order;
482 detail::iter_elems_store_n<simdified_value_t<VA0>, simdified_abi_t<VA0>, storage_order>(
483 std::forward<F>(fun),
484 std::apply([](auto &&...a) { return std::make_tuple(simdify(a)...); }, As),
485 simdify(Bs).as_const()...);
486}
487
488/// Apply a function to all elements of the given vectors and the diagonal elements of the given
489/// square matrices, storing the results in the tuple of vectors or matrices given as the first
490/// argument. Most efficient if only the first argument contains matrices, and all other arguments
491/// are column vectors.
492template <class F, simdifiable... VAs, simdifiable... VBs>
493 requires simdify_compatible<VAs..., VBs...>
494void transform_n_diag(F &&fun, std::tuple<VAs...> As, VBs &&...Bs) {
495 constexpr auto check_size = [](auto &x) {
496 if (x.rows() == x.cols())
497 return x.rows();
498 else if (x.storage_order == StorageOrder::ColMajor) {
499 BATMAT_ASSERT(x.cols() == 1);
500 return x.rows();
501 } else {
502 BATMAT_ASSERT(x.rows() == 1);
503 return x.cols();
504 }
505 };
506 [[maybe_unused]] const index_t n = check_size(std::get<0>(As));
507 [&]<size_t... Is>(std::index_sequence<Is...>) {
508 BATMAT_ASSERT(((check_size(get<Is>(As)) == n) && ...));
509 }(std::index_sequence_for<VAs...>());
510 BATMAT_ASSERT(((check_size(Bs) == n) && ...));
511 using VA0 = std::tuple_element_t<0, decltype(As)>;
512 detail::iter_diag_store_n<simdified_value_t<VA0>, simdified_abi_t<VA0>>(
513 std::forward<F>(fun),
514 std::apply([](auto &&...a) { return std::make_tuple(simdify(a)...); }, As),
515 simdify(Bs).as_const()...);
516}
517
518/// Copy the diagonal elements of a matrix. The arguments @p A and @p B must either be square
519/// matrices or vectors. This function supports setting the diagonal of a matrix to the values of
520/// a vector, copying the diagonal of one matrix to the diagonal of another, or copying the diagonal
521/// elements of a matrix to a vector.
522template <class F, simdifiable VA, simdifiable VB>
524void copy_diag(VA &&A, VB &&B) {
525 [[maybe_unused]] const index_t n =
526 A.storage_order == StorageOrder::ColMajor ? A.rows() : A.cols();
527 if constexpr (A.storage_order == StorageOrder::ColMajor) {
528 BATMAT_ASSERT(A.rows() == n);
529 BATMAT_ASSERT(A.cols() == n || A.cols() == 1);
530 } else {
531 BATMAT_ASSERT(A.rows() == n || A.rows() == 1);
532 BATMAT_ASSERT(A.cols() == n);
533 }
534 if constexpr (B.storage_order == StorageOrder::ColMajor) {
535 BATMAT_ASSERT(B.rows() == n);
536 BATMAT_ASSERT(B.cols() == n || B.cols() == 1);
537 } else {
538 BATMAT_ASSERT(B.rows() == n || B.rows() == 1);
539 BATMAT_ASSERT(B.cols() == n);
540 }
541 GUANAQO_TRACE_LINALG("copy_diag", n * A.depth());
542 detail::iter_diag_store_n<simdified_value_t<VA>, simdified_abi_t<VA>>(
543 [](auto Ai) { return std::make_tuple(Ai); }, std::make_tuple(simdify(B)),
544 simdify(A).as_const());
545}
546
547/// C = A + diag(b).
548template <simdifiable VA, simdifiable VB, simdifiable VC>
550void add_diag(VA &&A, VB &&b, VC &&C) {
551 BATMAT_ASSERT(A.rows() == A.cols());
552 BATMAT_ASSERT(A.rows() == C.rows());
553 BATMAT_ASSERT(A.cols() == C.cols());
554 if constexpr (b.storage_order == StorageOrder::ColMajor) {
555 BATMAT_ASSERT(b.rows() == A.rows());
556 BATMAT_ASSERT(b.cols() == 1);
557 } else {
558 BATMAT_ASSERT(b.rows() == 1);
559 BATMAT_ASSERT(b.cols() == A.cols());
560 }
561 GUANAQO_TRACE_LINALG("add_diag", detail::num_elem(simdify(b)));
562 detail::iter_diag_store_n<simdified_value_t<VA>, simdified_abi_t<VA>>(
563 [](auto Ai, auto bi) { return std::make_tuple(Ai + bi); }, std::make_tuple(simdify(C)),
564 simdify(A).as_const(), simdify(b).as_const());
565}
566
567/// A += diag(b).
568template <simdifiable VA, simdifiable VB>
570void add_diag(VA &&A, VB &&b) {
571 add_diag(std::forward<VA>(A), std::forward<VB>(b), std::forward<VA>(A));
572}
573
574/// @}
575
576/// @}
577
578// TODO: doxygen gets confused because the template parameters are the same as the single-batch
579// versions, so put in a separate namespace
580inline namespace multi {
581
582/// @addtogroup topic-linalg
583/// @{
584
585/// @name Multi-batch elementwise operations
586/// @{
587
588/// Multiply a vector by a scalar z = αx.
589template <simdifiable_multi Vx, simdifiable_multi Vz, std::convertible_to<simdified_simd_t<Vx>> T>
591void scale(T alpha, Vx &&x, Vz &&z) {
592 BATMAT_ASSERT(x.num_batches() == z.num_batches());
593 for (index_t b = 0; b < x.num_batches(); ++b)
594 linalg::scale(alpha, x.batch(b), z.batch(b));
595}
596
597/// Multiply a vector by a scalar x = αx.
598template <simdifiable_multi Vx, std::convertible_to<simdified_simd_t<Vx>> T>
599void scale(T alpha, Vx &&x) {
600 for (index_t b = 0; b < x.num_batches(); ++b)
601 linalg::scale(alpha, x.batch(b));
602}
603
604/// Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
605template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz>
607void hadamard(Vx &&x, Vy &&y, Vz &&z) {
608 BATMAT_ASSERT(x.num_batches() == y.num_batches());
609 BATMAT_ASSERT(x.num_batches() == z.num_batches());
610 for (index_t b = 0; b < x.num_batches(); ++b)
611 linalg::hadamard(x.batch(b), y.batch(b), z.batch(b));
612}
613
614/// Compute the Hadamard (elementwise) product of two vectors x = x ⊙ y.
615template <simdifiable_multi Vx, simdifiable_multi Vy>
617void hadamard(Vx &&x, Vy &&y) {
618 BATMAT_ASSERT(x.num_batches() == y.num_batches());
619 for (index_t b = 0; b < x.num_batches(); ++b)
620 linalg::hadamard(x.batch(b), y.batch(b));
621}
622
623/// Elementwise clamping z = max(lo, min(x, hi)).
624template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
626void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
627 BATMAT_ASSERT(x.num_batches() == lo.num_batches());
628 BATMAT_ASSERT(x.num_batches() == hi.num_batches());
629 BATMAT_ASSERT(x.num_batches() == z.num_batches());
630 for (index_t b = 0; b < x.num_batches(); ++b)
631 linalg::clamp(x.batch(b), lo.batch(b), hi.batch(b), z.batch(b));
632}
633
634/// Elementwise clamping residual z = x - max(lo, min(x, hi)).
635template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
637void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
638 BATMAT_ASSERT(x.num_batches() == lo.num_batches());
639 BATMAT_ASSERT(x.num_batches() == hi.num_batches());
640 BATMAT_ASSERT(x.num_batches() == z.num_batches());
641 for (index_t b = 0; b < x.num_batches(); ++b)
642 linalg::clamp_resid(x.batch(b), lo.batch(b), hi.batch(b), z.batch(b));
643}
644
645/// Elementwise clamping z = max(lo, min(x, hi)), with scalar lo and hi.
646template <simdifiable_multi Vx, simdifiable_multi Vz>
648void clamp(Vx &&x, simdified_simd_t<Vx> lo, simdified_simd_t<Vx> hi, Vz &&z) {
649 BATMAT_ASSERT(x.num_batches() == z.num_batches());
650 for (index_t b = 0; b < x.num_batches(); ++b)
651 linalg::clamp(x.batch(b), lo, hi, z.batch(b));
652}
653
654/// Add scaled vector z = αx + βy.
655template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz, //
656 std::convertible_to<simdified_simd_t<Vx>> Ta,
657 std::convertible_to<simdified_simd_t<Vx>> Tb>
659void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
660 BATMAT_ASSERT(x.num_batches() == y.num_batches());
661 BATMAT_ASSERT(x.num_batches() == z.num_batches());
662 for (index_t b = 0; b < x.num_batches(); ++b)
663 linalg::axpby(alpha, x.batch(b), beta, y.batch(b), z.batch(b));
664}
665
666/// Add scaled vector y = αx + βy.
667template <simdifiable_multi Vx, simdifiable_multi Vy, //
668 std::convertible_to<simdified_simd_t<Vx>> Ta,
669 std::convertible_to<simdified_simd_t<Vx>> Tb>
671void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
672 BATMAT_ASSERT(x.num_batches() == y.num_batches());
673 for (index_t b = 0; b < x.num_batches(); ++b)
674 linalg::axpby(alpha, x.batch(b), beta, y.batch(b));
675}
676
677/// Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
678template <auto Beta = 1, simdifiable_multi Vy, simdifiable_multi... Vx>
679 requires simdify_compatible<Vy, Vx...>
680void axpy(Vy &&y, const std::array<simdified_simd_t<Vy>, sizeof...(Vx)> &alphas, Vx &&...x) {
681 BATMAT_ASSERT(((y.num_batches() == x.num_batches()) && ...));
682 for (index_t b = 0; b < y.num_batches(); ++b)
683 linalg::axpy<Beta>(y.batch(b), alphas, x.batch(b)...);
684}
685
686/// Add scaled vector z = αx + y.
687template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz,
688 std::convertible_to<simdified_simd_t<Vx>> Ta>
690void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
691 axpby(alpha, x, 1, y, z);
692}
693
694/// Add scaled vector y = αx + βy (where β is a compile-time constant).
695template <auto Beta = 1, simdifiable_multi Vx, simdifiable_multi Vy,
696 std::convertible_to<simdified_simd_t<Vx>> Ta>
698void axpy(Ta alpha, Vx &&x, Vy &&y) {
699 BATMAT_ASSERT(x.num_batches() == y.num_batches());
700 for (index_t b = 0; b < x.num_batches(); ++b)
701 linalg::axpy<Beta>(alpha, x.batch(b), y.batch(b));
702}
703
704/// Negate a matrix or vector B = -A.
705template <simdifiable_multi VA, simdifiable_multi VB, int Rotate = 0>
707void negate(VA &&A, VB &&B, with_rotate_t<Rotate> rot = {}) {
708 BATMAT_ASSERT(A.num_batches() == B.num_batches());
709 for (index_t b = 0; b < A.num_batches(); ++b)
710 linalg::negate(A.batch(b), B.batch(b), rot);
711}
712
713/// Negate a matrix or vector A = -A.
714template <simdifiable_multi VA, int Rotate = 0>
715void negate(VA &&A, with_rotate_t<Rotate> rot = {}) {
716 for (index_t b = 0; b < A.num_batches(); ++b)
717 linalg::negate(A.batch(b), rot);
718}
719
720/// Subtract two matrices or vectors C = A - B. Rotate affects B.
721template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC, int Rotate = 0>
723void sub(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> rot = {}) {
724 BATMAT_ASSERT(A.num_batches() == B.num_batches());
725 BATMAT_ASSERT(A.num_batches() == C.num_batches());
726 for (index_t b = 0; b < A.num_batches(); ++b)
727 linalg::sub(A.batch(b), B.batch(b), C.batch(b), rot);
728}
729
730/// Subtract two matrices or vectors A = A - B. Rotate affects B.
731template <simdifiable_multi VA, simdifiable_multi VB, int Rotate = 0>
733void sub(VA &&A, VB &&B, with_rotate_t<Rotate> rot = {}) {
734 BATMAT_ASSERT(A.num_batches() == B.num_batches());
735 for (index_t b = 0; b < A.num_batches(); ++b)
736 linalg::sub(A.batch(b), B.batch(b), rot);
737}
738
739/// Add two matrices or vectors C = A + B. Rotate affects B.
740template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC, int Rotate = 0>
742void add(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> rot = {}) {
743 BATMAT_ASSERT(A.num_batches() == B.num_batches());
744 BATMAT_ASSERT(A.num_batches() == C.num_batches());
745 for (index_t b = 0; b < A.num_batches(); ++b)
746 linalg::add(A.batch(b), B.batch(b), C.batch(b), rot);
747}
748
749/// Add two matrices or vectors A = A + B. Rotate affects B.
750template <simdifiable_multi VA, simdifiable_multi VB, int Rotate = 0>
752void add(VA &&A, VB &&B, with_rotate_t<Rotate> rot = {}) {
753 BATMAT_ASSERT(A.num_batches() == B.num_batches());
754 for (index_t b = 0; b < A.num_batches(); ++b)
755 linalg::add(A.batch(b), B.batch(b), rot);
756}
757
758/// Apply a function to all elements of the given matrices or vectors.
759template <class F, simdifiable_multi VA, simdifiable_multi... VAs>
760 requires simdify_compatible<VA, VAs...>
761void for_each_elementwise(F &&fun, VA &&A, VAs &&...As) {
762 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
763 for (index_t b = 0; b < A.num_batches(); ++b)
764 linalg::for_each_elementwise(fun, A.batch(b), As.batch(b)...);
765}
766
767/// Apply a function to all elements of the given matrices or vectors, storing the result in the
768/// first argument.
769template <class F, simdifiable_multi VA, simdifiable_multi... VAs>
770 requires simdify_compatible<VA, VAs...>
771void transform_elementwise(F &&fun, VA &&A, VAs &&...As) {
772 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
773 for (index_t b = 0; b < A.num_batches(); ++b)
774 linalg::transform_elementwise(fun, A.batch(b), As.batch(b)...);
775}
776
777/// Apply a function to all elements of the given matrices or vectors, storing the results in the
778/// first two arguments.
779template <class F, simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi... VAs>
780 requires simdify_compatible<VA, VB, VAs...>
781void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As) {
782 BATMAT_ASSERT(A.num_batches() == B.num_batches());
783 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
784 for (index_t b = 0; b < A.num_batches(); ++b)
785 linalg::transform2_elementwise(fun, A.batch(b), B.batch(b), As.batch(b)...);
786}
787
788/// Apply a function to all elements of the given matrices or vectors, storing the results in the
789/// tuple of matrices given as the first argument.
790template <class F, simdifiable_multi... VAs, simdifiable_multi... VBs>
791 requires simdify_compatible<VAs..., VBs...>
792void transform_n_elementwise(F &&fun, std::tuple<VAs...> As, VBs &&...Bs) {
793 using std::get;
794 auto &&a0 = get<0>(As);
795 BATMAT_ASSERT(((a0.num_batches() == Bs.num_batches()) && ...));
796 BATMAT_ASSERT([&]<std::size_t... Is>(std::index_sequence<Is...>) {
797 return ((a0.num_batches() == get<Is>(As).num_batches()) && ...);
798 }(std::make_index_sequence<sizeof...(VAs)>()));
799 for (index_t b = 0; b < a0.num_batches(); ++b)
801 fun, std::apply([&](auto &&...a) { return std::make_tuple(a.batch(b)...); }, As),
802 Bs.batch(b)...);
803}
804
805/// @}
806
807/// @}
808
809} // namespace multi
810
811} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
void scale(T alpha, Vx &&x, Vz &&z)
Multiply a vector by a scalar z = αx.
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate > rot={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
void hadamard(Vx &&x, Vy &&y, Vz &&z)
Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
void hadamard(Vx &&x, Vy &&y, Vz &&z)
Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
void transform_n_elementwise(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given matrices or vectors, storing the results in the tuple o...
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate > rot={})
Add two matrices or vectors C = A + B. Rotate affects B.
void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping z = max(lo, min(x, hi)).
void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping residual z = x - max(lo, min(x, hi)).
void transform_n_diag(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given vectors and the diagonal elements of the given square m...
void axpy(Vy &&y, const std::array< simdified_simd_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
void scale(T alpha, Vx &&x, Vz &&z)
Multiply a vector by a scalar z = αx.
void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping residual z = x - max(lo, min(x, hi)).
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Add two matrices or vectors C = A + B. Rotate affects B.
void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the results in the first t...
void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping z = max(lo, min(x, hi)).
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
void negate(VA &&A, VB &&B, with_rotate_t< Rotate >={})
Negate a matrix or vector B = -A.
void negate(VA &&A, VB &&B, with_rotate_t< Rotate > rot={})
Negate a matrix or vector B = -A.
void transform_n_elementwise(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given matrices or vectors, storing the results in the tuple o...
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
void copy_diag(VA &&A, VB &&B)
Copy the diagonal elements of a matrix.
void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the results in the first t...
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
void axpy(Vy &&y, const std::array< simdified_simd_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
void add_diag(VA &&A, VB &&b, VC &&C)
C = A + diag(b).
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
datapar::simd< F, Abi > rotl(datapar::simd< F, Abi > x)
Rotates the elements of x by s positions to the left.
Definition rotate.hpp:226
#define GUANAQO_TRACE_LINALG(name, gflops)
stdx::simd< Tp, Abi > simd
Definition simd.hpp:102
typename detail::simdified_value< V >::type simdified_value_t
Definition simdify.hpp:214
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:216
typename detail::simdified_simd< V >::type simdified_simd_t
Definition simdify.hpp:218
typename simdified_view_type< V >::type simdified_view_t
Definition simdify.hpp:183
constexpr bool simdify_compatible
Definition simdify.hpp:221
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
int index_t
Definition config.hpp:13
constexpr auto cols(const MatrixView< T, I, S, O > &v)
Definition simdify.hpp:20
constexpr auto rows(const MatrixView< T, I, S, O > &v)
Definition simdify.hpp:16