batmat 0.0.24
Batched linear algebra routines
Loading...
Searching...
No Matches
sytrd.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
7#include <batmat/lut.hpp>
8#include <batmat/ops/cneg.hpp>
10#include <guanaqo/trace.hpp>
11
12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
13
15
16template <class T, class Abi, KernelConfig Conf, StorageOrder OD>
17inline const constinit auto microkernel_diag_lut =
20 });
21
22template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OD>
23[[gnu::hot, gnu::flatten]] void
26 using std::copysign;
27 using std::sqrt;
28 using simd = datapar::simd<T, Abi>;
29 BATMAT_ASSUME(k > R);
30 static constexpr auto safe_min = std::numeric_limits<T>::min();
31
32 // j j+1 j+2 R
33 // ┌─────┬─────┬─────┬─────┐
34 // │ D │ · │ · │ · │
35 // ├─────┼─────┼─────┼─────┤
36 // │ d1 │ a11 │ × │ × │ j
37 // ├─────┼─────┼─────┼─────┤
38 // │ b2 │ a21 │ a22 │ × │ j+1
39 // ├─────┼─────┼─────┼─────┤
40 // │ B3 │ a31 │ a32 │ A3 │ j+2
41 // └─────┴─────┴─────┴─────┘
42
43 // j j+1 j+2
44 // ┌──────────────┬────┬────┬───────────────┐
45 // │ d · · │ · │ · │ · · · │ d: diagonal elements of tridiagonal matrix
46 // │ │ │ │ │ e: off-diagonal elements of tridiagonal matrix
47 // │ c d · │ · │ · │ · · · │ b: Householder reflectors
48 // │ │ │ │ │ a: original matrix
49 // │ b c d │ · │ · │ · · · │ ×: implicitly symmetric part
50 // │ ├────┼────┼───────────────┤
51 // │ b b c │ a │ × │ × × × │ j
52 // ├──────────────┼────┼────┼───────────────┤
53 // │ b b b │ a │ a │ × × × │ j+1
54 // ├──────────────┼────┼────┼───────────────┤
55 // │ b b b │ a │ a │ a × × │ j+2
56 // │ │ │ │ │
57 // │ b b b │ a │ a │ a a × │
58 // │ │ │ │ │
59 // │ b b b │ a │ a │ a a a │
60 // └──────────────┴────┴────┴───────────────┘
61 // │ │ │ ╰─ symv A3 by a31
62 // │ │ ╰─ dot a32 with a31
63 // │ ╰─ ±norm a31 and a21 becomes c(j)
64 // ╰─ dot B3 with a31 for block Householder
65
66 // symv:
67 // [ W(j+2, j) ] [ A(j+2, j+2) A(j+3, j+2) A(j+4, j+2) ... ] [ A(j+2, j) ]
68 // [ W(j+3, j) ] = [ A(j+3, j+2) A(j+3, j+3) A(j+4, j+3) ... ] [ A(j+3, j) ]
69 // [ W(j+4, j) ] [ A(j+4, j+2) A(j+4, j+3) A(j+4, j+4) ... ] [ A(j+4, j) ]
70 // [ ... ] [ ... ... ... ... ] [ ... ]
71 //
72 // W(j+2, j) = sum(l=j+2..k) A(l, j+2) A(l, j)
73 // W(j+3, j) = A(j+3, j+2) A(j+2, j) + sum(l=j+3..k) A(l, j+3) A(l, j)
74 // W(j+4, j) = A(j+4, j+2) A(j+2, j) + A(j+4, j+3) A(j+3, j) + sum(l=j+4..k) A(l, j+4) A(l, j)
75 // W(q, j) = sum(p=j+2..q-1) A(q, p) A(p, j) + sum(l=q..k) A(l, q) A(l, j)
76
77 UNROLL_FOR (index_t j = 0; j < R; ++j) {
78 using std::max;
79 simd Axj[R + 1];
80 UNROLL_FOR (index_t i = j + 1; i < R + 1; ++i)
81 Axj[i] = D.load(i, j);
82 // Compute inner products between a(j) and b(i<j), a(j), and a(i>j) (symv)
83 simd bb[R + 1]{};
84 // Triangular part
85 UNROLL_FOR (index_t q = j + 2; q < R + 1; ++q) {
86 UNROLL_FOR (index_t i = 0; i <= q; ++i)
87 bb[i] += D.load(q, i) * Axj[q];
88 UNROLL_FOR (index_t p = j + 2; p < q; ++p)
89 bb[q] += D.load(q, p) * Axj[p];
90 }
91 // Rectangular part
92 for (index_t q = max(R + 1, j + 2); q < k; ++q) {
93 simd Aqx[R + 1];
94 UNROLL_FOR (index_t i = 0; i < R + 1; ++i)
95 Aqx[i] = D.load(q, i);
96 UNROLL_FOR (index_t i = 0; i < R + 1; ++i)
97 bb[i] += Aqx[i] * Aqx[j];
98 simd Yl{};
99 UNROLL_FOR (index_t p = j + 2; p < R + 1; ++p)
100 Yl += Aqx[p] * Axj[p];
101 Y.store(Yl, q, j); // W(q, j) = sum(p=j+2..q-1) A(q, p) A(p, j)
102 }
103 const simd a21 = Axj[j + 1];
104 bb[j] += a21 * a21;
105 // bb[i<j] now contain the inner products of a31 with the previous Householder vectors
106 // (except for the first components, which are implicitly 1 and are added later).
107 // bb[j] contains the squared norm of (a21, a31).
108 // bb[j+1] contains the dot product of a31 and a32.
109 // bb[i>j+1] contain the top rows of the symmetric product A3 a31 (complete).
110 // Y[i>=R, j] contains part of the symmetric product A3 a31, but still requires adding
111 // the contributions from all columns >=R (including dot products with the upper
112 // triangle of A3)
113
114 // Energy condition and Householder coefficients
115 const simd abs_c̃jj = sqrt(bb[j]);
116 const simd c̃j = copysign(abs_c̃jj, a21), β = a21 + c̃j;
117 const simd inv_τ = datapar::select(abs_c̃jj > safe_min, β / c̃j, simd{0}),
118 inv_β = datapar::select(abs_c̃jj > safe_min, simd{1} / β, simd{0});
119
120 // Save block Householder matrix W
121 UNROLL_FOR (index_t i = 0; i < j; ++i)
122 // Multiply implicit first component of the current Householder vector by the
123 // corresponding row of the previous Householder vectors, and add it to the previously
124 // computed inner products with a31, scaled by β⁻¹ to go from a31 to the normalized
125 // Householder vector.
126 W.store(bb[i] * inv_β + D.load(j + 1, i), i, j);
127 W.store(inv_τ, j, j); // inverse of diagonal
128
129 // Finish the symmetric product A3 a31
130 for (index_t i = max(R + 1, j + 2); i < k; ++i) {
131 simd yi = Y.load(i, j);
132 const simd xi = D.load(i, j);
133 yi += D.load(i, i) * xi; // diagonal term
134 for (index_t l = max(R + 1, j + 2); l < i; ++l) { // lower triangle l < i
135 simd yl = Y.load(l, j);
136 const simd xl = D.load(l, j), ail = D.load(i, l); // TODO: access D column-wise
137 yi += ail * xl;
138 yl += ail * xi; // symmetric contribution to y[j]
139 Y.store(yl, l, j); // TODO: optimize by unrolling to avoid load/store of yl
140 }
141 Y.store(yi, i, j);
142 }
143 // Y[i>=R, j] now contains the complete bottom rows of the symmetric product A3 a31.
144
145 // Now compute the vector w = τ⁻¹(A3 b + a32) = τ⁻¹(β⁻¹ A3 a31 + a32).
146 simd Axj1[R + 1];
147 UNROLL_FOR (index_t i = j + 1; i < R + 1; ++i)
148 Axj1[i] = D.load(i, j + 1);
149 UNROLL_FOR (index_t i = j + 2; i < R + 1; ++i)
150 Y.store(bb[i] = inv_τ * (inv_β * bb[i] + Axj1[i]), i, j);
151 for (index_t i = max(R + 1, j + 2); i < k; ++i) {
152 simd yi = Y.load(i, j);
153 Y.store(inv_τ * (inv_β * yi + D.load(i, j + 1)), i, j);
154 }
155 // bb[i>j+1] now contain w[i].
156
157 const simd a2 = Axj1[j + 1];
158 const simd a31ᵀa32 = bb[j + 1];
159 const simd ω = inv_τ * (inv_β * a31ᵀa32 + a2); // ω = τ⁻¹(a32ᵀb + a22)
160 // Scale a31 to obtain b, and dot it with w.
161 simd wᵀb_ω = ω; // accumulator for wᵀb + ω
162 simd b[R + 1];
163 UNROLL_FOR (index_t l = j + 2; l < R + 1; ++l) {
164 b[l] = inv_β * Axj[l];
165 D.store(b[l], l, j);
166 wᵀb_ω += b[l] * bb[l];
167 }
168 for (index_t l = max(R + 1, j + 2); l < k; ++l) {
169 simd bl = inv_β * D.load(l, j);
170 D.store(bl, l, j);
171 wᵀb_ω += bl * Y.load(l, j);
172 }
173 const simd γ = inv_τ * wᵀb_ω; // γ = τ⁻¹ (wᵀb + ω)
174 const simd d2 = a2 - T{2} * ω + γ;
175 D.store(-c̃j, j + 1, j);
176 D.store(d2, j + 1, j + 1);
177
178 // Compute and store ã32 = a32 + (γ - ω) b - w and y = w - ½γ b
179 const simd γ_ω = γ - ω;
180 UNROLL_FOR (index_t l = j + 2; l < R + 1; ++l) {
181 simd ã32 = Axj1[l] + γ_ω * b[l] - bb[l];
182 D.store(ã32, l, j + 1);
183 simd yl = bb[l] - simd{T{0.5}} * γ * b[l];
184 Y.store(yl, l, j);
185 }
186 for (index_t l = max(R + 1, j + 2); l < k; ++l) {
187 simd bl = D.load(l, j);
188 simd ã32 = D.load(l, j + 1) + γ_ω * bl - Y.load(l, j);
189 D.store(ã32, l, j + 1);
190 simd yl = Y.load(l, j) - simd{T{0.5}} * γ * bl;
191 Y.store(yl, l, j);
192 }
193
194 // Update the trailing submatrix A3 = A3 - byᵀ - ybᵀ
195 // TODO: optimize memory accesses
196 for (index_t i = j + 2; i < k; ++i) // column of A3
197 for (index_t l = i; l < k; ++l) { // row of A3 (lower triangle)
198 simd Ail = D.load(l, i);
199 Ail -= Y.load(i, j) * D.load(l, j) + Y.load(l, j) * D.load(i, j);
200 D.store(Ail, l, i);
201 }
202 }
203}
204
205/// Symmetric block tridiagonalization.
206template <class T, class Abi, KernelConfig Conf, StorageOrder OD>
207void sytrd_register(const view<T, Abi, OD> D, const view<T, Abi> W, const view<T, Abi> Y) noexcept {
208 static constexpr index_constant<SizeR<T, Abi>> R;
209 const index_t k = D.rows();
210 BATMAT_ASSUME(k > 0);
211 BATMAT_ASSUME(D.rows() == D.cols());
212 BATMAT_ASSUME(W.rows() == 0 ||
213 (W.cols() == 1 && W.rows() == std::max<index_t>(D.cols(), 1) - 1) ||
214 std::make_pair(W.rows(), W.cols()) == (sytrd_W_size<T, Abi>)(D));
215 BATMAT_ASSUME(std::make_pair(Y.rows(), Y.cols()) == (sytrd_Y_size<T, Abi>)(D));
216
218 alignas(W_t::alignment()) T W_sto[W_t::size()];
219
220 // Sizeless views to partition and pass to the micro-kernels
221 const uview<T, Abi, OD> D_ = D;
224 const bool store_full_W = std::make_pair(W.rows(), W.cols()) == (sytrd_W_size<T, Abi>)(D);
225
226 // Process all diagonal blocks (in multiples of R, except the last).
227 foreach_chunked_merged(0, k - 1, R, [&](index_t j, auto rem_j) {
228 auto Wj = store_full_W ? W_t{W_.middle_cols(j / R).data} : W_t{W_sto};
229 auto Djj = D_.block(j, j);
230 microkernel_diag_lut<T, Abi, Conf, OD>[rem_j - 1](k - j, Wj, Djj, Y_);
231 if (!store_full_W && W.rows() > 0) [[unlikely]]
232 for (index_t l = 0; l < rem_j; ++l)
233 W_.store(Wj.load(l, l), j + l, 0);
234 });
235}
236
237} // namespace batmat::linalg::micro_kernels::sytrd
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
auto select(auto cond, auto t, auto f)
Definition simd.hpp:245
stdx::simd< Tp, Abi > simd
Definition simd.hpp:148
void sytrd_diag_microkernel(index_t k, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OD > D, uview< T, Abi, StorageOrder::ColMajor > Y) noexcept
Definition sytrd.tpp:24
constexpr std::pair< index_t, index_t > sytrd_W_size(view< T, Abi, OD > D)
Definition sytrd.hpp:24
void sytrd_register(view< T, Abi, OD > D, view< T, Abi > W, view< T, Abi > Y) noexcept
Symmetric block tridiagonalization.
Definition sytrd.tpp:207
const constinit auto microkernel_diag_lut
Definition sytrd.tpp:17
constexpr std::pair< index_t, index_t > sytrd_Y_size(view< T, Abi, OD > D)
Definition sytrd.hpp:32
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
int index_t
Definition config.hpp:13
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
void store(simd x, index_t r, index_t c) const noexcept
Definition uview.hpp:104
Self middle_cols(this const Self &self, index_t c) noexcept
Definition uview.hpp:118