batmat 0.0.21
Batched linear algebra routines
Loading...
Searching...
No Matches
reduce.hpp
Go to the documentation of this file.
1#pragma once
2
5
6namespace batmat::linalg {
7
8/// @cond DETAIL
9
10namespace detail {
11
12template <class T, class Abi, StorageOrder O0, class Tinit, class F, class R, class... Args>
13auto reduce(Tinit init, F fun, R reduce, view<const T, Abi, O0> x0, const Args &...xs) {
14 BATMAT_ASSERT(((x0.rows() == xs.rows()) && ...));
15 BATMAT_ASSERT(((x0.cols() == xs.cols()) && ...));
16 BATMAT_ASSERT(((x0.depth() == xs.depth()) && ...));
17 BATMAT_ASSERT(((x0.batch_size() == xs.batch_size()) && ...));
18 iter_elems<T, Abi, O0>([&](auto... args) { init = fun(init, args...); }, x0, xs...);
19 return reduce(init);
20}
21
22template <class T, class Abi, StorageOrder OA>
25 using norms = linalg::norms<T, simd>;
26 return reduce<T, Abi>(norms::zero_simd(), norms(), norms(), A);
27}
28
29/// Dot product.
30template <class T, class Abi, StorageOrder OA, StorageOrder OB>
31[[gnu::flatten]] T dot(view<const T, Abi, OA> a, view<const T, Abi, OB> b) {
33 auto fma = [](auto accum, auto ai, auto bi) { return ai * bi + accum; };
34 auto simd_reduce = [](auto accum) { return reduce(accum); };
35 return reduce<T, Abi>(simd{0}, fma, simd_reduce, a, b);
36}
37
38/// Squared 2-norm.
39template <class T, class Abi, StorageOrder OA>
40[[gnu::flatten]] T norm_2_sq(view<const T, Abi, OA> a) {
42 auto fma = [](auto accum, auto ai) { return ai * ai + accum; };
43 auto simd_reduce = [](auto accum) { return reduce(accum); };
44 return reduce<T, Abi>(simd{0}, fma, simd_reduce, a);
45}
46
47/// ∑ wᵢ aᵢ².
48template <class T, class Abi, StorageOrder OW, StorageOrder OA>
51 auto wnd = [](auto accum, auto wi, auto ai) { return wi * (ai * ai) + accum; };
52 auto simd_reduce = [](auto accum) { return reduce(accum); };
53 return reduce<T, Abi>(simd{0}, wnd, simd_reduce, w, a);
54}
55
56/// ∑ wᵢ(aᵢ - bᵢ)².
57template <class T, class Abi, StorageOrder OW, StorageOrder OA, StorageOrder OB>
61 auto wnd = [](auto accum, auto wi, auto ai, auto bi) {
62 auto ei = ai - bi;
63 return wi * (ei * ei) + accum;
64 };
65 auto simd_reduce = [](auto accum) { return reduce(accum); };
66 return reduce<T, Abi>(simd{0}, wnd, simd_reduce, w, a, b);
67}
68
69} // namespace detail
70
71/// @endcond
72
73/// @addtogroup topic-linalg
74/// @{
75
76/// @name Single-batch reduction operations
77/// @{
78
79/// Compute the norms (max, 1-norm, and 2-norm) of a vector.
80template <simdifiable Vx>
82 GUANAQO_TRACE_LINALG("norms_all", 3 * detail::num_elem(simdify(x))); // fma, add, max
83 return detail::norms_all<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const());
84}
85
86/// Compute the infinity norm of a vector.
87template <simdifiable Vx>
89 return norms_all(std::forward<Vx>(x)).norm_inf();
90}
91
92/// Compute the 1-norm of a vector.
93template <simdifiable Vx>
95 return norms_all(std::forward<Vx>(x)).norm_1();
96}
97
98/// Compute the squared 2-norm of a vector.
99template <simdifiable Vx>
101 GUANAQO_TRACE_LINALG("norm_2_squared", detail::num_elem(simdify(x)));
102 return detail::norm_2_sq<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const());
103}
104
105/// Compute the 2-norm of a vector.
106template <simdifiable Vx>
108 using std::sqrt;
109 return sqrt(norm_2_squared(std::forward<Vx>(x)));
110}
111
112/// Compute the dot product of two vectors.
113template <simdifiable Vx, simdifiable Vy>
115simdified_value_t<Vx> dot(Vx &&x, Vy &&y) {
116 GUANAQO_TRACE_LINALG("dot", detail::num_elem(simdify(x)));
117 return detail::dot<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(),
118 simdify(y).as_const());
119}
120
121/// ∑ wᵢ aᵢ².
122template <simdifiable Vw, simdifiable Va>
125 GUANAQO_TRACE_LINALG("weighted_norm_sq", 2 * detail::num_elem(simdify(w)));
126 return detail::weighted_norm_sq<simdified_value_t<Vw>, simdified_abi_t<Vw>>(
127 simdify(w).as_const(), simdify(a).as_const());
128}
129
130/// ∑ wᵢ(aᵢ - bᵢ)².
131template <simdifiable Vw, simdifiable Va, simdifiable Vb>
134 GUANAQO_TRACE_LINALG("weighted_norm_sq_difference", 3 * detail::num_elem(simdify(w)));
135 return detail::weighted_norm_sq_difference<simdified_value_t<Vw>, simdified_abi_t<Vw>>(
136 simdify(w).as_const(), simdify(a).as_const(), simdify(b).as_const());
137}
138
139/// @}
140
141/// @}
142
143// TODO: doxygen gets confused because the template parameters are the same as the single-batch
144// versions, so put in a separate namespace
145inline namespace multi {
146
147/// @addtogroup topic-linalg
148/// @{
149
150/// @name Multi-batch reduction operations
151/// @{
152
153/// Compute the norms (max, 1-norm, and 2-norm) of a vector.
154template <simdifiable_multi Vx>
156 typename norms<simdified_value_t<Vx>>::result result{};
157 for (index_t b = 0; b < x.num_batches(); ++b)
158 result = norms<simdified_value_t<Vx>>{}(result, linalg::norms_all(x.batch(b)));
159 return result;
160}
161
162/// Compute the infinity norm of a vector.
163template <simdifiable_multi Vx>
165 return norms_all(std::forward<Vx>(x)).norm_inf();
166}
167
168/// Compute the 1-norm of a vector.
169template <simdifiable_multi Vx>
171 return norms_all(std::forward<Vx>(x)).norm_1();
172}
173
174/// Compute the squared 2-norm of a vector.
175template <simdifiable_multi Vx>
177 simdified_value_t<Vx> sumsq{};
178 for (index_t b = 0; b < x.num_batches(); ++b)
179 sumsq += linalg::norm_2_squared(x.batch(b));
180 return sumsq;
181}
182
183/// Compute the 2-norm of a vector.
184template <simdifiable_multi Vx>
186 using std::sqrt;
187 return sqrt(norm_2_squared(std::forward<Vx>(x)));
188}
189
190/// Compute the dot product of two vectors.
191template <simdifiable_multi Vx, simdifiable_multi Vy>
193simdified_value_t<Vx> dot(Vx &&x, Vy &&y) {
194 BATMAT_ASSERT(x.num_batches() == y.num_batches());
195 simdified_value_t<Vx> result{};
196 for (index_t b = 0; b < x.num_batches(); ++b)
197 result += linalg::dot(x.batch(b), y.batch(b));
198 return result;
199}
200
201/// ∑ wᵢ xᵢ².
202template <simdifiable_multi Vw, simdifiable_multi Vx>
205 BATMAT_ASSERT(w.num_batches() == x.num_batches());
206 simdified_value_t<Vw> result{};
207 for (index_t b = 0; b < w.num_batches(); ++b)
208 result += linalg::weighted_norm_sq(w.batch(b), x.batch(b));
209 return result;
210}
211
212/// ∑ wᵢ(xᵢ - yᵢ)².
213template <simdifiable_multi Vw, simdifiable_multi Vx, simdifiable_multi Vy>
216 BATMAT_ASSERT(w.num_batches() == x.num_batches());
217 BATMAT_ASSERT(w.num_batches() == y.num_batches());
218 simdified_value_t<Vw> result{};
219 for (index_t b = 0; b < w.num_batches(); ++b)
220 result += linalg::weighted_norm_sq_diff(w.batch(b), x.batch(b), y.batch(b));
221 return result;
222}
223
224} // namespace multi
225
226} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
simdified_value_t< Vx > norm_inf(Vx &&x)
Compute the infinity norm of a vector.
Definition reduce.hpp:88
simdified_value_t< Vx > norm_2_squared(Vx &&x)
Compute the squared 2-norm of a vector.
Definition reduce.hpp:100
simdified_value_t< Vx > norm_2(Vx &&x)
Compute the 2-norm of a vector.
Definition reduce.hpp:185
simdified_value_t< Vx > norm_2(Vx &&x)
Compute the 2-norm of a vector.
Definition reduce.hpp:107
simdified_value_t< Vx > norm_2_squared(Vx &&x)
Compute the squared 2-norm of a vector.
Definition reduce.hpp:176
simdified_value_t< Vx > norm_1(Vx &&x)
Compute the 1-norm of a vector.
Definition reduce.hpp:94
simdified_value_t< Vw > weighted_norm_sq(Vw &&w, Vx &&x)
∑ wᵢ xᵢ².
Definition reduce.hpp:204
simdified_value_t< Vx > dot(Vx &&x, Vy &&y)
Compute the dot product of two vectors.
Definition reduce.hpp:115
simdified_value_t< Vx > norm_1(Vx &&x)
Compute the 1-norm of a vector.
Definition reduce.hpp:170
norms< simdified_value_t< Vx > >::result norms_all(Vx &&x)
Compute the norms (max, 1-norm, and 2-norm) of a vector.
Definition reduce.hpp:81
simdified_value_t< Vw > weighted_norm_sq_difference(Vw &&w, Vx &&x, Vy &&y)
∑ wᵢ(xᵢ - yᵢ)².
Definition reduce.hpp:215
simdified_value_t< Vx > norm_inf(Vx &&x)
Compute the infinity norm of a vector.
Definition reduce.hpp:164
simdified_value_t< Vx > dot(Vx &&x, Vy &&y)
Compute the dot product of two vectors.
Definition reduce.hpp:193
simdified_value_t< Vw > weighted_norm_sq(Vw &&w, Va &&a)
∑ wᵢ aᵢ².
Definition reduce.hpp:124
simdified_value_t< Vw > weighted_norm_sq_diff(Vw &&w, Va &&a, Vb &&b)
∑ wᵢ(aᵢ - bᵢ)².
Definition reduce.hpp:133
norms< simdified_value_t< Vx > >::result norms_all(Vx &&x)
Compute the norms (max, 1-norm, and 2-norm) of a vector.
Definition reduce.hpp:155
#define GUANAQO_TRACE_LINALG(name, gflops)
stdx::simd< Tp, Abi > simd
Definition simd.hpp:102
typename detail::simdified_value< V >::type simdified_value_t
Definition simdify.hpp:206
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:208
constexpr bool simdify_compatible
Definition simdify.hpp:211
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:218
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
int index_t
Definition config.hpp:13
Vector reductions.
Utilities for computing vector norms.
Definition norms.hpp:26
static result_simd zero_simd()
Definition norms.hpp:53
typename norms< T >::result result
Accumulator.
Definition norms.hpp:28