10#include <guanaqo/trace.hpp>
23constexpr index_t num_elem(
const auto &A) {
return A.rows() * A.cols() * A.depth(); }
25template <
class T,
class Abi,
StorageOrder O,
class F,
class X,
class... Xs>
26[[gnu::always_inline]]
inline void iter_elems(F &&fun, X &&x, Xs &&...xs) {
27 using types = simd_view_types<T, Abi>;
28 if constexpr (O == StorageOrder::ColMajor) {
29 for (
index_t c = 0; c < x.cols(); ++c)
30 for (
index_t r = 0; r < x.rows(); ++r)
31 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
33 for (
index_t r = 0; r < x.rows(); ++r)
34 for (
index_t c = 0; c < x.cols(); ++c)
35 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
39template <
class T,
class Abi,
StorageOrder O,
class F,
class X,
class... Xs>
40[[gnu::always_inline]]
inline void iter_elems_store(F &&fun, X &&x, Xs &&...xs) {
41 using types = simd_view_types<T, Abi>;
42 if constexpr (O == StorageOrder::ColMajor) {
43 for (
index_t c = 0; c < x.cols(); ++c)
44 for (
index_t r = 0; r < x.rows(); ++r)
45 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
47 for (
index_t r = 0; r < x.rows(); ++r)
48 for (
index_t c = 0; c < x.cols(); ++c)
49 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
53template <
class T,
class Abi,
StorageOrder O,
class F,
class X0,
class X1,
class... Xs>
54[[gnu::always_inline]]
inline void iter_elems_store2(F &&fun, X0 &&x0, X1 &&x1, Xs &&...xs) {
55 using types = simd_view_types<T, Abi>;
56 if constexpr (O == StorageOrder::ColMajor) {
57 for (
index_t c = 0; c < x0.cols(); ++c)
58 for (
index_t r = 0; r < x0.rows(); ++r) {
59 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
60 types::aligned_store(r0, &x0(0, r, c));
61 types::aligned_store(r1, &x1(0, r, c));
64 for (
index_t r = 0; r < x0.rows(); ++r)
65 for (
index_t c = 0; c < x0.cols(); ++c) {
66 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
67 types::aligned_store(r0, &x0(0, r, c));
68 types::aligned_store(r1, &x1(0, r, c));
73template <
class T,
class Abi,
StorageOrder O,
class F,
class... Ys,
class... Xs>
74[[gnu::always_inline]]
inline void iter_elems_store_n(F &&fun, std::tuple<Ys...> ys, Xs &&...xs) {
76 using types = simd_view_types<T, Abi>;
77 const index_t rows = std::get<0>(ys).rows(),
cols = std::get<0>(ys).cols();
78 if constexpr (O == StorageOrder::ColMajor) {
81 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
82 static_assert(std::tuple_size_v<
decltype(rs)> ==
sizeof...(Ys));
83 [&]<
size_t... Is>(std::index_sequence<Is...>) {
84 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
85 }(std::index_sequence_for<Ys...>());
90 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
91 static_assert(std::tuple_size_v<
decltype(rs)> ==
sizeof...(Ys));
92 [&]<
size_t... Is>(std::index_sequence<Is...>) {
93 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
94 }(std::index_sequence_for<Ys...>());
102template <
class T,
class Abi,
class F,
class... Ys,
class... Xs>
103[[gnu::always_inline]]
inline void iter_diag_store_n(F &&fun, std::tuple<Ys...> ys, Xs &&...xs) {
105 using types = simd_view_types<T, Abi>;
106 const index_t rows = std::get<0>(ys).rows(),
cols = std::get<0>(ys).cols();
107 const index_t n = std::get<0>(ys).storage_order == StorageOrder::ColMajor
108 ? (
cols > 1 ? std::min(rows, cols) :
rows)
112 static constexpr auto all_vectors_except_first = [](
auto &x0,
auto &...x1s) {
113 return x0.rows() == x0.cols() &&
114 ((x1s.storage_order == StorageOrder::ColMajor && x1s.cols() == 1) && ...);
116 const bool all_xs_vectors_except_first = all_vectors_except_first(xs...);
117 const bool all_ys_vector_except_first = std::apply(all_vectors_except_first, ys);
118 if (all_xs_vectors_except_first && all_ys_vector_except_first) {
119 for (
index_t r = 0; r < n; ++r) {
120 auto rs = [&](
auto &x0,
auto &...x1s) {
121 return fun(types::aligned_load(&x0(0, r, r)),
122 types::aligned_load(&x1s(0, r, 0))...);
124 static_assert(std::tuple_size_v<
decltype(rs)> ==
sizeof...(Ys));
125 [&]<
size_t... Is>(std::index_sequence<Is...>) {
126 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, Is == 0 ? r : 0))), ...);
127 }(std::index_sequence_for<Ys...>());
135 static constexpr auto access = [](
auto &x,
index_t r) ->
auto & {
136 return x.storage_order == StorageOrder::ColMajor
137 ? (x.cols() > 1 ? x(0, r, r) : x(0, r, 0))
138 : (x.
rows() > 1 ? x(0, r, r) : x(0, 0, r));
140 for (
index_t r = 0; r < n; ++r) {
141 auto rs = fun(types::aligned_load(&access(xs, r))...);
142 static_assert(std::tuple_size_v<
decltype(rs)> ==
sizeof...(Ys));
143 [&]<
size_t... Is>(std::index_sequence<Is...>) {
144 ((types::aligned_store(get<Is>(rs), &access(get<Is>(ys), r))), ...);
145 }(std::index_sequence_for<Ys...>());
151template <
class T,
class Abi, StorageOrder OB, StorageOrder OC>
155 iter_elems_store<T, Abi, OC>([&](
auto Bi) {
return a * Bi; }, C, B);
159template <
class T,
class Abi, StorageOrder OA, StorageOrder OB, StorageOrder OC>
166 iter_elems_store<T, Abi, OC>([&](
auto Ai,
auto Bi) {
return Ai * Bi; }, C, A, B);
170template <
class T,
class Abi, StorageOrder O>
179 const auto clamp = [&](
auto xi,
auto loi,
auto hii) {
return fmax(loi, fmin(xi, hii)); };
180 iter_elems_store<T, Abi, O>(
clamp, z, x, lo, hi);
184template <
class T,
class Abi, StorageOrder O>
188 const auto clamp = [&](
auto xi) {
return fmax(
decltype(xi){lo}, fmin(xi,
decltype(xi){hi})); };
189 iter_elems_store<T, Abi, O>(
clamp, z, x);
193template <
class T,
class Abi, StorageOrder O>
203 const auto clamp_resid = [&](
auto xi,
auto loi,
auto hii) {
204 return fmax(xi - hii, fmin(simd{0}, xi - loi));
206 iter_elems_store<T, Abi, O>(
clamp_resid, z, x, lo, hi);
210template <
class T,
class Abi, T Beta,
StorageOrder O,
class... Xs>
211[[gnu::flatten]]
void gaxpby(
view<T, Abi, O> z,
const std::array<T,
sizeof...(Xs)> &alphas,
215 if constexpr (Beta == 0)
216 iter_elems_store<T, Abi, O>(
218 return [&]<std::size_t... Is>(std::index_sequence<Is...>,
auto... xis) {
219 return ((xis * alphas[Is]) + ...);
220 }(std::make_index_sequence<
sizeof...(Xs)>(), xis...);
224 iter_elems_store<T, Abi, O>(
225 [&](
auto zi,
auto... xis) {
226 return [&]<std::size_t... Is>(std::index_sequence<Is...>,
auto... xis) {
227 return zi * Beta + ((xis * alphas[Is]) + ...);
228 }(std::make_index_sequence<
sizeof...(Xs)>(), xis...);
236template <
class T,
class Abi,
int Rotate, StorageOrder OA, StorageOrder OB>
241 iter_elems_store<T, Abi, OB>([&](
auto Ai) {
return -rotl<Rotate>(Ai); }, B, A);
245template <
class T,
class Abi,
int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
252 iter_elems_store<T, Abi, OC>([&](
auto Ai,
auto Bi) {
return Ai - rotl<Rotate>(Bi); }, C, A, B);
256template <
class T,
class Abi,
int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
263 iter_elems_store<T, Abi, OC>([&](
auto Ai,
auto Bi) {
return Ai + rotl<Rotate>(Bi); }, C, A, B);
277template <simdifiable Vx, simdifiable Vz, std::convertible_to<simdified_value_t<Vx>> T>
279void scale(T alpha, Vx &&x, Vz &&z) {
286template <simdifiable Vx, std::convertible_to<simdified_value_t<Vx>> T>
294template <simdifiable Vx, simdifiable Vy, simdifiable Vz>
303template <simdifiable Vx, simdifiable Vy>
312template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
314void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
321template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
330template <simdifiable Vx, simdifiable Vz>
339template <simdifiable Vx, simdifiable Vy, simdifiable Vz,
340 std::convertible_to<simdified_value_t<Vx>> Ta,
341 std::convertible_to<simdified_value_t<Vx>> Tb>
343void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
350template <simdifiable Vx, simdifiable Vy,
351 std::convertible_to<simdified_value_t<Vx>> Ta,
352 std::convertible_to<simdified_value_t<Vx>> Tb>
354void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
361template <
auto Beta = 1, simdifiable Vy, simdifiable... Vx>
364 [[maybe_unused]]
static constexpr index_t num_mul = Beta != 1 && Beta != 0 ? 1 : 0;
365 [[maybe_unused]]
static constexpr index_t num_fma =
sizeof...(Vx);
372template <simdifiable Vx, simdifiable Vy, simdifiable Vz,
373 std::convertible_to<simdified_value_t<Vx>> Ta>
375void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
376 axpby(alpha, x, Ta{1}, y, z);
380template <
auto Beta = 1, simdifiable Vx, simdifiable Vy,
381 std::convertible_to<simdified_value_t<Vx>> Ta>
383void axpy(Ta alpha, Vx &&x, Vy &&y) {
384 [[maybe_unused]]
static constexpr index_t num_mul = Beta != 1 && Beta != 0 ? 1 : 0;
385 [[maybe_unused]]
static constexpr index_t num_fma = 1;
392template <simdifiable VA, simdifiable VB,
int Rotate = 0>
401template <simdifiable VA,
int Rotate = 0>
409template <simdifiable VA, simdifiable VB, simdifiable VC,
int Rotate = 0>
418template <simdifiable VA, simdifiable VB,
int Rotate = 0>
427template <simdifiable VA, simdifiable VB, simdifiable VC,
int Rotate = 0>
436template <simdifiable VA, simdifiable VB,
int Rotate = 0>
445template <
class F, simdifiable VA, simdifiable... VAs>
450 std::forward<F>(fun),
simdify(A).as_const(),
simdify(As).as_const()...);
455template <
class F, simdifiable VA, simdifiable... VAs>
465template <
class F, simdifiable VA, simdifiable VB, simdifiable... VAs>
475template <
class F, simdifiable... VAs, simdifiable... VBs>
478 using VA0 = std::tuple_element_t<0,
decltype(As)>;
481 std::forward<F>(fun),
482 std::apply([](
auto &&...a) {
return std::make_tuple(
simdify(a)...); }, As),
490template <
class F, simdifiable... VAs, simdifiable... VBs>
493 constexpr auto check_size = [](
auto &x) {
494 if (x.rows() == x.cols())
496 else if (x.storage_order == StorageOrder::ColMajor) {
504 [[maybe_unused]]
const index_t n = check_size(std::get<0>(As));
505 [&]<
size_t... Is>(std::index_sequence<Is...>) {
507 }(std::index_sequence_for<VAs...>());
509 using VA0 = std::tuple_element_t<0,
decltype(As)>;
511 std::forward<F>(fun),
512 std::apply([](
auto &&...a) {
return std::make_tuple(
simdify(a)...); }, As),
520template <
class F, simdifiable VA, simdifiable VB>
523 [[maybe_unused]]
const index_t n =
524 A.storage_order == StorageOrder::ColMajor ? A.rows() : A.cols();
525 if constexpr (A.storage_order == StorageOrder::ColMajor) {
532 if constexpr (B.storage_order == StorageOrder::ColMajor) {
541 [](
auto Ai) {
return std::make_tuple(Ai); }, std::make_tuple(
simdify(B)),
546template <simdifiable VA, simdifiable VB, simdifiable VC>
552 if constexpr (b.storage_order == StorageOrder::ColMajor) {
561 [](
auto Ai,
auto bi) {
return std::make_tuple(Ai + bi); }, std::make_tuple(
simdify(C)),
566template <simdifiable VA, simdifiable VB>
569 add_diag(std::forward<VA>(A), std::forward<VB>(b), std::forward<VA>(A));
587template <simdifiable_multi Vx, simdifiable_multi Vz, std::convertible_to<simdified_value_t<Vx>> T>
589void scale(T alpha, Vx &&x, Vz &&z) {
591 for (
index_t b = 0; b < x.num_batches(); ++b)
596template <simdifiable_multi Vx, std::convertible_to<simdified_value_t<Vx>> T>
598 for (
index_t b = 0; b < x.num_batches(); ++b)
603template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz>
608 for (
index_t b = 0; b < x.num_batches(); ++b)
613template <simdifiable_multi Vx, simdifiable_multi Vy>
617 for (
index_t b = 0; b < x.num_batches(); ++b)
622template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
624void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
628 for (
index_t b = 0; b < x.num_batches(); ++b)
629 linalg::clamp(x.batch(b), lo.batch(b), hi.batch(b), z.batch(b));
633template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
639 for (
index_t b = 0; b < x.num_batches(); ++b)
644template <simdifiable_multi Vx, simdifiable_multi Vz>
648 for (
index_t b = 0; b < x.num_batches(); ++b)
653template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz,
654 std::convertible_to<simdified_value_t<Vx>> Ta,
655 std::convertible_to<simdified_value_t<Vx>> Tb>
657void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
660 for (
index_t b = 0; b < x.num_batches(); ++b)
661 linalg::axpby(alpha, x.batch(b), beta, y.batch(b), z.batch(b));
665template <simdifiable_multi Vx, simdifiable_multi Vy,
666 std::convertible_to<simdified_value_t<Vx>> Ta,
667 std::convertible_to<simdified_value_t<Vx>> Tb>
669void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
671 for (
index_t b = 0; b < x.num_batches(); ++b)
676template <
auto Beta = 1, simdifiable_multi Vy, simdifiable_multi... Vx>
679 BATMAT_ASSERT(((y.num_batches() == x.num_batches()) && ...));
680 for (
index_t b = 0; b < y.num_batches(); ++b)
685template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz,
686 std::convertible_to<simdified_value_t<Vx>> Ta>
688void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
689 axpby(alpha, x, 1, y, z);
693template <
auto Beta = 1, simdifiable_multi Vx, simdifiable_multi Vy,
694 std::convertible_to<simdified_value_t<Vx>> Ta>
696void axpy(Ta alpha, Vx &&x, Vy &&y) {
698 for (
index_t b = 0; b < x.num_batches(); ++b)
703template <simdifiable_multi VA, simdifiable_multi VB,
int Rotate = 0>
707 for (
index_t b = 0; b < A.num_batches(); ++b)
712template <simdifiable_multi VA,
int Rotate = 0>
714 for (
index_t b = 0; b < A.num_batches(); ++b)
719template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC,
int Rotate = 0>
724 for (
index_t b = 0; b < A.num_batches(); ++b)
725 linalg::sub(A.batch(b), B.batch(b), C.batch(b), rot);
729template <simdifiable_multi VA, simdifiable_multi VB,
int Rotate = 0>
733 for (
index_t b = 0; b < A.num_batches(); ++b)
738template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC,
int Rotate = 0>
743 for (
index_t b = 0; b < A.num_batches(); ++b)
744 linalg::add(A.batch(b), B.batch(b), C.batch(b), rot);
748template <simdifiable_multi VA, simdifiable_multi VB,
int Rotate = 0>
752 for (
index_t b = 0; b < A.num_batches(); ++b)
757template <
class F, simdifiable_multi VA, simdifiable_multi... VAs>
760 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
761 for (
index_t b = 0; b < A.num_batches(); ++b)
767template <
class F, simdifiable_multi VA, simdifiable_multi... VAs>
770 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
771 for (
index_t b = 0; b < A.num_batches(); ++b)
777template <
class F, simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi... VAs>
781 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
782 for (
index_t b = 0; b < A.num_batches(); ++b)
788template <
class F, simdifiable_multi... VAs, simdifiable_multi... VBs>
792 auto &&a0 = get<0>(As);
793 BATMAT_ASSERT(((a0.num_batches() == Bs.num_batches()) && ...));
794 BATMAT_ASSERT([&]<std::size_t... Is>(std::index_sequence<Is...>) {
795 return ((a0.num_batches() == get<Is>(As).num_batches()) && ...);
796 }(std::make_index_sequence<
sizeof...(VAs)>()));
797 for (
index_t b = 0; b < a0.num_batches(); ++b)
799 fun, std::apply([&](
auto &&...a) {
return std::make_tuple(a.batch(b)...); }, As),
void scale(T alpha, Vx &&x, Vz &&z)
Multiply a vector by a scalar z = αx.
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate > rot={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
void hadamard(Vx &&x, Vy &&y, Vz &&z)
Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
void hadamard(Vx &&x, Vy &&y, Vz &&z)
Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
void axpy(Vy &&y, const std::array< simdified_value_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
void transform_n_elementwise(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given matrices or vectors, storing the results in the tuple o...
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate > rot={})
Add two matrices or vectors C = A + B. Rotate affects B.
void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping z = max(lo, min(x, hi)).
void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping residual z = x - max(lo, min(x, hi)).
void transform_n_diag(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given vectors and the diagonal elements of the given square m...
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
void scale(T alpha, Vx &&x, Vz &&z)
Multiply a vector by a scalar z = αx.
void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping residual z = x - max(lo, min(x, hi)).
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Add two matrices or vectors C = A + B. Rotate affects B.
void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the results in the first t...
void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping z = max(lo, min(x, hi)).
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
void axpy(Vy &&y, const std::array< simdified_value_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
void negate(VA &&A, VB &&B, with_rotate_t< Rotate >={})
Negate a matrix or vector B = -A.
void negate(VA &&A, VB &&B, with_rotate_t< Rotate > rot={})
Negate a matrix or vector B = -A.
void transform_n_elementwise(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given matrices or vectors, storing the results in the tuple o...
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
void copy_diag(VA &&A, VB &&B)
Copy the diagonal elements of a matrix.
void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the results in the first t...
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
void add_diag(VA &&A, VB &&b, VC &&C)
C = A + diag(b).
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
datapar::simd< F, Abi > rotl(datapar::simd< F, Abi > x)
Rotates the elements of x by s positions to the left.
#define GUANAQO_TRACE_LINALG(name, gflops)
stdx::simd< Tp, Abi > simd
typename detail::simdified_value< V >::type simdified_value_t
typename detail::simdified_abi< V >::type simdified_abi_t
typename simdified_view_type< V >::type simdified_view_t
constexpr bool simdify_compatible
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
constexpr auto cols(const MatrixView< T, I, S, O > &v)
constexpr auto rows(const MatrixView< T, I, S, O > &v)