batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
geqrf.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
7#include <batmat/lut.hpp>
8#include <batmat/ops/cneg.hpp>
10#include <guanaqo/trace.hpp>
11
12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
13
15
16template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
17inline const constinit auto microkernel_diag_lut =
20 });
21
22template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
23inline const constinit auto microkernel_full_lut =
26 });
27
28template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
29inline const constinit auto microkernel_tail_lut =
31 return geqrf_tail_microkernel<T, Abi, Conf, SizeR<T, Abi>, Row + 1, OA, OD, OB>;
32 });
33
34template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
38 });
39
40template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OA, StorageOrder OD>
41[[gnu::hot, gnu::flatten]] void
44 using std::copysign;
45 using std::sqrt;
46 using simd = datapar::simd<T, Abi>;
47 BATMAT_ASSUME(k > 0); // TODO: fast path for k == 1
48
49 UNROLL_FOR (index_t j = 0; j < R; ++j) {
50 const bool use_A = j == 0;
51 // Compute all inner products between A and a
52 simd bb[R]{};
53 for (index_t l = j + 1; l < k; ++l) {
54 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
55 UNROLL_FOR (index_t i = 0; i < R; ++i)
56 bb[i] += (use_A ? A.load(l, i) : D.load(l, i)) * Alj;
57 }
58 simd aa[R];
59 UNROLL_FOR (index_t i = 0; i < R; ++i)
60 aa[i] = use_A ? A.load(j, i) : D.load(j, i);
61 bb[j] += aa[j] * aa[j];
62 // Energy condition and Householder coefficients
63 const simd ãjj = copysign(sqrt(bb[j]), aa[j]), β = aa[j] + ãjj;
64 simd inv_τ = β / ãjj, inv_β = simd{1} / β;
65 D.store(-ãjj, j, j);
66 // Save block Householder matrix W
67 UNROLL_FOR (index_t i = 0; i < j; ++i)
68 bb[i] = bb[i] * inv_β + aa[i];
69 bb[j] = inv_τ; // inverse of diagonal
70 UNROLL_FOR (index_t i = 0; i < j + 1; ++i)
71 W.store(bb[i], i, j);
72 // Replace row j of A by R (and replace bb[j+1:] with w)
73 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
74 bb[i] = (aa[i] + bb[i] * inv_β) * inv_τ; // w
75 D.store(aa[i] - bb[i], j, i); // R[j, i]
76 }
77 // Update trailing part of A
78 for (index_t l = j + 1; l < k; ++l) {
79 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
80 Alj *= inv_β; // Scale Householder vector
81 D.store(Alj, l, j); // V[l, j]
82 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
83 simd Ali = use_A ? A.load(l, i) : D.load(l, i);
84 Ali -= Alj * bb[i];
85 D.store(Ali, l, i);
86 }
87 }
88 }
89}
90
91/// A (k×R)
92/// D (k×R)
93template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OA, StorageOrder OD>
94[[gnu::hot, gnu::flatten]] void geqrf_full_microkernel(index_t k, uview<const T, Abi, OA> A,
95 uview<T, Abi, OD> D) noexcept {
96 using std::copysign;
97 using std::sqrt;
98 using simd = datapar::simd<T, Abi>;
99 BATMAT_ASSUME(k > 0); // TODO: fast path for k == 1
100
101 UNROLL_FOR (index_t j = 0; j < R; ++j) {
102 const bool use_A = j == 0;
103 // Compute all inner products between A and a
104 simd bb[R];
105 UNROLL_FOR (index_t i = j; i < R; ++i)
106 bb[i] = simd{0};
107 for (index_t l = j + 1; l < k; ++l) {
108 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
109 UNROLL_FOR (index_t i = j; i < R; ++i)
110 bb[i] += (use_A ? A.load(l, i) : D.load(l, i)) * Alj;
111 }
112 simd aa[R];
113 UNROLL_FOR (index_t i = j; i < R; ++i)
114 aa[i] = use_A ? A.load(j, i) : D.load(j, i);
115 bb[j] += aa[j] * aa[j];
116 // Energy condition and Householder coefficients
117 const simd ãjj = copysign(sqrt(bb[j]), aa[j]), β = aa[j] + ãjj;
118 simd inv_τ = β / ãjj, inv_β = simd{1} / β;
119 D.store(-ãjj, j, j);
120 // Replace row j of A by R (and replace bb[j+1:] with w)
121 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
122 bb[i] = (aa[i] + bb[i] * inv_β) * inv_τ; // w
123 D.store(aa[i] - bb[i], j, i); // R[j, i]
124 }
125 // Update trailing part of A
126 for (index_t l = j + 1; l < k; ++l) {
127 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
128 Alj *= inv_β; // Scale Householder vector
129 D.store(Alj, l, j); // V[l, j]
130 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
131 simd Ali = use_A ? A.load(l, i) : D.load(l, i);
132 Ali -= Alj * bb[i];
133 D.store(Ali, l, i);
134 }
135 }
136 }
137}
138
139// Householder vectors are stored in the strict lower triangle of B, with the upper triangle
140// implicitly equal to the identity.
141// The matrix W completes the block Householder representation Q = I - BW⁻¹Bᵀ. The diagonal of W
142// is already inverted to enable efficient application of W⁻¹.
143// B = [ 1 0 0 ... ]
144// [ b11 1 0 ... ]
145// [ b21 b22 1 ... ]
146// [ b31 b32 b33 ... ]
147// B: k×R (lower trapezoidal, with implicit unit diagonal)
148// W: R×R (upper triangular, with inverted diagonal)
149// A: k×S
150// D: k×S
151template <class T, class Abi, KernelConfig Conf, index_t R, index_t S, StorageOrder OA,
153[[gnu::hot, gnu::flatten]] void geqrf_tail_microkernel(
154 index_t k, bool transposed, triangular_accessor<const T, Abi, SizeR<T, Abi>> W,
156 using simd = datapar::simd<T, Abi>;
157 BATMAT_ASSUME(k > 0);
158
159 // Compute product U = BᵀA
160 simd V[R][S];
161 // Triangular part of B (top R rows)
162 UNROLL_FOR (index_t l = 0; l < R; ++l)
163 UNROLL_FOR (index_t i = 0; i < S; ++i) {
164 V[l][i] = A.load(l, i); // B[l, l] = 1
165 UNROLL_FOR (index_t j = 0; j < l; ++j) // B[l, >l] = 0
166 V[j][i] += B.load(l, j) * A.load(l, i);
167 }
168 // Remaining rectangular part of B
169 for (index_t l = R; l < k; ++l)
170 UNROLL_FOR (index_t j = 0; j < R; ++j) {
171 auto Blj = B.load(l, j);
172 UNROLL_FOR (index_t i = 0; i < S; ++i)
173 V[j][i] += Blj * A.load(l, i);
174 }
175
176 // Solve system V = W⁻¹ U (with W upper triangular, in-place)
177 if (!transposed)
178 UNROLL_FOR (index_t j = R; j-- > 0;) // row of W
179 UNROLL_FOR (index_t i = 0; i < S; ++i) { // column of V, U
180 UNROLL_FOR (index_t l = j + 1; l < R; ++l) // column of W
181 V[j][i] -= W.load(j, l) * V[l][i];
182 V[j][i] *= W.load(j, j); // diagonal already inverted
183 }
184 // Solve system V = W⁻ᵀ U (with W upper triangular, in-place)
185 else
186 UNROLL_FOR (index_t j = 0; j < R; ++j) // row of Wᵀ
187 UNROLL_FOR (index_t i = 0; i < S; ++i) { // column of V, U
188 UNROLL_FOR (index_t l = 0; l < j; ++l) // column of Wᵀ
189 V[j][i] -= W.load(l, j) * V[l][i];
190 V[j][i] *= W.load(j, j); // diagonal already inverted
191 }
192
193 // Update A = A - B V
194 simd Bl[R];
195 // Top R rows of B
196 UNROLL_FOR (index_t l = 0; l < R; ++l) {
197 UNROLL_FOR (index_t j = 0; j < l; ++j)
198 Bl[j] = B.load(l, j);
199 UNROLL_FOR (index_t i = 0; i < S; ++i) {
200 simd Dli = A.load(l, i) - V[l][i];
201 UNROLL_FOR (index_t j = 0; j < l; ++j)
202 Dli -= V[j][i] * Bl[j];
203 D.store(Dli, l, i);
204 }
205 }
206 // Remaining rectangular part of B
207 for (index_t l = R; l < k; ++l) {
208 UNROLL_FOR (index_t j = 0; j < R; ++j)
209 Bl[j] = B.load(l, j);
210 UNROLL_FOR (index_t i = 0; i < S; ++i) {
211 simd Dli = A.load(l, i);
212 UNROLL_FOR (index_t j = 0; j < R; ++j)
213 Dli -= V[j][i] * Bl[j];
214 D.store(Dli, l, i);
215 }
216 }
217}
218
219/// Block hyperbolic Householder factorization update using register blocking.
220/// This variant does not store the Householder representation W.
221template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
223 const view<T, Abi> W) noexcept {
224 static constexpr index_constant<SizeR<T, Abi>> R;
225 static constexpr index_constant<SizeS<T, Abi>> S;
226 const index_t k = A.rows();
227 BATMAT_ASSUME(k > 0);
228 BATMAT_ASSUME(A.rows() >= A.cols());
229 BATMAT_ASSUME(A.rows() == D.rows());
230 BATMAT_ASSUME(A.cols() == D.cols());
231 BATMAT_ASSUME(W.rows() == 0 || (W.cols() == 1 && W.rows() == A.cols()) ||
232 std::make_pair(W.rows(), W.cols()) == (geqrf_W_size<const T, Abi>)(A));
233
235 alignas(W_t::alignment()) T W_sto[W_t::size()];
236
237 // Sizeless views to partition and pass to the micro-kernels
238 const uview<const T, Abi, OA> A_ = A;
239 const uview<T, Abi, OD> D_ = D;
241 const bool store_full_W = std::make_pair(W.rows(), W.cols()) == (geqrf_W_size<const T, Abi>)(A);
242
243 // Process all diagonal blocks (in multiples of R, except the last).
244 if (A.rows() == A.cols() && W.rows() == 0) {
245 auto Wj = W_t{W_sto};
247 0, A.cols(), R,
248 [&](index_t j) {
249 auto Djj = D_.block(j, j);
250 // Copy result from A to D
251 if (j == 0) {
252 // Triangularize block column j (all rows below diagonal)
253 geqrf_diag_microkernel<T, Abi, Conf, R, OA, OD>(k, Wj, A_, Djj);
254 // Update the trailing columns (in multiples of S)
255 foreach_chunked_merged(
256 j + R, A.cols(), S,
257 [&](index_t i, auto rem_i) {
258 auto Dji = D_.block(j, i);
259 microkernel_tail_lut<T, Abi, Conf, OA, OD, OD>[rem_i - 1](
260 k, true, Wj, A_.block(j, i), Dji, Djj);
261 },
262 LoopDir::Backward); // TODO: decide on order
263 } else {
264 // Triangularize block column j (all rows below diagonal)
265 geqrf_diag_microkernel<T, Abi, Conf, R, OD, OD>(k - j, Wj, Djj, Djj);
266 // Update the trailing columns (in multiples of S)
267 foreach_chunked_merged(
268 j + R, A.cols(), S,
269 [&](index_t i, auto rem_i) {
270 auto Dji = D_.block(j, i);
271 microkernel_tail_lut<T, Abi, Conf, OD, OD, OD>[rem_i - 1](
272 k - j, true, Wj, Dji, Dji, Djj);
273 },
274 LoopDir::Backward); // TODO: decide on order
275 }
276 },
277 [&](index_t j, index_t rem_j) {
278 auto Djj = D_.block(j, j);
279 if (j == 0) // copy result from A to D
280 microkernel_full_lut<T, Abi, Conf, OA, OD>[rem_j - 1](k, A_, Djj);
281 else
282 microkernel_full_lut<T, Abi, Conf, OD, OD>[rem_j - 1](k - j, Djj, Djj);
283 });
284 } else {
285 foreach_chunked_merged(0, A.cols(), R, [&](index_t j, auto rem_j) {
286 auto Wj = store_full_W ? W_t{W_.middle_cols(j / R).data} : W_t{W_sto};
287 auto Djj = D_.block(j, j);
288 // Copy result from A to D
289 if (j == 0) {
290 // Triangularize block column j (all rows below diagonal)
291 microkernel_diag_lut<T, Abi, Conf, OA, OD>[rem_j - 1](k, Wj, A_, Djj);
292 // Update the trailing columns (in multiples of S)
294 j + R, A.cols(), S,
295 [&](index_t i, auto rem_i) {
296 auto Dji = D_.block(j, i);
297 microkernel_tail_lut_2<T, Abi, Conf, OA, OD, OD>[rem_j - 1][rem_i - 1](
298 k, true, Wj, A_.block(j, i), Dji, Djj);
299 },
300 LoopDir::Backward); // TODO: decide on order
301 } else {
302 // Triangularize block column j (all rows below diagonal)
303 microkernel_diag_lut<T, Abi, Conf, OD, OD>[rem_j - 1](k - j, Wj, Djj, Djj);
304 // Update the trailing columns (in multiples of S)
305 foreach_chunked_merged(
306 j + R, A.cols(), S,
307 [&](index_t i, auto rem_i) {
308 auto Dji = D_.block(j, i);
309 microkernel_tail_lut_2<T, Abi, Conf, OD, OD, OD>[rem_j - 1][rem_i - 1](
310 k - j, true, Wj, Dji, Dji, Djj);
311 },
312 LoopDir::Backward); // TODO: decide on order
313 }
314 if (!store_full_W && W.rows() > 0) [[unlikely]]
315 for (index_t l = 0; l < rem_j; ++l)
316 W_.store(Wj.load(l, l), j + l, 0);
317 });
318 }
319}
320
321/// Apply a block Householder transformation.
322template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
325 bool transposed, bool reversed) noexcept {
326 const index_t k = A.rows();
327 BATMAT_ASSUME(k > 0);
328 BATMAT_ASSUME(A.rows() == D.rows());
329 BATMAT_ASSUME(A.cols() == D.cols());
330 BATMAT_ASSUME(B.rows() == A.rows());
331
332 static constexpr index_constant<SizeR<T, Abi>> R;
334 BATMAT_ASSUME(std::make_pair(W.rows(), W.cols()) == (geqrf_W_size<const T, Abi>)(B));
335
336 // Sizeless views to partition and pass to the micro-kernels
337 const uview<const T, Abi, OA> A_ = A;
338 const uview<T, Abi, OD> D_ = D;
339 const uview<const T, Abi, OB> B_ = B;
341
342 // Process all diagonal blocks (in multiples of R, except the last).
343 const bool forward = transposed ^ reversed;
345 0, B.cols(), R,
346 [&](index_t j, auto nj) {
347 const bool first = forward ? j == 0 : j + nj >= B.cols();
348 static constexpr index_constant<SizeS<T, Abi>> S;
349 // Part of A corresponding to this diagonal block
350 auto Bjj = B_.block(j, j);
351 auto Wj = W_t{W_.middle_cols(j / R).data};
352 // Process all rows (in multiples of S).
353 foreach_chunked_merged( // TODO: swap loop order?
354 0, A.cols(), S,
355 [&](index_t i, auto ni) {
356 auto Dji = D_.block(j, i);
357 if (first)
358 microkernel_tail_lut_2<T, Abi, Conf, OA, OD, OB>[nj - 1][ni - 1](
359 k - j, transposed, Wj, A_.block(j, i), Dji, Bjj);
360 else
361 microkernel_tail_lut_2<T, Abi, Conf, OD, OD, OB>[nj - 1][ni - 1](
362 k - j, transposed, Wj, Dji, Dji, Bjj);
363 // TODO: is it better to merge this copy into the next micro-kernel call?
364 if (first && !transposed && D_.data != A_.data)
365 for (index_t l = 0; l < j; ++l)
366 for (index_t ii = i; ii < i + ni; ++ii)
367 D_.store(A_.load(l, ii), l, ii);
368 },
369 LoopDir::Backward); // TODO: decide on order
370 },
372}
373
374} // namespace batmat::linalg::micro_kernels::geqrf
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
int index_t
Definition config.hpp:13
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
Definition loop.hpp:20
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
stdx::simd< Tp, Abi > simd
Definition simd.hpp:148
void geqrf_copy_register(view< const T, Abi, OA > A, view< T, Abi, OD > D, view< T, Abi > W) noexcept
Block hyperbolic Householder factorization update using register blocking.
Definition geqrf.tpp:222
constexpr std::pair< index_t, index_t > geqrf_W_size(view< T, Abi, OA > A)
Definition geqrf.hpp:36
void geqrf_diag_microkernel(index_t k, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< const T, Abi, OA > A, uview< T, Abi, OD > D) noexcept
Definition geqrf.tpp:42
const constinit auto microkernel_full_lut
Definition geqrf.tpp:23
void geqrf_tail_microkernel(index_t k, bool transposed, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< const T, Abi, OA > A, uview< T, Abi, OD > D, uview< const T, Abi, OB > B) noexcept
Definition geqrf.tpp:153
const constinit auto microkernel_tail_lut_2
Definition geqrf.tpp:35
const constinit auto microkernel_tail_lut
Definition geqrf.tpp:29
const constinit auto microkernel_diag_lut
Definition geqrf.tpp:17
void geqrf_full_microkernel(index_t k, uview< const T, Abi, OA > A, uview< T, Abi, OD > D) noexcept
A (k×R) D (k×R).
Definition geqrf.tpp:94
void geqrf_apply_register(view< const T, Abi, OA > A, view< T, Abi, OD > D, view< const T, Abi, OB > B, view< const T, Abi > W, bool transposed, bool reversed) noexcept
Apply a block Householder transformation.
Definition geqrf.tpp:323
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
int index_t
Definition config.hpp:13