batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
sytrd.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
7#include <batmat/lut.hpp>
8#include <batmat/ops/cneg.hpp>
10#include <guanaqo/trace.hpp>
11
12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
13
15
16template <class T, class Abi, KernelConfig Conf, StorageOrder OD>
17inline const constinit auto microkernel_diag_lut =
20 });
21
22template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OD>
23[[gnu::hot, gnu::flatten]] void
26 using std::copysign;
27 using std::sqrt;
28 using simd = datapar::simd<T, Abi>;
29 BATMAT_ASSUME(k > R);
30
31 // j j+1 j+2 R
32 // ┌─────┬─────┬─────┬─────┐
33 // │ D │ · │ · │ · │
34 // ├─────┼─────┼─────┼─────┤
35 // │ d1 │ a11 │ × │ × │ j
36 // ├─────┼─────┼─────┼─────┤
37 // │ b2 │ a21 │ a22 │ × │ j+1
38 // ├─────┼─────┼─────┼─────┤
39 // │ B3 │ a31 │ a32 │ A3 │ j+2
40 // └─────┴─────┴─────┴─────┘
41
42 // j j+1 j+2
43 // ┌──────────────┬────┬────┬───────────────┐
44 // │ d · · │ · │ · │ · · · │ d: diagonal elements of tridiagonal matrix
45 // │ │ │ │ │ e: off-diagonal elements of tridiagonal matrix
46 // │ c d · │ · │ · │ · · · │ b: Householder reflectors
47 // │ │ │ │ │ a: original matrix
48 // │ b c d │ · │ · │ · · · │ ×: implicitly symmetric part
49 // │ ├────┼────┼───────────────┤
50 // │ b b c │ a │ × │ × × × │ j
51 // ├──────────────┼────┼────┼───────────────┤
52 // │ b b b │ a │ a │ × × × │ j+1
53 // ├──────────────┼────┼────┼───────────────┤
54 // │ b b b │ a │ a │ a × × │ j+2
55 // │ │ │ │ │
56 // │ b b b │ a │ a │ a a × │
57 // │ │ │ │ │
58 // │ b b b │ a │ a │ a a a │
59 // └──────────────┴────┴────┴───────────────┘
60 // │ │ │ ╰─ symv A3 by a31
61 // │ │ ╰─ dot a32 with a31
62 // │ ╰─ ±norm a31 and a21 becomes c(j)
63 // ╰─ dot B3 with a31 for block Householder
64
65 // symv:
66 // [ W(j+2, j) ] [ A(j+2, j+2) A(j+3, j+2) A(j+4, j+2) ... ] [ A(j+2, j) ]
67 // [ W(j+3, j) ] = [ A(j+3, j+2) A(j+3, j+3) A(j+4, j+3) ... ] [ A(j+3, j) ]
68 // [ W(j+4, j) ] [ A(j+4, j+2) A(j+4, j+3) A(j+4, j+4) ... ] [ A(j+4, j) ]
69 // [ ... ] [ ... ... ... ... ] [ ... ]
70 //
71 // W(j+2, j) = sum(l=j+2..k) A(l, j+2) A(l, j)
72 // W(j+3, j) = A(j+3, j+2) A(j+2, j) + sum(l=j+3..k) A(l, j+3) A(l, j)
73 // W(j+4, j) = A(j+4, j+2) A(j+2, j) + A(j+4, j+3) A(j+3, j) + sum(l=j+4..k) A(l, j+4) A(l, j)
74 // W(q, j) = sum(p=j+2..q-1) A(q, p) A(p, j) + sum(l=q..k) A(l, q) A(l, j)
75
76 UNROLL_FOR (index_t j = 0; j < R; ++j) {
77 using std::max;
78 simd Axj[R + 1];
79 UNROLL_FOR (index_t i = j + 1; i < R + 1; ++i)
80 Axj[i] = D.load(i, j);
81 // Compute inner products between a(j) and b(i<j), a(j), and a(i>j) (symv)
82 simd bb[R + 1]{};
83 // Triangular part
84 UNROLL_FOR (index_t q = j + 2; q < R + 1; ++q) {
85 UNROLL_FOR (index_t i = 0; i <= q; ++i)
86 bb[i] += D.load(q, i) * Axj[q];
87 UNROLL_FOR (index_t p = j + 2; p < q; ++p)
88 bb[q] += D.load(q, p) * Axj[p];
89 }
90 // Rectangular part
91 for (index_t q = max(R + 1, j + 2); q < k; ++q) {
92 simd Aqx[R + 1];
93 UNROLL_FOR (index_t i = 0; i < R + 1; ++i)
94 Aqx[i] = D.load(q, i);
95 UNROLL_FOR (index_t i = 0; i < R + 1; ++i)
96 bb[i] += Aqx[i] * Aqx[j];
97 simd Yl{};
98 UNROLL_FOR (index_t p = j + 2; p < R + 1; ++p)
99 Yl += Aqx[p] * Axj[p];
100 Y.store(Yl, q, j); // W(q, j) = sum(p=j+2..q-1) A(q, p) A(p, j)
101 }
102 const simd a21 = Axj[j + 1];
103 bb[j] += a21 * a21;
104 // bb[i<j] now contain the inner products of a31 with the previous Householder vectors
105 // (except for the first components, which are implicitly 1 and are added later).
106 // bb[j] contains the squared norm of (a21, a31).
107 // bb[j+1] contains the dot product of a31 and a32.
108 // bb[i>j+1] contain the top rows of the symmetric product A3 a31 (complete).
109 // Y[i>=R, j] contains part of the symmetric product A3 a31, but still requires adding
110 // the contributions from all columns >=R (including dot products with the upper
111 // triangle of A3)
112
113 // Energy condition and Householder coefficients
114 const simd c̃j = copysign(sqrt(bb[j]), a21), β = a21 + c̃j;
115 const simd inv_τ = β / c̃j, inv_β = simd{1} / β;
116
117 // Save block Householder matrix W
118 UNROLL_FOR (index_t i = 0; i < j; ++i)
119 // Multiply implicit first component of the current Householder vector by the
120 // corresponding row of the previous Householder vectors, and add it to the previously
121 // computed inner products with a31, scaled by β⁻¹ to go from a31 to the normalized
122 // Householder vector.
123 W.store(bb[i] * inv_β + D.load(j + 1, i), i, j);
124 W.store(inv_τ, j, j); // inverse of diagonal
125
126 // Finish the symmetric product A3 a31
127 for (index_t i = max(R + 1, j + 2); i < k; ++i) {
128 simd yi = Y.load(i, j);
129 const simd xi = D.load(i, j);
130 yi += D.load(i, i) * xi; // diagonal term
131 for (index_t l = max(R + 1, j + 2); l < i; ++l) { // lower triangle l < i
132 simd yl = Y.load(l, j);
133 const simd xl = D.load(l, j), ail = D.load(i, l); // TODO: access D column-wise
134 yi += ail * xl;
135 yl += ail * xi; // symmetric contribution to y[j]
136 Y.store(yl, l, j); // TODO: optimize by unrolling to avoid load/store of yl
137 }
138 Y.store(yi, i, j);
139 }
140 // Y[i>=R, j] now contains the complete bottom rows of the symmetric product A3 a31.
141
142 // Now compute the vector w = τ⁻¹(A3 b + a32) = τ⁻¹(β⁻¹ A3 a31 + a32).
143 simd Axj1[R + 1];
144 UNROLL_FOR (index_t i = j + 1; i < R + 1; ++i)
145 Axj1[i] = D.load(i, j + 1);
146 UNROLL_FOR (index_t i = j + 2; i < R + 1; ++i)
147 Y.store(bb[i] = inv_τ * (inv_β * bb[i] + Axj1[i]), i, j);
148 for (index_t i = max(R + 1, j + 2); i < k; ++i) {
149 simd yi = Y.load(i, j);
150 Y.store(inv_τ * (inv_β * yi + D.load(i, j + 1)), i, j);
151 }
152 // bb[i>j+1] now contain w[i].
153
154 const simd a2 = Axj1[j + 1];
155 const simd a31ᵀa32 = bb[j + 1];
156 const simd ω = inv_τ * (inv_β * a31ᵀa32 + a2); // ω = τ⁻¹(a32ᵀb + a22)
157 // Scale a31 to obtain b, and dot it with w.
158 simd wᵀb_ω = ω; // accumulator for wᵀb + ω
159 simd b[R + 1];
160 UNROLL_FOR (index_t l = j + 2; l < R + 1; ++l) {
161 b[l] = inv_β * Axj[l];
162 D.store(b[l], l, j);
163 wᵀb_ω += b[l] * bb[l];
164 }
165 for (index_t l = max(R + 1, j + 2); l < k; ++l) {
166 simd bl = inv_β * D.load(l, j);
167 D.store(bl, l, j);
168 wᵀb_ω += bl * Y.load(l, j);
169 }
170 const simd γ = inv_τ * wᵀb_ω; // γ = τ⁻¹ (wᵀb + ω)
171 const simd d2 = a2 - 2 * ω + γ;
172 D.store(-c̃j, j + 1, j);
173 D.store(d2, j + 1, j + 1);
174
175 // Compute and store ã32 = a32 + (γ - ω) b - w and y = w - ½γ b
176 const simd γ_ω = γ - ω;
177 UNROLL_FOR (index_t l = j + 2; l < R + 1; ++l) {
178 simd ã32 = Axj1[l] + γ_ω * b[l] - bb[l];
179 D.store(ã32, l, j + 1);
180 simd yl = bb[l] - simd{0.5} * γ * b[l];
181 Y.store(yl, l, j);
182 }
183 for (index_t l = max(R + 1, j + 2); l < k; ++l) {
184 simd bl = D.load(l, j);
185 simd ã32 = D.load(l, j + 1) + γ_ω * bl - Y.load(l, j);
186 D.store(ã32, l, j + 1);
187 simd yl = Y.load(l, j) - simd{0.5} * γ * bl;
188 Y.store(yl, l, j);
189 }
190
191 // Update the trailing submatrix A3 = A3 - byᵀ - ybᵀ
192 // TODO: optimize memory accesses
193 for (index_t i = j + 2; i < k; ++i) // column of A3
194 for (index_t l = i; l < k; ++l) { // row of A3 (lower triangle)
195 simd Ail = D.load(l, i);
196 Ail -= Y.load(i, j) * D.load(l, j) + Y.load(l, j) * D.load(i, j);
197 D.store(Ail, l, i);
198 }
199 }
200}
201
202/// Symmetric block tridiagonalization.
203template <class T, class Abi, KernelConfig Conf, StorageOrder OD>
204void sytrd_register(const view<T, Abi, OD> D, const view<T, Abi> W, const view<T, Abi> Y) noexcept {
205 static constexpr index_constant<SizeR<T, Abi>> R;
206 const index_t k = D.rows();
207 BATMAT_ASSUME(k > 0);
208 BATMAT_ASSUME(D.rows() == D.cols());
209 BATMAT_ASSUME(W.rows() == 0 ||
210 (W.cols() == 1 && W.rows() == std::max<index_t>(D.cols(), 1) - 1) ||
211 std::make_pair(W.rows(), W.cols()) == (sytrd_W_size<T, Abi>)(D));
212 BATMAT_ASSUME(std::make_pair(Y.rows(), Y.cols()) == (sytrd_Y_size<T, Abi>)(D));
213
215 alignas(W_t::alignment()) T W_sto[W_t::size()];
216
217 // Sizeless views to partition and pass to the micro-kernels
218 const uview<T, Abi, OD> D_ = D;
221 const bool store_full_W = std::make_pair(W.rows(), W.cols()) == (sytrd_W_size<T, Abi>)(D);
222
223 // Process all diagonal blocks (in multiples of R, except the last).
224 foreach_chunked_merged(0, k - 1, R, [&](index_t j, auto rem_j) {
225 auto Wj = store_full_W ? W_t{W_.middle_cols(j / R).data} : W_t{W_sto};
226 auto Djj = D_.block(j, j);
227 microkernel_diag_lut<T, Abi, Conf, OD>[rem_j - 1](k - j, Wj, Djj, Y_);
228 if (!store_full_W && W.rows() > 0) [[unlikely]]
229 for (index_t l = 0; l < rem_j; ++l)
230 W_.store(Wj.load(l, l), j + l, 0);
231 });
232}
233
234} // namespace batmat::linalg::micro_kernels::sytrd
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
stdx::simd< Tp, Abi > simd
Definition simd.hpp:148
void sytrd_diag_microkernel(index_t k, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OD > D, uview< T, Abi, StorageOrder::ColMajor > Y) noexcept
Definition sytrd.tpp:24
constexpr std::pair< index_t, index_t > sytrd_W_size(view< T, Abi, OD > D)
Definition sytrd.hpp:24
void sytrd_register(view< T, Abi, OD > D, view< T, Abi > W, view< T, Abi > Y) noexcept
Symmetric block tridiagonalization.
Definition sytrd.tpp:204
const constinit auto microkernel_diag_lut
Definition sytrd.tpp:17
constexpr std::pair< index_t, index_t > sytrd_Y_size(view< T, Abi, OD > D)
Definition sytrd.hpp:32
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
int index_t
Definition config.hpp:13
Self block(this const Self &self, index_t r, index_t c) noexcept
Definition uview.hpp:110
void store(simd x, index_t r, index_t c) const noexcept
Definition uview.hpp:104
Self middle_cols(this const Self &self, index_t c) noexcept
Definition uview.hpp:118