10#include <guanaqo/trace.hpp>
23constexpr index_t num_elem(
const auto &A) {
return A.rows() * A.cols() * A.depth(); }
25template <
class T,
class Abi,
StorageOrder O,
class F,
class X,
class... Xs>
26[[gnu::always_inline]]
inline void iter_elems(F &&fun, X &&x, Xs &&...xs) {
27 using types = simd_view_types<T, Abi>;
28 if constexpr (O == StorageOrder::ColMajor) {
29 for (
index_t c = 0; c < x.cols(); ++c)
30 for (
index_t r = 0; r < x.rows(); ++r)
31 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
33 for (
index_t r = 0; r < x.rows(); ++r)
34 for (
index_t c = 0; c < x.cols(); ++c)
35 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
39template <
class T,
class Abi,
StorageOrder O,
class F,
class X,
class... Xs>
40[[gnu::always_inline]]
inline void iter_elems_store(F &&fun, X &&x, Xs &&...xs) {
41 using types = simd_view_types<T, Abi>;
42 if constexpr (O == StorageOrder::ColMajor) {
43 for (
index_t c = 0; c < x.cols(); ++c)
44 for (
index_t r = 0; r < x.rows(); ++r)
45 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
47 for (
index_t r = 0; r < x.rows(); ++r)
48 for (
index_t c = 0; c < x.cols(); ++c)
49 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
53template <
class T,
class Abi,
StorageOrder O,
class F,
class X0,
class X1,
class... Xs>
54[[gnu::always_inline]]
inline void iter_elems_store2(F &&fun, X0 &&x0, X1 &&x1, Xs &&...xs) {
55 using types = simd_view_types<T, Abi>;
56 if constexpr (O == StorageOrder::ColMajor) {
57 for (
index_t c = 0; c < x0.cols(); ++c)
58 for (
index_t r = 0; r < x0.rows(); ++r) {
59 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
60 types::aligned_store(r0, &x0(0, r, c));
61 types::aligned_store(r1, &x1(0, r, c));
64 for (
index_t r = 0; r < x0.rows(); ++r)
65 for (
index_t c = 0; c < x0.cols(); ++c) {
66 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
67 types::aligned_store(r0, &x0(0, r, c));
68 types::aligned_store(r1, &x1(0, r, c));
73template <
class T,
class Abi,
StorageOrder O,
class F,
class... Ys,
class... Xs>
74[[gnu::always_inline]]
inline void iter_elems_store_n(F &&fun, std::tuple<Ys...> ys, Xs &&...xs) {
76 using types = simd_view_types<T, Abi>;
77 const index_t rows = std::get<0>(ys).rows(),
cols = std::get<0>(ys).cols();
78 if constexpr (O == StorageOrder::ColMajor) {
81 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
82 static_assert(std::tuple_size_v<
decltype(rs)> ==
sizeof...(Ys));
83 [&]<
size_t... Is>(std::index_sequence<Is...>) {
84 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
85 }(std::index_sequence_for<Ys...>());
90 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
91 static_assert(std::tuple_size_v<
decltype(rs)> ==
sizeof...(Ys));
92 [&]<
size_t... Is>(std::index_sequence<Is...>) {
93 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
94 }(std::index_sequence_for<Ys...>());
102template <
class T,
class Abi,
class F,
class... Ys,
class... Xs>
103[[gnu::always_inline]]
inline void iter_diag_store_n(F &&fun, std::tuple<Ys...> ys, Xs &&...xs) {
105 using types = simd_view_types<T, Abi>;
106 const index_t rows = std::get<0>(ys).rows(),
cols = std::get<0>(ys).cols();
107 const index_t n = std::get<0>(ys).storage_order == StorageOrder::ColMajor
108 ? (
cols > 1 ? std::min(rows, cols) :
rows)
112 static constexpr auto all_vectors_except_first = [](
auto &x0,
auto &...x1s) {
113 return x0.rows() == x0.cols() &&
114 ((x1s.storage_order == StorageOrder::ColMajor && x1s.cols() == 1) && ...);
116 const bool all_xs_vectors_except_first = all_vectors_except_first(xs...);
117 const bool all_ys_vector_except_first = std::apply(all_vectors_except_first, ys);
118 if (all_xs_vectors_except_first && all_ys_vector_except_first) {
119 for (
index_t r = 0; r < n; ++r) {
120 auto rs = [&](
auto &x0,
auto &...x1s) {
121 return fun(types::aligned_load(&x0(0, r, r)),
122 types::aligned_load(&x1s(0, r, 0))...);
124 static_assert(std::tuple_size_v<
decltype(rs)> ==
sizeof...(Ys));
125 [&]<
size_t... Is>(std::index_sequence<Is...>) {
126 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, Is == 0 ? r : 0))), ...);
127 }(std::index_sequence_for<Ys...>());
135 static constexpr auto access = [](
auto &x,
index_t r) ->
auto & {
136 return x.storage_order == StorageOrder::ColMajor
137 ? (x.cols() > 1 ? x(0, r, r) : x(0, r, 0))
138 : (x.
rows() > 1 ? x(0, r, r) : x(0, 0, r));
140 for (
index_t r = 0; r < n; ++r) {
141 auto rs = fun(types::aligned_load(&access(xs, r))...);
142 static_assert(std::tuple_size_v<
decltype(rs)> ==
sizeof...(Ys));
143 [&]<
size_t... Is>(std::index_sequence<Is...>) {
144 ((types::aligned_store(get<Is>(rs), &access(get<Is>(ys), r))), ...);
145 }(std::index_sequence_for<Ys...>());
151template <
class T,
class Abi, StorageOrder OB, StorageOrder OC>
155 iter_elems_store<T, Abi, OC>([&](
auto Bi) {
return a * Bi; }, C, B);
159template <
class T,
class Abi, StorageOrder OA, StorageOrder OB, StorageOrder OC>
166 iter_elems_store<T, Abi, OC>([&](
auto Ai,
auto Bi) {
return Ai * Bi; }, C, A, B);
170template <
class T,
class Abi, StorageOrder O>
179 const auto clamp = [&](
auto xi,
auto loi,
auto hii) {
return fmax(loi, fmin(xi, hii)); };
180 iter_elems_store<T, Abi, O>(
clamp, z, x, lo, hi);
184template <
class T,
class Abi, StorageOrder O>
189 const auto clamp = [&](
auto xi) {
return fmax(lo, fmin(xi, hi)); };
190 iter_elems_store<T, Abi, O>(
clamp, z, x);
194template <
class T,
class Abi, StorageOrder O>
204 const auto clamp_resid = [&](
auto xi,
auto loi,
auto hii) {
205 return fmax(xi - hii, fmin(simd{0}, xi - loi));
207 iter_elems_store<T, Abi, O>(
clamp_resid, z, x, lo, hi);
211template <
class T,
class Abi, T Beta,
StorageOrder O,
class... Xs>
217 if constexpr (Beta == 0)
218 iter_elems_store<T, Abi, O>(
220 return [&]<std::size_t... Is>(std::index_sequence<Is...>,
auto... xis) {
221 return ((xis * alphas[Is]) + ...);
222 }(std::make_index_sequence<
sizeof...(Xs)>(), xis...);
226 iter_elems_store<T, Abi, O>(
227 [&](
auto zi,
auto... xis) {
228 return [&]<std::size_t... Is>(std::index_sequence<Is...>,
auto... xis) {
229 return zi * Beta + ((xis * alphas[Is]) + ...);
230 }(std::make_index_sequence<
sizeof...(Xs)>(), xis...);
238template <
class T,
class Abi,
int Rotate, StorageOrder OA, StorageOrder OB>
243 iter_elems_store<T, Abi, OB>([&](
auto Ai) {
return -rotl<Rotate>(Ai); }, B, A);
247template <
class T,
class Abi,
int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
254 iter_elems_store<T, Abi, OC>([&](
auto Ai,
auto Bi) {
return Ai - rotl<Rotate>(Bi); }, C, A, B);
258template <
class T,
class Abi,
int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
265 iter_elems_store<T, Abi, OC>([&](
auto Ai,
auto Bi) {
return Ai + rotl<Rotate>(Bi); }, C, A, B);
279template <simdifiable Vx, simdifiable Vz, std::convertible_to<simdified_simd_t<Vx>> T>
281void scale(T alpha, Vx &&x, Vz &&z) {
288template <simdifiable Vx, std::convertible_to<simdified_simd_t<Vx>> T>
296template <simdifiable Vx, simdifiable Vy, simdifiable Vz>
305template <simdifiable Vx, simdifiable Vy>
314template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
316void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
323template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
332template <simdifiable Vx, simdifiable Vz>
341template <simdifiable Vx, simdifiable Vy, simdifiable Vz,
342 std::convertible_to<simdified_simd_t<Vx>> Ta,
343 std::convertible_to<simdified_simd_t<Vx>> Tb>
345void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
352template <simdifiable Vx, simdifiable Vy,
353 std::convertible_to<simdified_simd_t<Vx>> Ta,
354 std::convertible_to<simdified_simd_t<Vx>> Tb>
356void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
363template <
auto Beta = 1, simdifiable Vy, simdifiable... Vx>
366 [[maybe_unused]]
static constexpr index_t num_mul = Beta != 1 && Beta != 0 ? 1 : 0;
367 [[maybe_unused]]
static constexpr index_t num_fma =
sizeof...(Vx);
374template <simdifiable Vx, simdifiable Vy, simdifiable Vz,
375 std::convertible_to<simdified_simd_t<Vx>> Ta>
377void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
378 axpby(alpha, x, Ta{1}, y, z);
382template <
auto Beta = 1, simdifiable Vx, simdifiable Vy,
383 std::convertible_to<simdified_simd_t<Vx>> Ta>
385void axpy(Ta alpha, Vx &&x, Vy &&y) {
386 [[maybe_unused]]
static constexpr index_t num_mul = Beta != 1 && Beta != 0 ? 1 : 0;
387 [[maybe_unused]]
static constexpr index_t num_fma = 1;
394template <simdifiable VA, simdifiable VB,
int Rotate = 0>
403template <simdifiable VA,
int Rotate = 0>
411template <simdifiable VA, simdifiable VB, simdifiable VC,
int Rotate = 0>
420template <simdifiable VA, simdifiable VB,
int Rotate = 0>
429template <simdifiable VA, simdifiable VB, simdifiable VC,
int Rotate = 0>
438template <simdifiable VA, simdifiable VB,
int Rotate = 0>
447template <
class F, simdifiable VA, simdifiable... VAs>
452 std::forward<F>(fun),
simdify(A).as_const(),
simdify(As).as_const()...);
457template <
class F, simdifiable VA, simdifiable... VAs>
467template <
class F, simdifiable VA, simdifiable VB, simdifiable... VAs>
477template <
class F, simdifiable... VAs, simdifiable... VBs>
480 using VA0 = std::tuple_element_t<0,
decltype(As)>;
483 std::forward<F>(fun),
484 std::apply([](
auto &&...a) {
return std::make_tuple(
simdify(a)...); }, As),
492template <
class F, simdifiable... VAs, simdifiable... VBs>
495 constexpr auto check_size = [](
auto &x) {
496 if (x.rows() == x.cols())
498 else if (x.storage_order == StorageOrder::ColMajor) {
506 [[maybe_unused]]
const index_t n = check_size(std::get<0>(As));
507 [&]<
size_t... Is>(std::index_sequence<Is...>) {
509 }(std::index_sequence_for<VAs...>());
511 using VA0 = std::tuple_element_t<0,
decltype(As)>;
513 std::forward<F>(fun),
514 std::apply([](
auto &&...a) {
return std::make_tuple(
simdify(a)...); }, As),
522template <
class F, simdifiable VA, simdifiable VB>
525 [[maybe_unused]]
const index_t n =
526 A.storage_order == StorageOrder::ColMajor ? A.rows() : A.cols();
527 if constexpr (A.storage_order == StorageOrder::ColMajor) {
534 if constexpr (B.storage_order == StorageOrder::ColMajor) {
543 [](
auto Ai) {
return std::make_tuple(Ai); }, std::make_tuple(
simdify(B)),
548template <simdifiable VA, simdifiable VB, simdifiable VC>
554 if constexpr (b.storage_order == StorageOrder::ColMajor) {
563 [](
auto Ai,
auto bi) {
return std::make_tuple(Ai + bi); }, std::make_tuple(
simdify(C)),
568template <simdifiable VA, simdifiable VB>
571 add_diag(std::forward<VA>(A), std::forward<VB>(b), std::forward<VA>(A));
589template <simdifiable_multi Vx, simdifiable_multi Vz, std::convertible_to<simdified_simd_t<Vx>> T>
591void scale(T alpha, Vx &&x, Vz &&z) {
593 for (
index_t b = 0; b < x.num_batches(); ++b)
598template <simdifiable_multi Vx, std::convertible_to<simdified_simd_t<Vx>> T>
600 for (
index_t b = 0; b < x.num_batches(); ++b)
605template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz>
610 for (
index_t b = 0; b < x.num_batches(); ++b)
615template <simdifiable_multi Vx, simdifiable_multi Vy>
619 for (
index_t b = 0; b < x.num_batches(); ++b)
624template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
626void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
630 for (
index_t b = 0; b < x.num_batches(); ++b)
631 linalg::clamp(x.batch(b), lo.batch(b), hi.batch(b), z.batch(b));
635template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
641 for (
index_t b = 0; b < x.num_batches(); ++b)
646template <simdifiable_multi Vx, simdifiable_multi Vz>
650 for (
index_t b = 0; b < x.num_batches(); ++b)
655template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz,
656 std::convertible_to<simdified_simd_t<Vx>> Ta,
657 std::convertible_to<simdified_simd_t<Vx>> Tb>
659void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
662 for (
index_t b = 0; b < x.num_batches(); ++b)
663 linalg::axpby(alpha, x.batch(b), beta, y.batch(b), z.batch(b));
667template <simdifiable_multi Vx, simdifiable_multi Vy,
668 std::convertible_to<simdified_simd_t<Vx>> Ta,
669 std::convertible_to<simdified_simd_t<Vx>> Tb>
671void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
673 for (
index_t b = 0; b < x.num_batches(); ++b)
678template <
auto Beta = 1, simdifiable_multi Vy, simdifiable_multi... Vx>
681 BATMAT_ASSERT(((y.num_batches() == x.num_batches()) && ...));
682 for (
index_t b = 0; b < y.num_batches(); ++b)
687template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz,
688 std::convertible_to<simdified_simd_t<Vx>> Ta>
690void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
691 axpby(alpha, x, 1, y, z);
695template <
auto Beta = 1, simdifiable_multi Vx, simdifiable_multi Vy,
696 std::convertible_to<simdified_simd_t<Vx>> Ta>
698void axpy(Ta alpha, Vx &&x, Vy &&y) {
700 for (
index_t b = 0; b < x.num_batches(); ++b)
705template <simdifiable_multi VA, simdifiable_multi VB,
int Rotate = 0>
709 for (
index_t b = 0; b < A.num_batches(); ++b)
714template <simdifiable_multi VA,
int Rotate = 0>
716 for (
index_t b = 0; b < A.num_batches(); ++b)
721template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC,
int Rotate = 0>
726 for (
index_t b = 0; b < A.num_batches(); ++b)
727 linalg::sub(A.batch(b), B.batch(b), C.batch(b), rot);
731template <simdifiable_multi VA, simdifiable_multi VB,
int Rotate = 0>
735 for (
index_t b = 0; b < A.num_batches(); ++b)
740template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC,
int Rotate = 0>
745 for (
index_t b = 0; b < A.num_batches(); ++b)
746 linalg::add(A.batch(b), B.batch(b), C.batch(b), rot);
750template <simdifiable_multi VA, simdifiable_multi VB,
int Rotate = 0>
754 for (
index_t b = 0; b < A.num_batches(); ++b)
759template <
class F, simdifiable_multi VA, simdifiable_multi... VAs>
762 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
763 for (
index_t b = 0; b < A.num_batches(); ++b)
769template <
class F, simdifiable_multi VA, simdifiable_multi... VAs>
772 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
773 for (
index_t b = 0; b < A.num_batches(); ++b)
779template <
class F, simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi... VAs>
783 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
784 for (
index_t b = 0; b < A.num_batches(); ++b)
790template <
class F, simdifiable_multi... VAs, simdifiable_multi... VBs>
794 auto &&a0 = get<0>(As);
795 BATMAT_ASSERT(((a0.num_batches() == Bs.num_batches()) && ...));
796 BATMAT_ASSERT([&]<std::size_t... Is>(std::index_sequence<Is...>) {
797 return ((a0.num_batches() == get<Is>(As).num_batches()) && ...);
798 }(std::make_index_sequence<
sizeof...(VAs)>()));
799 for (
index_t b = 0; b < a0.num_batches(); ++b)
801 fun, std::apply([&](
auto &&...a) {
return std::make_tuple(a.batch(b)...); }, As),
void scale(T alpha, Vx &&x, Vz &&z)
Multiply a vector by a scalar z = αx.
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate > rot={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
void hadamard(Vx &&x, Vy &&y, Vz &&z)
Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
void hadamard(Vx &&x, Vy &&y, Vz &&z)
Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
void transform_n_elementwise(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given matrices or vectors, storing the results in the tuple o...
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate > rot={})
Add two matrices or vectors C = A + B. Rotate affects B.
void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping z = max(lo, min(x, hi)).
void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping residual z = x - max(lo, min(x, hi)).
void transform_n_diag(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given vectors and the diagonal elements of the given square m...
void axpy(Vy &&y, const std::array< simdified_simd_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
void scale(T alpha, Vx &&x, Vz &&z)
Multiply a vector by a scalar z = αx.
void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping residual z = x - max(lo, min(x, hi)).
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Add two matrices or vectors C = A + B. Rotate affects B.
void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the results in the first t...
void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping z = max(lo, min(x, hi)).
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
void negate(VA &&A, VB &&B, with_rotate_t< Rotate >={})
Negate a matrix or vector B = -A.
void negate(VA &&A, VB &&B, with_rotate_t< Rotate > rot={})
Negate a matrix or vector B = -A.
void transform_n_elementwise(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given matrices or vectors, storing the results in the tuple o...
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
void copy_diag(VA &&A, VB &&B)
Copy the diagonal elements of a matrix.
void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the results in the first t...
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
void axpy(Vy &&y, const std::array< simdified_simd_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
void add_diag(VA &&A, VB &&b, VC &&C)
C = A + diag(b).
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
datapar::simd< F, Abi > rotl(datapar::simd< F, Abi > x)
Rotates the elements of x by s positions to the left.
#define GUANAQO_TRACE_LINALG(name, gflops)
stdx::simd< Tp, Abi > simd
typename detail::simdified_value< V >::type simdified_value_t
typename detail::simdified_abi< V >::type simdified_abi_t
typename detail::simdified_simd< V >::type simdified_simd_t
typename simdified_view_type< V >::type simdified_view_t
constexpr bool simdify_compatible
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
constexpr auto cols(const MatrixView< T, I, S, O > &v)
constexpr auto rows(const MatrixView< T, I, S, O > &v)