batmat main
Batched linear algebra routines
Loading...
Searching...
No Matches
geqrf.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
6#include <batmat/loop.hpp>
7#include <batmat/lut.hpp>
8#include <batmat/ops/cneg.hpp>
10#include <guanaqo/trace.hpp>
11
12#define UNROLL_FOR(...) BATMAT_FULLY_UNROLLED_FOR (__VA_ARGS__)
13
15
16template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
17inline const constinit auto microkernel_diag_lut =
20 });
21
22template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
23inline const constinit auto microkernel_full_lut =
26 });
27
28template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
29inline const constinit auto microkernel_tail_lut =
31 return geqrf_tail_microkernel<T, Abi, Conf, SizeR<T, Abi>, Row + 1, OA, OD, OB>;
32 });
33
34template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
38 });
39
40template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OA, StorageOrder OD>
41[[gnu::hot, gnu::flatten]] void
44 using std::copysign;
45 using std::sqrt;
46 using simd = datapar::simd<T, Abi>;
47 BATMAT_ASSUME(k > 0); // TODO: fast path for k == 1
48 static constexpr auto safe_min = std::numeric_limits<T>::min();
49
50 UNROLL_FOR (index_t j = 0; j < R; ++j) {
51 const bool use_A = j == 0;
52 // Compute all inner products between A and a
53 simd bb[R]{};
54 for (index_t l = j + 1; l < k; ++l) {
55 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
56 UNROLL_FOR (index_t i = 0; i < R; ++i)
57 bb[i] += (use_A ? A.load(l, i) : D.load(l, i)) * Alj;
58 }
59 simd aa[R];
60 UNROLL_FOR (index_t i = 0; i < R; ++i)
61 aa[i] = use_A ? A.load(j, i) : D.load(j, i);
62 bb[j] += aa[j] * aa[j];
63 // Energy condition and Householder coefficients
64 const simd abs_ãjj = sqrt(bb[j]);
65 const simd ãjj = copysign(abs_ãjj, aa[j]), β = aa[j] + ãjj;
66 simd inv_τ = datapar::select(abs_ãjj > safe_min, β / ãjj, simd{0}),
67 inv_β = datapar::select(abs_ãjj > safe_min, simd{1} / β, simd{0});
68 D.store(-ãjj, j, j);
69 // Save block Householder matrix W
70 UNROLL_FOR (index_t i = 0; i < j; ++i)
71 bb[i] = bb[i] * inv_β + aa[i];
72 bb[j] = inv_τ; // inverse of diagonal
73 UNROLL_FOR (index_t i = 0; i < j + 1; ++i)
74 W.store(bb[i], i, j);
75 // Replace row j of A by R (and replace bb[j+1:] with w)
76 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
77 bb[i] = (aa[i] + bb[i] * inv_β) * inv_τ; // w
78 D.store(aa[i] - bb[i], j, i); // R[j, i]
79 }
80 // Update trailing part of A
81 for (index_t l = j + 1; l < k; ++l) {
82 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
83 Alj *= inv_β; // Scale Householder vector
84 D.store(Alj, l, j); // V[l, j]
85 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
86 simd Ali = use_A ? A.load(l, i) : D.load(l, i);
87 Ali -= Alj * bb[i];
88 D.store(Ali, l, i);
89 }
90 }
91 }
92}
93
94/// A (k×R)
95/// D (k×R)
96template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OA, StorageOrder OD>
97[[gnu::hot, gnu::flatten]] void geqrf_full_microkernel(index_t k, uview<const T, Abi, OA> A,
98 uview<T, Abi, OD> D) noexcept {
99 using std::abs;
100 using std::copysign;
101 using std::sqrt;
102 using simd = datapar::simd<T, Abi>;
103 BATMAT_ASSUME(k > 0); // TODO: fast path for k == 1
104 static constexpr auto safe_min = std::numeric_limits<T>::min();
105
106 UNROLL_FOR (index_t j = 0; j < R; ++j) {
107 const bool use_A = j == 0;
108 // Compute all inner products between A and a
109 simd bb[R];
110 UNROLL_FOR (index_t i = j; i < R; ++i)
111 bb[i] = simd{0};
112 for (index_t l = j + 1; l < k; ++l) {
113 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
114 UNROLL_FOR (index_t i = j; i < R; ++i)
115 bb[i] += (use_A ? A.load(l, i) : D.load(l, i)) * Alj;
116 }
117 simd aa[R];
118 UNROLL_FOR (index_t i = j; i < R; ++i)
119 aa[i] = use_A ? A.load(j, i) : D.load(j, i);
120 bb[j] += aa[j] * aa[j];
121 // Energy condition and Householder coefficients
122 const simd abs_ãjj = sqrt(bb[j]);
123 const simd ãjj = copysign(abs_ãjj, aa[j]), β = aa[j] + ãjj;
124 simd inv_τ = datapar::select(abs_ãjj > safe_min, β / ãjj, simd{0}),
125 inv_β = datapar::select(abs_ãjj > safe_min, simd{1} / β, simd{0});
126 D.store(-ãjj, j, j);
127 // Replace row j of A by R (and replace bb[j+1:] with w)
128 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
129 bb[i] = (aa[i] + bb[i] * inv_β) * inv_τ; // w
130 D.store(aa[i] - bb[i], j, i); // R[j, i]
131 }
132 // Update trailing part of A
133 for (index_t l = j + 1; l < k; ++l) {
134 simd Alj = use_A ? A.load(l, j) : D.load(l, j);
135 Alj *= inv_β; // Scale Householder vector
136 D.store(Alj, l, j); // V[l, j]
137 UNROLL_FOR (index_t i = j + 1; i < R; ++i) {
138 simd Ali = use_A ? A.load(l, i) : D.load(l, i);
139 Ali -= Alj * bb[i];
140 D.store(Ali, l, i);
141 }
142 }
143 }
144}
145
146// Householder vectors are stored in the strict lower triangle of B, with the upper triangle
147// implicitly equal to the identity.
148// The matrix W completes the block Householder representation Q = I - BW⁻¹Bᵀ. The diagonal of W
149// is already inverted to enable efficient application of W⁻¹.
150// B = [ 1 0 0 ... ]
151// [ b11 1 0 ... ]
152// [ b21 b22 1 ... ]
153// [ b31 b32 b33 ... ]
154// B: k×R (lower trapezoidal, with implicit unit diagonal)
155// W: R×R (upper triangular, with inverted diagonal)
156// A: k×S
157// D: k×S
158template <class T, class Abi, KernelConfig Conf, index_t R, index_t S, StorageOrder OA,
160[[gnu::hot, gnu::flatten]] void geqrf_tail_microkernel(
161 index_t k, bool transposed, triangular_accessor<const T, Abi, SizeR<T, Abi>> W,
163 using simd = datapar::simd<T, Abi>;
164 BATMAT_ASSUME(k > 0);
165
166 // Compute product U = BᵀA
167 simd V[R][S];
168 // Triangular part of B (top R rows)
169 UNROLL_FOR (index_t l = 0; l < R; ++l)
170 UNROLL_FOR (index_t i = 0; i < S; ++i) {
171 V[l][i] = A.load(l, i); // B[l, l] = 1
172 UNROLL_FOR (index_t j = 0; j < l; ++j) // B[l, >l] = 0
173 V[j][i] += B.load(l, j) * A.load(l, i);
174 }
175 // Remaining rectangular part of B
176 for (index_t l = R; l < k; ++l)
177 UNROLL_FOR (index_t j = 0; j < R; ++j) {
178 auto Blj = B.load(l, j);
179 UNROLL_FOR (index_t i = 0; i < S; ++i)
180 V[j][i] += Blj * A.load(l, i);
181 }
182
183 // Solve system V = W⁻¹ U (with W upper triangular, in-place)
184 if (!transposed)
185 UNROLL_FOR (index_t j = R; j-- > 0;) // row of W
186 UNROLL_FOR (index_t i = 0; i < S; ++i) { // column of V, U
187 UNROLL_FOR (index_t l = j + 1; l < R; ++l) // column of W
188 V[j][i] -= W.load(j, l) * V[l][i];
189 V[j][i] *= W.load(j, j); // diagonal already inverted
190 }
191 // Solve system V = W⁻ᵀ U (with W upper triangular, in-place)
192 else
193 UNROLL_FOR (index_t j = 0; j < R; ++j) // row of Wᵀ
194 UNROLL_FOR (index_t i = 0; i < S; ++i) { // column of V, U
195 UNROLL_FOR (index_t l = 0; l < j; ++l) // column of Wᵀ
196 V[j][i] -= W.load(l, j) * V[l][i];
197 V[j][i] *= W.load(j, j); // diagonal already inverted
198 }
199
200 // Update A = A - B V
201 simd Bl[R];
202 // Top R rows of B
203 UNROLL_FOR (index_t l = 0; l < R; ++l) {
204 UNROLL_FOR (index_t j = 0; j < l; ++j)
205 Bl[j] = B.load(l, j);
206 UNROLL_FOR (index_t i = 0; i < S; ++i) {
207 simd Dli = A.load(l, i) - V[l][i];
208 UNROLL_FOR (index_t j = 0; j < l; ++j)
209 Dli -= V[j][i] * Bl[j];
210 D.store(Dli, l, i);
211 }
212 }
213 // Remaining rectangular part of B
214 for (index_t l = R; l < k; ++l) {
215 UNROLL_FOR (index_t j = 0; j < R; ++j)
216 Bl[j] = B.load(l, j);
217 UNROLL_FOR (index_t i = 0; i < S; ++i) {
218 simd Dli = A.load(l, i);
219 UNROLL_FOR (index_t j = 0; j < R; ++j)
220 Dli -= V[j][i] * Bl[j];
221 D.store(Dli, l, i);
222 }
223 }
224}
225
226/// Block hyperbolic Householder factorization update using register blocking.
227/// This variant does not store the Householder representation W.
228template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD>
230 const view<T, Abi> W) noexcept {
231 static constexpr index_constant<SizeR<T, Abi>> R;
232 static constexpr index_constant<SizeS<T, Abi>> S;
233 const index_t k = A.rows();
234 BATMAT_ASSUME(k > 0);
235 BATMAT_ASSUME(A.rows() >= A.cols());
236 BATMAT_ASSUME(A.rows() == D.rows());
237 BATMAT_ASSUME(A.cols() == D.cols());
238 BATMAT_ASSUME(W.rows() == 0 || (W.cols() == 1 && W.rows() == A.cols()) ||
239 std::make_pair(W.rows(), W.cols()) == (geqrf_W_size<const T, Abi>)(A));
240
242 alignas(W_t::alignment()) T W_sto[W_t::size()];
243
244 // Sizeless views to partition and pass to the micro-kernels
245 const uview<const T, Abi, OA> A_ = A;
246 const uview<T, Abi, OD> D_ = D;
248 const bool store_full_W = std::make_pair(W.rows(), W.cols()) == (geqrf_W_size<const T, Abi>)(A);
249
250 // Process all diagonal blocks (in multiples of R, except the last).
251 if (A.rows() == A.cols() && W.rows() == 0) {
252 auto Wj = W_t{W_sto};
254 0, A.cols(), R,
255 [&](index_t j) {
256 auto Djj = D_.block(j, j);
257 // Copy result from A to D
258 if (j == 0) {
259 // Triangularize block column j (all rows below diagonal)
260 geqrf_diag_microkernel<T, Abi, Conf, R, OA, OD>(k, Wj, A_, Djj);
261 // Update the trailing columns (in multiples of S)
262 foreach_chunked_merged(
263 j + R, A.cols(), S,
264 [&](index_t i, auto rem_i) {
265 auto Dji = D_.block(j, i);
266 microkernel_tail_lut<T, Abi, Conf, OA, OD, OD>[rem_i - 1](
267 k, true, Wj, A_.block(j, i), Dji, Djj);
268 },
269 LoopDir::Backward); // TODO: decide on order
270 } else {
271 // Triangularize block column j (all rows below diagonal)
272 geqrf_diag_microkernel<T, Abi, Conf, R, OD, OD>(k - j, Wj, Djj, Djj);
273 // Update the trailing columns (in multiples of S)
274 foreach_chunked_merged(
275 j + R, A.cols(), S,
276 [&](index_t i, auto rem_i) {
277 auto Dji = D_.block(j, i);
278 microkernel_tail_lut<T, Abi, Conf, OD, OD, OD>[rem_i - 1](
279 k - j, true, Wj, Dji, Dji, Djj);
280 },
281 LoopDir::Backward); // TODO: decide on order
282 }
283 },
284 [&](index_t j, index_t rem_j) {
285 auto Djj = D_.block(j, j);
286 if (j == 0) // copy result from A to D
287 microkernel_full_lut<T, Abi, Conf, OA, OD>[rem_j - 1](k, A_, Djj);
288 else
289 microkernel_full_lut<T, Abi, Conf, OD, OD>[rem_j - 1](k - j, Djj, Djj);
290 });
291 } else {
292 foreach_chunked_merged(0, A.cols(), R, [&](index_t j, auto rem_j) {
293 auto Wj = store_full_W ? W_t{W_.middle_cols(j / R).data} : W_t{W_sto};
294 auto Djj = D_.block(j, j);
295 // Copy result from A to D
296 if (j == 0) {
297 // Triangularize block column j (all rows below diagonal)
298 microkernel_diag_lut<T, Abi, Conf, OA, OD>[rem_j - 1](k, Wj, A_, Djj);
299 // Update the trailing columns (in multiples of S)
301 j + R, A.cols(), S,
302 [&](index_t i, auto rem_i) {
303 auto Dji = D_.block(j, i);
304 microkernel_tail_lut_2<T, Abi, Conf, OA, OD, OD>[rem_j - 1][rem_i - 1](
305 k, true, Wj, A_.block(j, i), Dji, Djj);
306 },
307 LoopDir::Backward); // TODO: decide on order
308 } else {
309 // Triangularize block column j (all rows below diagonal)
310 microkernel_diag_lut<T, Abi, Conf, OD, OD>[rem_j - 1](k - j, Wj, Djj, Djj);
311 // Update the trailing columns (in multiples of S)
312 foreach_chunked_merged(
313 j + R, A.cols(), S,
314 [&](index_t i, auto rem_i) {
315 auto Dji = D_.block(j, i);
316 microkernel_tail_lut_2<T, Abi, Conf, OD, OD, OD>[rem_j - 1][rem_i - 1](
317 k - j, true, Wj, Dji, Dji, Djj);
318 },
319 LoopDir::Backward); // TODO: decide on order
320 }
321 if (!store_full_W && W.rows() > 0) [[unlikely]]
322 for (index_t l = 0; l < rem_j; ++l)
323 W_.store(Wj.load(l, l), j + l, 0);
324 });
325 }
326}
327
328/// Apply a block Householder transformation.
329template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OD, StorageOrder OB>
332 bool transposed, bool reversed) noexcept {
333 const index_t k = A.rows();
334 BATMAT_ASSUME(k > 0);
335 BATMAT_ASSUME(A.rows() == D.rows());
336 BATMAT_ASSUME(A.cols() == D.cols());
337 BATMAT_ASSUME(B.rows() == A.rows());
338
339 static constexpr index_constant<SizeR<T, Abi>> R;
341 BATMAT_ASSUME(std::make_pair(W.rows(), W.cols()) == (geqrf_W_size<const T, Abi>)(B));
342
343 // Sizeless views to partition and pass to the micro-kernels
344 const uview<const T, Abi, OA> A_ = A;
345 const uview<T, Abi, OD> D_ = D;
346 const uview<const T, Abi, OB> B_ = B;
348
349 // Process all diagonal blocks (in multiples of R, except the last).
350 const bool forward = transposed ^ reversed;
352 0, B.cols(), R,
353 [&](index_t j, auto nj) {
354 const bool first = forward ? j == 0 : j + nj >= B.cols();
355 static constexpr index_constant<SizeS<T, Abi>> S;
356 // Part of A corresponding to this diagonal block
357 auto Bjj = B_.block(j, j);
358 auto Wj = W_t{W_.middle_cols(j / R).data};
359 // Process all rows (in multiples of S).
360 foreach_chunked_merged( // TODO: swap loop order?
361 0, A.cols(), S,
362 [&](index_t i, auto ni) {
363 auto Dji = D_.block(j, i);
364 if (first)
365 microkernel_tail_lut_2<T, Abi, Conf, OA, OD, OB>[nj - 1][ni - 1](
366 k - j, transposed, Wj, A_.block(j, i), Dji, Bjj);
367 else
368 microkernel_tail_lut_2<T, Abi, Conf, OD, OD, OB>[nj - 1][ni - 1](
369 k - j, transposed, Wj, Dji, Dji, Bjj);
370 // TODO: is it better to merge this copy into the next micro-kernel call?
371 if (first && !transposed && D_.data != A_.data)
372 for (index_t l = 0; l < j; ++l)
373 for (index_t ii = i; ii < i + ni; ++ii)
374 D_.store(A_.load(l, ii), l, ii);
375 },
376 LoopDir::Backward); // TODO: decide on order
377 },
379}
380
381} // namespace batmat::linalg::micro_kernels::geqrf
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
int index_t
Definition config.hpp:13
#define UNROLL_FOR(...)
Definition gemm-diag.tpp:10
void foreach_chunked(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, auto func_rem, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each ful...
Definition loop.hpp:20
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
void foreach_chunked_merged(index_t i_begin, index_t i_end, auto chunk_size, auto func_chunk, LoopDir dir=LoopDir::Forward)
Iterate over the range [i_begin, i_end) in chunks of size chunk_size, calling func_chunk for each chu...
Definition loop.hpp:43
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
auto select(auto cond, auto t, auto f)
Definition simd.hpp:245
stdx::simd< Tp, Abi > simd
Definition simd.hpp:148
void geqrf_copy_register(view< const T, Abi, OA > A, view< T, Abi, OD > D, view< T, Abi > W) noexcept
Block hyperbolic Householder factorization update using register blocking.
Definition geqrf.tpp:229
constexpr std::pair< index_t, index_t > geqrf_W_size(view< T, Abi, OA > A)
Definition geqrf.hpp:36
void geqrf_diag_microkernel(index_t k, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< const T, Abi, OA > A, uview< T, Abi, OD > D) noexcept
Definition geqrf.tpp:42
const constinit auto microkernel_full_lut
Definition geqrf.tpp:23
void geqrf_tail_microkernel(index_t k, bool transposed, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< const T, Abi, OA > A, uview< T, Abi, OD > D, uview< const T, Abi, OB > B) noexcept
Definition geqrf.tpp:160
const constinit auto microkernel_tail_lut_2
Definition geqrf.tpp:35
const constinit auto microkernel_tail_lut
Definition geqrf.tpp:29
const constinit auto microkernel_diag_lut
Definition geqrf.tpp:17
void geqrf_full_microkernel(index_t k, uview< const T, Abi, OA > A, uview< T, Abi, OD > D) noexcept
A (k×R) D (k×R).
Definition geqrf.tpp:97
void geqrf_apply_register(view< const T, Abi, OA > A, view< T, Abi, OD > D, view< const T, Abi, OB > B, view< const T, Abi > W, bool transposed, bool reversed) noexcept
Apply a block Householder transformation.
Definition geqrf.tpp:330
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
int index_t
Definition config.hpp:13