batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
reduce.hpp
Go to the documentation of this file.
1#pragma once
2
5
6namespace batmat::linalg {
7
8/// @cond DETAIL
9
10namespace detail {
11
12template <class T, class Abi, StorageOrder O0, class Tinit, class F, class... Args>
13auto vreduce(Tinit init, F fun, view<const T, Abi, O0> x0, const Args &...xs) {
14 BATMAT_ASSERT(((x0.rows() == xs.rows()) && ...));
15 BATMAT_ASSERT(((x0.cols() == xs.cols()) && ...));
16 BATMAT_ASSERT(((x0.depth() == xs.depth()) && ...));
17 BATMAT_ASSERT(((x0.batch_size() == xs.batch_size()) && ...));
18 iter_elems<T, Abi, O0>([&](auto... args) { init = fun(init, args...); }, x0, xs...);
19 return init;
20}
21
22template <class T, class Abi, StorageOrder O0, class Tinit, class F, class R, class... Args>
23auto reduce(Tinit init, F fun, R reduce_fn, view<const T, Abi, O0> x0, const Args &...xs) {
24 return reduce_fn(vreduce(init, fun, x0, xs...));
25}
26
27template <class T, class Abi, StorageOrder OA>
28[[gnu::flatten]] auto vnorms_all(view<const T, Abi, OA> A) {
30 using norms = linalg::norms<T, simd>;
31 return vreduce<T, Abi>(norms::zero_simd(), norms(), A);
32}
33
34template <class T, class Abi, StorageOrder OA>
37 return linalg::norms<T, simd>{}(vnorms_all<T, Abi>(A));
38}
39
40/// Dot product (lane-wise).
41template <class T, class Abi, StorageOrder OA, StorageOrder OB>
42[[gnu::flatten]] auto vdot(view<const T, Abi, OA> a, view<const T, Abi, OB> b) {
44 auto fma = [](auto accum, auto ai, auto bi) { return ai * bi + accum; };
45 return vreduce<T, Abi>(simd{0}, fma, a, b);
46}
47
48/// Dot product.
49template <class T, class Abi, StorageOrder OA, StorageOrder OB>
50[[gnu::flatten]] T dot(view<const T, Abi, OA> a, view<const T, Abi, OB> b) {
51 return reduce(vdot<T, Abi>(a, b));
52}
53
54/// Squared 2-norm (lane-wise).
55template <class T, class Abi, StorageOrder OA>
56[[gnu::flatten]] auto vnorm_2_squared(view<const T, Abi, OA> a) {
58 auto fma = [](auto accum, auto ai) { return ai * ai + accum; };
59 return vreduce<T, Abi>(simd{0}, fma, a);
60}
61
62/// Squared 2-norm.
63template <class T, class Abi, StorageOrder OA>
64[[gnu::flatten]] T norm_2_sq(view<const T, Abi, OA> a) {
65 return reduce(vnorm_2_squared<T, Abi>(a));
66}
67
68/// ∑ wᵢ aᵢ² (lane-wise).
69template <class T, class Abi, StorageOrder OW, StorageOrder OA>
72 auto wnd = [](auto accum, auto wi, auto ai) { return wi * (ai * ai) + accum; };
73 return vreduce<T, Abi>(simd{0}, wnd, w, a);
74}
75
76/// ∑ wᵢ aᵢ².
77template <class T, class Abi, StorageOrder OW, StorageOrder OA>
79 return reduce(weighted_vnorm_sq<T, Abi>(w, a));
80}
81
82/// ∑ wᵢ(aᵢ - bᵢ)² (lane-wise).
83template <class T, class Abi, StorageOrder OW, StorageOrder OA, StorageOrder OB>
84[[gnu::flatten]] auto weighted_vnorm_sq_difference(view<const T, Abi, OW> w,
88 auto wnd = [](auto accum, auto wi, auto ai, auto bi) {
89 auto ei = ai - bi;
90 return wi * (ei * ei) + accum;
91 };
92 return vreduce<T, Abi>(simd{0}, wnd, w, a, b);
93}
94
95/// ∑ wᵢ(aᵢ - bᵢ)².
96template <class T, class Abi, StorageOrder OW, StorageOrder OA, StorageOrder OB>
99 return reduce(weighted_vnorm_sq_difference<T, Abi>(w, a, b));
100}
101
102} // namespace detail
103
104/// @endcond
105
106/// @addtogroup topic-linalg
107/// @{
108
109/// @name Single-batch reduction operations
110/// @{
111
112/// Compute the lane-wise norms (max, 1-norm, and 2-norm) of a batch of vectors.
113template <simdifiable Vx>
115 GUANAQO_TRACE_LINALG("vnorms_all", 3 * detail::num_elem(simdify(x))); // fma, add, max
116 return detail::vnorms_all<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const());
117}
118
119/// Compute the norms (max, 1-norm, and 2-norm) of a vector.
120template <simdifiable Vx>
122 GUANAQO_TRACE_LINALG("norms_all", 3 * detail::num_elem(simdify(x))); // fma, add, max
123 return detail::norms_all<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const());
124}
125
126/// Compute the lane-wise infinity norms of a batch of vectors.
127template <simdifiable Vx>
129 return vnorms_all(std::forward<Vx>(x)).amax;
130}
131
132/// Compute the infinity norm of a vector.
133template <simdifiable Vx>
135 return norms_all(std::forward<Vx>(x)).norm_inf();
136}
137
138/// Compute the lane-wise 1-norms of a batch of vectors.
139template <simdifiable Vx>
141 return vnorms_all(std::forward<Vx>(x)).asum;
142}
143
144/// Compute the 1-norm of a vector.
145template <simdifiable Vx>
147 return norms_all(std::forward<Vx>(x)).norm_1();
148}
149
150/// Compute the lane-wise squared 2-norms of a batch of vectors.
151template <simdifiable Vx>
153 GUANAQO_TRACE_LINALG("vnorm_2_squared", detail::num_elem(simdify(x)));
154 return detail::vnorm_2_squared<simdified_value_t<Vx>, simdified_abi_t<Vx>>(
155 simdify(x).as_const());
156}
157
158/// Compute the squared 2-norm of a vector.
159template <simdifiable Vx>
161 GUANAQO_TRACE_LINALG("norm_2_squared", detail::num_elem(simdify(x)));
162 return detail::norm_2_sq<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const());
163}
164
165/// Compute the lane-wise 2-norms of a batch of vectors.
166template <simdifiable Vx>
168 using std::sqrt;
169 return sqrt(vnorm_2_squared(std::forward<Vx>(x)));
170}
171
172/// Compute the 2-norm of a vector.
173template <simdifiable Vx>
175 using std::sqrt;
176 return sqrt(norm_2_squared(std::forward<Vx>(x)));
177}
178
179/// Compute the lane-wise dot products of two batches of vectors.
180template <simdifiable Vx, simdifiable Vy>
182simdified_simd_t<Vx> vdot(Vx &&x, Vy &&y) {
183 GUANAQO_TRACE_LINALG("vdot", detail::num_elem(simdify(x)));
184 return detail::vdot<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(),
185 simdify(y).as_const());
186}
187
188/// Compute the dot product of two vectors.
189template <simdifiable Vx, simdifiable Vy>
191simdified_value_t<Vx> dot(Vx &&x, Vy &&y) {
192 GUANAQO_TRACE_LINALG("dot", detail::num_elem(simdify(x)));
193 return detail::dot<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(),
194 simdify(y).as_const());
195}
196
197/// ∑ wᵢ aᵢ² (lane-wise).
198template <simdifiable Vw, simdifiable Va>
201 GUANAQO_TRACE_LINALG("weighted_vnorm_sq", 2 * detail::num_elem(simdify(w)));
202 return detail::weighted_vnorm_sq<simdified_value_t<Vw>, simdified_abi_t<Vw>>(
203 simdify(w).as_const(), simdify(a).as_const());
204}
205
206/// ∑ wᵢ aᵢ².
207template <simdifiable Vw, simdifiable Va>
210 GUANAQO_TRACE_LINALG("weighted_norm_sq", 2 * detail::num_elem(simdify(w)));
211 return detail::weighted_norm_sq<simdified_value_t<Vw>, simdified_abi_t<Vw>>(
212 simdify(w).as_const(), simdify(a).as_const());
213}
214
215/// ∑ wᵢ(aᵢ - bᵢ)² (lane-wise).
216template <simdifiable Vw, simdifiable Va, simdifiable Vb>
219 GUANAQO_TRACE_LINALG("weighted_vnorm_sq_diff", 3 * detail::num_elem(simdify(w)));
220 return detail::weighted_vnorm_sq_difference<simdified_value_t<Vw>, simdified_abi_t<Vw>>(
221 simdify(w).as_const(), simdify(a).as_const(), simdify(b).as_const());
222}
223
224/// ∑ wᵢ(aᵢ - bᵢ)².
225template <simdifiable Vw, simdifiable Va, simdifiable Vb>
228 GUANAQO_TRACE_LINALG("weighted_norm_sq_difference", 3 * detail::num_elem(simdify(w)));
229 return detail::weighted_norm_sq_difference<simdified_value_t<Vw>, simdified_abi_t<Vw>>(
230 simdify(w).as_const(), simdify(a).as_const(), simdify(b).as_const());
231}
232
233/// @}
234
235/// @}
236
237// TODO: doxygen gets confused because the template parameters are the same as the single-batch
238// versions, so put in a separate namespace
239inline namespace multi {
240
241/// @addtogroup topic-linalg
242/// @{
243
244/// @name Multi-batch reduction operations
245/// @{
246
247/// Compute the norms (max, 1-norm, and 2-norm) of a vector.
248template <simdifiable_multi Vx>
251 typename norms::result result{};
252 for (index_t b = 0; b < x.num_batches(); ++b)
253 result = norms{}(result, linalg::norms_all(x.batch(b)));
254 return result;
255}
256
257/// Compute the infinity norm of a vector.
258template <simdifiable_multi Vx>
260 return norms_all(std::forward<Vx>(x)).norm_inf();
261}
262
263/// Compute the 1-norm of a vector.
264template <simdifiable_multi Vx>
266 return norms_all(std::forward<Vx>(x)).norm_1();
267}
268
269/// Compute the squared 2-norm of a vector.
270template <simdifiable_multi Vx>
272 simdified_value_t<Vx> sumsq{};
273 for (index_t b = 0; b < x.num_batches(); ++b)
274 sumsq += linalg::norm_2_squared(x.batch(b));
275 return sumsq;
276}
277
278/// Compute the 2-norm of a vector.
279template <simdifiable_multi Vx>
281 using std::sqrt;
282 return sqrt(norm_2_squared(std::forward<Vx>(x)));
283}
284
285/// Compute the dot product of two vectors.
286template <simdifiable_multi Vx, simdifiable_multi Vy>
288simdified_value_t<Vx> dot(Vx &&x, Vy &&y) {
289 BATMAT_ASSERT(x.num_batches() == y.num_batches());
290 simdified_value_t<Vx> result{};
291 for (index_t b = 0; b < x.num_batches(); ++b)
292 result += linalg::dot(x.batch(b), y.batch(b));
293 return result;
294}
295
296/// ∑ wᵢ xᵢ².
297template <simdifiable_multi Vw, simdifiable_multi Vx>
300 BATMAT_ASSERT(w.num_batches() == x.num_batches());
301 simdified_value_t<Vw> result{};
302 for (index_t b = 0; b < w.num_batches(); ++b)
303 result += linalg::weighted_norm_sq(w.batch(b), x.batch(b));
304 return result;
305}
306
307/// ∑ wᵢ(xᵢ - yᵢ)².
308template <simdifiable_multi Vw, simdifiable_multi Vx, simdifiable_multi Vy>
311 BATMAT_ASSERT(w.num_batches() == x.num_batches());
312 BATMAT_ASSERT(w.num_batches() == y.num_batches());
313 simdified_value_t<Vw> result{};
314 for (index_t b = 0; b < w.num_batches(); ++b)
315 result += linalg::weighted_norm_sq_diff(w.batch(b), x.batch(b), y.batch(b));
316 return result;
317}
318
319/// Compute the lane-wise norms (max, 1-norm, and 2-norm) of a batch of vectors.
320template <simdifiable_multi Vx>
322 using std::max;
324 typename norms::result_simd result{};
325 for (index_t b = 0; b < x.num_batches(); ++b)
326 result = norms{}(result, linalg::vnorms_all(x.batch(b)));
327 return result;
328}
329
330/// Compute the lane-wise infinity norms of a batch of vectors.
331template <simdifiable_multi Vx>
333 return vnorms_all(std::forward<Vx>(x)).norm_inf();
334}
335
336/// Compute the lane-wise 1-norms of a batch of vectors.
337template <simdifiable_multi Vx>
339 return vnorms_all(std::forward<Vx>(x)).norm_1();
340}
341
342/// Compute the lane-wise squared 2-norms of a batch of vectors.
343template <simdifiable_multi Vx>
345 simdified_simd_t<Vx> result{};
346 for (index_t b = 0; b < x.num_batches(); ++b)
347 result += linalg::vnorm_2_squared(x.batch(b));
348 return result;
349}
350
351/// Compute the lane-wise 2-norms of a batch of vectors.
352template <simdifiable_multi Vx>
354 using std::sqrt;
355 return sqrt(vnorm_2_squared(std::forward<Vx>(x)));
356}
357
358/// Compute the lane-wise dot products of two batches of vectors.
359template <simdifiable_multi Vx, simdifiable_multi Vy>
361simdified_simd_t<Vx> vdot(Vx &&x, Vy &&y) {
362 BATMAT_ASSERT(x.num_batches() == y.num_batches());
363 simdified_simd_t<Vx> result{};
364 for (index_t b = 0; b < x.num_batches(); ++b)
365 result += linalg::vdot(x.batch(b), y.batch(b));
366 return result;
367}
368
369/// ∑ wᵢ xᵢ² (lane-wise).
370template <simdifiable_multi Vw, simdifiable_multi Vx>
373 BATMAT_ASSERT(w.num_batches() == x.num_batches());
374 simdified_simd_t<Vw> result{};
375 for (index_t b = 0; b < w.num_batches(); ++b)
376 result += linalg::weighted_vnorm_sq(w.batch(b), x.batch(b));
377 return result;
378}
379
380/// ∑ wᵢ(xᵢ - yᵢ)² (lane-wise).
381template <simdifiable_multi Vw, simdifiable_multi Vx, simdifiable_multi Vy>
384 BATMAT_ASSERT(w.num_batches() == x.num_batches());
385 BATMAT_ASSERT(w.num_batches() == y.num_batches());
386 simdified_simd_t<Vw> result{};
387 for (index_t b = 0; b < w.num_batches(); ++b)
388 result += linalg::weighted_vnorm_sq_diff(w.batch(b), x.batch(b), y.batch(b));
389 return result;
390}
391
392/// @}
393
394} // namespace multi
395
396} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
simdified_simd_t< Vx > vnorm_2_squared(Vx &&x)
Compute the lane-wise squared 2-norms of a batch of vectors.
Definition reduce.hpp:152
simdified_value_t< Vx > norm_inf(Vx &&x)
Compute the infinity norm of a vector.
Definition reduce.hpp:134
simdified_simd_t< Vx > vnorm_1(Vx &&x)
Compute the lane-wise 1-norms of a batch of vectors.
Definition reduce.hpp:338
simdified_value_t< Vx > norm_2_squared(Vx &&x)
Compute the squared 2-norm of a vector.
Definition reduce.hpp:160
simdified_value_t< Vx > norm_2(Vx &&x)
Compute the 2-norm of a vector.
Definition reduce.hpp:280
simdified_simd_t< Vw > weighted_vnorm_sq(Vw &&w, Va &&a)
∑ wᵢ aᵢ² (lane-wise).
Definition reduce.hpp:200
simdified_value_t< Vx > norm_2(Vx &&x)
Compute the 2-norm of a vector.
Definition reduce.hpp:174
simdified_value_t< Vx > norm_2_squared(Vx &&x)
Compute the squared 2-norm of a vector.
Definition reduce.hpp:271
simdified_value_t< Vx > norm_1(Vx &&x)
Compute the 1-norm of a vector.
Definition reduce.hpp:146
simdified_simd_t< Vw > weighted_vnorm_sq_diff(Vw &&w, Va &&a, Vb &&b)
∑ wᵢ(aᵢ - bᵢ)² (lane-wise).
Definition reduce.hpp:218
simdified_simd_t< Vx > vnorm_1(Vx &&x)
Compute the lane-wise 1-norms of a batch of vectors.
Definition reduce.hpp:140
simdified_simd_t< Vx > vnorm_inf(Vx &&x)
Compute the lane-wise infinity norms of a batch of vectors.
Definition reduce.hpp:332
simdified_value_t< Vw > weighted_norm_sq(Vw &&w, Vx &&x)
∑ wᵢ xᵢ².
Definition reduce.hpp:299
simdified_value_t< Vx > dot(Vx &&x, Vy &&y)
Compute the dot product of two vectors.
Definition reduce.hpp:191
simdified_value_t< Vx > norm_1(Vx &&x)
Compute the 1-norm of a vector.
Definition reduce.hpp:265
norms< simdified_value_t< Vx > >::result norms_all(Vx &&x)
Compute the norms (max, 1-norm, and 2-norm) of a vector.
Definition reduce.hpp:121
simdified_value_t< Vw > weighted_norm_sq_difference(Vw &&w, Vx &&x, Vy &&y)
∑ wᵢ(xᵢ - yᵢ)².
Definition reduce.hpp:310
simdified_value_t< Vx > norm_inf(Vx &&x)
Compute the infinity norm of a vector.
Definition reduce.hpp:259
simdified_simd_t< Vx > vnorm_inf(Vx &&x)
Compute the lane-wise infinity norms of a batch of vectors.
Definition reduce.hpp:128
simdified_simd_t< Vx > vnorm_2(Vx &&x)
Compute the lane-wise 2-norms of a batch of vectors.
Definition reduce.hpp:353
simdified_value_t< Vx > dot(Vx &&x, Vy &&y)
Compute the dot product of two vectors.
Definition reduce.hpp:288
simdified_value_t< Vw > weighted_norm_sq(Vw &&w, Va &&a)
∑ wᵢ aᵢ².
Definition reduce.hpp:209
simdified_simd_t< Vw > weighted_vnorm_sq(Vw &&w, Vx &&x)
∑ wᵢ xᵢ² (lane-wise).
Definition reduce.hpp:372
norms< simdified_value_t< Vx >, simdified_simd_t< Vx > >::result_simd vnorms_all(Vx &&x)
Compute the lane-wise norms (max, 1-norm, and 2-norm) of a batch of vectors.
Definition reduce.hpp:321
simdified_simd_t< Vw > weighted_vnorm_sq_diff(Vw &&w, Vx &&x, Vy &&y)
∑ wᵢ(xᵢ - yᵢ)² (lane-wise).
Definition reduce.hpp:383
simdified_value_t< Vw > weighted_norm_sq_diff(Vw &&w, Va &&a, Vb &&b)
∑ wᵢ(aᵢ - bᵢ)².
Definition reduce.hpp:227
norms< simdified_value_t< Vx >, simdified_simd_t< Vx > >::result_simd vnorms_all(Vx &&x)
Compute the lane-wise norms (max, 1-norm, and 2-norm) of a batch of vectors.
Definition reduce.hpp:114
simdified_simd_t< Vx > vdot(Vx &&x, Vy &&y)
Compute the lane-wise dot products of two batches of vectors.
Definition reduce.hpp:361
simdified_simd_t< Vx > vnorm_2(Vx &&x)
Compute the lane-wise 2-norms of a batch of vectors.
Definition reduce.hpp:167
simdified_simd_t< Vx > vnorm_2_squared(Vx &&x)
Compute the lane-wise squared 2-norms of a batch of vectors.
Definition reduce.hpp:344
norms< simdified_value_t< Vx > >::result norms_all(Vx &&x)
Compute the norms (max, 1-norm, and 2-norm) of a vector.
Definition reduce.hpp:249
simdified_simd_t< Vx > vdot(Vx &&x, Vy &&y)
Compute the lane-wise dot products of two batches of vectors.
Definition reduce.hpp:182
#define GUANAQO_TRACE_LINALG(name, gflops)
stdx::simd< Tp, Abi > simd
Definition simd.hpp:102
typename detail::simdified_value< V >::type simdified_value_t
Definition simdify.hpp:214
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:216
typename detail::simdified_simd< V >::type simdified_simd_t
Definition simdify.hpp:218
constexpr bool simdify_compatible
Definition simdify.hpp:221
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:228
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
int index_t
Definition config.hpp:13
Vector reductions.
Lane-wise accumulators.
Definition norms.hpp:30
Utilities for computing vector norms.
Definition norms.hpp:26
static result_simd zero_simd()
Definition norms.hpp:72
typename norms< T >::result result
Accumulator.
Definition norms.hpp:28