batmat 0.0.23
Batched linear algebra routines
Loading...
Searching...
No Matches
sterf.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
8#include <batmat/simd.hpp>
9
10#include <cmath>
11#include <expected>
12#include <limits>
13
15
16[[nodiscard]] constexpr auto default_tolerance(auto user_tol) noexcept {
17 return user_tol > 0 ? user_tol : std::numeric_limits<decltype(user_tol)>::epsilon();
18}
19
20template <class T, class Abi>
22 using std::sqrt;
23 const datapar::simd<T, Abi> zero{0}, one{1};
24 static constexpr T safe_min = std::numeric_limits<T>::min();
25 static constexpr T safe_max = std::numeric_limits<T>::max();
26 static constexpr T ε = std::numeric_limits<T>::epsilon();
27 // Conservative safe range for intermediate squared/hypot-like quantities.
28 static constexpr T small = sqrt(safe_min) / ε;
29 static constexpr T large = sqrt(safe_max) * ε;
30
31 auto factor = datapar::simd<T, Abi>{one};
32 factor = datapar::select(anorm > large, large / anorm, factor);
33 factor = datapar::select(anorm < small, small / anorm, factor);
34 factor = datapar::select(anorm == zero, one, factor);
35 return factor;
36}
37
38template <class T, class Abi>
39[[nodiscard]] bool all_zero(datapar::simd<T, Abi> x) noexcept {
40 using std::all_of;
41 return all_of(x == T{0});
42}
43
44template <class T, class Abi>
46 datapar::simd<T, Abi> d1, T ε_sq) noexcept {
47 using std::all_of;
48 using std::fabs;
49 return all_of(fabs(e0_sq) <= ε_sq * fabs(d0 * d1));
50}
51
52/// Eigenvalues of [a b; b c].
53template <class T, class Abi>
54[[nodiscard]] std::pair<datapar::simd<T, Abi>, datapar::simd<T, Abi>>
56 datapar::simd<T, Abi> c) noexcept {
57 // The hypot form avoids spurious overflow in sqrt((a-c)^2 + 4b^2) for reasonably scaled blocks.
58 using std::hypot;
59 using std::swap;
60 const T half{0.5};
61 const auto half_trace = (a + c) * half;
62 const auto half_diff = (a - c) * half;
63 const auto radius = hypot(half_diff, b);
64 auto lambda1 = half_trace - radius;
65 auto lambda2 = half_trace + radius;
66 return {lambda1, lambda2};
67}
68
69template <class T, class Abi>
72 using std::fabs;
73 using std::sqrt;
74 const auto b = sqrt(fabs(e.load(l, 0))); // TODO: can we avoid the square root here?
75 const auto a = d.load(l, 0), c = d.load(l + 1, 0);
76 const auto [λ1, λ2] = stable_2x2_eigenvalues(a, b, c);
77 d.store(λ1, l, 0), d.store(λ2, l + 1, 0), e.store(T{0}, l, 0);
78}
79
80template <class T, class Abi>
82 datapar::simd<T, Abi> factor) noexcept {
83 for (index_t i = l; i <= m; ++i)
84 d.store(d.load(i, 0) * factor, i, 0);
85}
86
87template <class T, class Abi>
90 datapar::simd<T, Abi> factor) noexcept {
91 scale_diag_only(d, l, m, factor);
92 const auto factor_sq = factor * factor;
93 for (index_t i = l; i < m; ++i)
94 e.store(e.load(i, 0) * factor_sq, i, 0);
95}
96
97template <class T, class Abi>
100 index_t m) noexcept {
101 using simd = datapar::simd<T, Abi>;
102 using std::fabs;
103 using std::hypot;
104 using std::sqrt;
105
106 const simd zero{T{0}}, one{T{1}}, two{T{2}};
107
108 const auto p0 = d.load(l, 0);
109 const auto e0 = sqrt(fabs(e.load(l, 0)));
110 auto σ = (d.load(l + 1, 0) - p0) / (two * e0);
111 const auto rshift = hypot(σ, one);
112 σ = p0 - e0 / (σ + copysign(rshift, σ));
113 σ = datapar::select(e0 != zero, σ, zero);
114
115 auto c = one;
116 auto s = zero;
117 auto γ = d.load(m, 0) - σ;
118 auto p = γ * γ;
119
120 for (index_t i = m; i-- > l;) {
121 const auto bb = e.load(i, 0);
122 const auto r = p + bb;
123 if (i != m - 1)
124 e.store(s * r, i + 1, 0);
125 const auto old_c = c;
126 c = datapar::select(r != zero, p / r, one);
127 s = datapar::select(r != zero, bb / r, zero);
128 const auto old_γ = γ;
129 const auto α = d.load(i, 0);
130 γ = c * (α - σ) - s * old_γ;
131 d.store(old_γ + (α - γ), i + 1, 0);
132 p = datapar::select(c != zero, (γ * γ) / c, old_c * bb);
133 }
134 e.store(s * p, l, 0);
135 d.store(σ + γ, l, 0);
136}
137
138template <class T, class Abi>
141 index_t m) noexcept {
142 using simd = datapar::simd<T, Abi>;
143 using std::fabs;
144 using std::hypot;
145 using std::sqrt;
146
147 const simd zero{T{0}}, one{T{1}}, two{T{2}};
148
149 const auto p0 = d.load(m, 0);
150 const auto e0 = sqrt(fabs(e.load(m - 1, 0)));
151 auto σ = (d.load(m - 1, 0) - p0) / (two * e0);
152 const auto rshift = hypot(σ, one);
153 σ = p0 - e0 / (σ + copysign(rshift, σ));
154 σ = datapar::select(e0 != zero, σ, zero);
155
156 auto c = one;
157 auto s = zero;
158 auto γ = d.load(l, 0) - σ;
159 auto p = γ * γ;
160
161 for (index_t i = l; i < m; ++i) {
162 const auto bb = e.load(i, 0);
163 const auto r = p + bb;
164 if (i != l)
165 e.store(s * r, i - 1, 0);
166 const auto old_c = c;
167 c = datapar::select(r != zero, p / r, one);
168 s = datapar::select(r != zero, bb / r, zero);
169 const auto old_γ = γ;
170 const auto α = d.load(i + 1, 0);
171 γ = c * (α - σ) - s * old_γ;
172 d.store(old_γ + (α - γ), i, 0);
173 p = datapar::select(c != zero, (γ * γ) / c, old_c * bb);
174 }
175 e.store(s * p, m - 1, 0);
176 d.store(σ + γ, m, 0);
177}
178
179template <class T, class Abi>
182 index_t m) noexcept {
183 using std::fabs;
184 static constexpr index_t half_v = datapar::simd_size<T, Abi>::value / 2;
185 const bool use_qr = datapar::reduce_count(fabs(d.load(m, 0)) < fabs(d.load(l, 0))) > half_v;
186 if (use_qr)
188 else
190}
191
192template <class T, class Abi>
193[[nodiscard]] datapar::simd<T, Abi>
196 index_t m) noexcept {
197 using simd = datapar::simd<T, Abi>;
198 using std::fabs;
199 using std::max;
200
201 simd anorm_sq{T{0}};
202 for (index_t i = l; i <= m; ++i) {
203 const auto di = d.load(i, 0);
204 anorm_sq = max(anorm_sq, di * di);
205 }
206 for (index_t i = l; i < m; ++i) {
207 const auto ei_sq = fabs(e_sq.load(i, 0)); // may be negative due to rounding
208 anorm_sq = max(anorm_sq, ei_sq);
209 }
210 return anorm_sq;
211}
212
213template <class T, class Abi>
214[[nodiscard]] datapar::simd<T, Abi>
221
222/// Eigenvalues of a symmetric tridiagonal matrix given by `diag` and `subdiag`, computed in-place
223/// using the Pal-Walker-Kahan variant of the implicit QR/QL method with Wilkinson shifts. Based on
224/// LAPACK 3.12.1's `STERF`:
225/// https://netlib.org/lapack//explore-html/d4/d9d/group__sterf_gad293bb81da1c7785b42796d1e197f08c.html
226template <class T, class Abi>
227std::expected<index_t, index_t> sterf(view<T, Abi, StorageOrder::ColMajor> diag,
229 SterfOptions options) noexcept {
230 static_assert(!std::is_const_v<T>);
231 BATMAT_ASSUME(diag.cols() == 1);
232 BATMAT_ASSUME(subdiag.cols() == 1);
233 const index_t n = diag.rows();
234 BATMAT_ASSUME(n > 1);
235 BATMAT_ASSUME(subdiag.rows() == n - 1);
236
237 using simd = datapar::simd<T, Abi>;
238 using std::any_of;
239
240 const T ε = default_tolerance(static_cast<T>(options.relative_tolerance));
241 const T ε_sq = ε * ε;
242 const simd zero{T{0}}, one{T{1}};
243 const index_t max_total_iterations = options.max_iterations_per_eigenvalue * n;
244 index_t total_iterations = 0;
245
248
249 // Square all offdiagonals globally and apply the equivalent squared LAPACK split test:
250 // |e_i| <= eps * sqrt(|d_i|) * sqrt(|d_{i+1}|)
251 // becomes
252 // e_i^2 <= eps^2 * |d_i * d_{i+1}|.
253 for (index_t i = 0; i + 1 < n; ++i) {
254 const auto ei = e.load(i, 0);
255 const auto ei_sq = ei * ei;
256 const auto di = d.load(i, 0), di_next = d.load(i + 1, 0);
257 const bool split = negligible_squared_e(ei_sq, di, di_next, ε_sq);
258 e.store(split ? zero : ei_sq, i, 0);
259 }
260
261 bool found_unreduced_block;
262 do {
263 found_unreduced_block = false;
264
265 index_t l = 0;
266 while (l < n) {
267 // Skip over converged 1x1 blocks.
268 while (l + 1 < n && all_zero(e.load(l, 0)))
269 ++l;
270 if (l + 1 >= n)
271 break;
272 // Find the end of the current unreduced block.
273 index_t m;
274 for (m = l; m + 1 < n; ++m) {
275 const auto em = e.load(m, 0);
276 if (all_zero(em))
277 break;
278 if (negligible_squared_e(em, d.load(m, 0), d.load(m + 1, 0), ε_sq)) {
279 e.store(zero, m, 0);
280 break;
281 }
282 }
283 // Active unreduced block is d[l..m]. Reduce it.
284 if (m > l) {
285 found_unreduced_block = true;
286 const auto anorm = block_norm_estimate_from_squared_e(d, e, l, m);
287 const auto factor = safe_scaling_factor(anorm);
288 const bool scaled = any_of(factor != one);
289 if (scaled)
290 scale_squared_e(d, e, l, m, factor);
291
292 if (m == l + 1) // Solve the 2×2 block directly rather than using QR sweeps
294 else if (++total_iterations < max_total_iterations)
296
297 if (scaled)
298 scale_squared_e(d, e, l, m, T{1} / factor);
299 if (total_iterations >= max_total_iterations)
300 return std::unexpected(total_iterations);
301 }
302
303 l = m + 1; // Beginning of next block
304 }
305 } while (found_unreduced_block);
306 return total_iterations;
307}
308
309} // namespace batmat::linalg::micro_kernels::sterf
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
stdx::simd_size< Tp, Abi > simd_size
Definition simd.hpp:233
auto reduce_count(auto v)
Definition simd.hpp:244
auto select(auto cond, auto t, auto f)
Definition simd.hpp:245
stdx::simd< Tp, Abi > simd
Definition simd.hpp:148
std::expected< index_t, index_t > sterf(view< T, Abi, StorageOrder::ColMajor > diag, view< T, Abi, StorageOrder::ColMajor > subdiag, SterfOptions options) noexcept
Eigenvalues of a symmetric tridiagonal matrix given by diag and subdiag, computed in-place using the ...
Definition sterf.tpp:227
void sterf_ql_sweep_squared_e_inplace(uview< T, Abi, StorageOrder::ColMajor > d, uview< T, Abi, StorageOrder::ColMajor > e, index_t l, index_t m) noexcept
Definition sterf.tpp:98
datapar::simd< T, Abi > squared_block_norm_estimate_from_squared_e(uview< T, Abi, StorageOrder::ColMajor > d, uview< T, Abi, StorageOrder::ColMajor > e_sq, index_t l, index_t m) noexcept
Definition sterf.tpp:194
void sterf_dynamic_step_squared_e_inplace(uview< T, Abi, StorageOrder::ColMajor > d, uview< T, Abi, StorageOrder::ColMajor > e, index_t l, index_t m) noexcept
Definition sterf.tpp:180
datapar::simd< T, Abi > safe_scaling_factor(datapar::simd< T, Abi > anorm) noexcept
Definition sterf.tpp:21
constexpr auto default_tolerance(auto user_tol) noexcept
Definition sterf.tpp:16
datapar::simd< T, Abi > block_norm_estimate_from_squared_e(uview< T, Abi, StorageOrder::ColMajor > d, uview< T, Abi, StorageOrder::ColMajor > e_sq, index_t l, index_t m) noexcept
Definition sterf.tpp:215
void sterf_qr_sweep_squared_e_inplace(uview< T, Abi, StorageOrder::ColMajor > d, uview< T, Abi, StorageOrder::ColMajor > e, index_t l, index_t m) noexcept
Definition sterf.tpp:139
void solve_2x2_squared_e_inplace(uview< T, Abi, StorageOrder::ColMajor > d, uview< T, Abi, StorageOrder::ColMajor > e, index_t l) noexcept
Definition sterf.tpp:70
std::pair< datapar::simd< T, Abi >, datapar::simd< T, Abi > > stable_2x2_eigenvalues(datapar::simd< T, Abi > a, datapar::simd< T, Abi > b, datapar::simd< T, Abi > c) noexcept
Eigenvalues of [a b; b c].
Definition sterf.tpp:55
void scale_squared_e(uview< T, Abi, StorageOrder::ColMajor > d, uview< T, Abi, StorageOrder::ColMajor > e, index_t l, index_t m, datapar::simd< T, Abi > factor) noexcept
Definition sterf.tpp:88
bool all_zero(datapar::simd< T, Abi > x) noexcept
Definition sterf.tpp:39
bool negligible_squared_e(datapar::simd< T, Abi > e0_sq, datapar::simd< T, Abi > d0, datapar::simd< T, Abi > d1, T ε_sq) noexcept
Definition sterf.tpp:45
void scale_diag_only(uview< T, Abi, StorageOrder::ColMajor > d, index_t l, index_t m, datapar::simd< T, Abi > factor) noexcept
Definition sterf.tpp:81
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
int index_t
Definition config.hpp:13
void store(simd x, index_t r, index_t c) const noexcept
Definition uview.hpp:104
simd load(index_t r, index_t c) const noexcept
Definition uview.hpp:100