batmat develop
Batched linear algebra routines
Loading...
Searching...
No Matches
potrf.hpp
Go to the documentation of this file.
1#pragma once
2
10#include <batmat/loop.hpp>
12#include <guanaqo/trace.hpp>
13
14namespace batmat::linalg {
15
16namespace detail {
17template <class T, class Abi, micro_kernels::potrf::KernelConfig Conf, StorageOrder OA,
18 StorageOrder OCD>
19 requires(Conf.struc_C != MatrixStructure::General)
22 // Check dimensions
23 BATMAT_ASSERT(D.rows() >= D.cols());
24 BATMAT_ASSERT(A.cols() == 0 || A.rows() == D.rows());
25 BATMAT_ASSERT(C.rows() == D.rows());
26 BATMAT_ASSERT(C.cols() == D.cols());
27 if constexpr (Conf.with_diag()) {
28 BATMAT_ASSERT(d.rows() == A.cols());
29 BATMAT_ASSERT(d.cols() == 1);
30 }
31 const index_t M = D.rows(), N = D.cols();
32 GUANAQO_TRACE_LINALG("potrf", total(flops::syrk_potrf(M, N, A.cols())) * C.depth());
33 // Degenerate case
34 if (M == 0 || N == 0) [[unlikely]]
35 return;
36 return micro_kernels::potrf::potrf_copy_register<T, Abi, Conf>(A, C, D, regularization, d);
37}
38} // namespace detail
39
40/// @addtogroup topic-linalg
41/// @{
42
43/// @name Cholesky factorization of batches of matrices
44/// @{
45
46/// D = chol(C + AAᵀ) with C symmetric, D triangular
47template <MatrixStructure SC, simdifiable VA, simdifiable VC, simdifiable VD>
50 simdified_value_t<VA> regularization = 0) {
51 detail::potrf<simdified_value_t<VA>, simdified_abi_t<VA>, {.negate_A = false, .struc_C = SC}>(
52 simdify(A).as_const(), simdify(C.value).as_const(), simdify(D.value), regularization);
53}
54/// D = chol(D + AAᵀ) with D symmetric/triangular
55template <MatrixStructure SC, simdifiable VA, simdifiable VD>
58 syrk_add_potrf(A, D.ref(), D.ref());
59}
60
61/// D = chol(C - AAᵀ) with C symmetric, D triangular
62template <MatrixStructure SC, simdifiable VA, simdifiable VC, simdifiable VD>
65 simdified_value_t<VA> regularization = 0) {
66 detail::potrf<simdified_value_t<VA>, simdified_abi_t<VA>, {.negate_A = true, .struc_C = SC}>(
67 simdify(A).as_const(), simdify(C.value).as_const(), simdify(D.value), regularization);
68}
69/// D = chol(D - AAᵀ) with D symmetric/triangular
70template <MatrixStructure SC, simdifiable VA, simdifiable VD>
72void syrk_sub_potrf(VA &&A, Structured<VD, SC> D, simdified_value_t<VA> regularization = 0) {
73 syrk_sub_potrf(A, D.ref(), D.ref(), regularization);
74}
75
76/// D = chol(C + A diag(d) Aᵀ) with C symmetric, D triangular
77template <MatrixStructure SC, simdifiable VA, simdifiable VC, simdifiable VD, simdifiable Vd>
80 simdified_value_t<VA> regularization = 0) {
83 {.negate_A = false, .diag_A = KernelConfig::diag, .struc_C = SC}>(
84 simdify(A).as_const(), simdify(C.value).as_const(), simdify(D.value), regularization,
85 simdify(d).as_const());
86}
87/// D = chol(D + A diag(d) Aᵀ) with D symmetric/triangular
88template <MatrixStructure SC, simdifiable VA, simdifiable VD, simdifiable Vd>
90void syrk_diag_add_potrf(VA &&A, Structured<VD, SC> D, Vd &&d) {
91 syrk_diag_add_potrf(A, D.ref(), D.ref(), d);
92}
93
94/// D = chol(C) with C symmetric, D triangular
95template <MatrixStructure SC, simdifiable VC, simdifiable VD>
98 decltype(simdify(C.value).as_const()) null{{.data = nullptr, .rows = 0, .cols = 0}};
100 null, simdify(C.value).as_const(), simdify(D.value), regularization);
101}
102
103/// D = chol(D) with D symmetric/triangular
104template <MatrixStructure SD, simdifiable VD>
105void potrf(Structured<VD, SD> D, simdified_value_t<VD> regularization = 0) {
106 decltype(simdify(D.value).as_const()) null{{.data = nullptr, .rows = 0, .cols = 0}};
108 null, simdify(D.value).as_const(), simdify(D.value), regularization);
109}
110
111/// @}
112
113/// @}
114
115} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
constexpr FlopCount syrk_potrf(index_t m, index_t n, index_t k)
Fused symmetric rank-k update and Cholesky factorization of an m×n matrix with m≥n.
Definition flops.hpp:182
void syrk_add_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VA > regularization=0)
D = chol(C + AAᵀ) with C symmetric, D triangular.
Definition potrf.hpp:49
void syrk_sub_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VA > regularization=0)
D = chol(C - AAᵀ) with C symmetric, D triangular.
Definition potrf.hpp:64
void potrf(Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VC > regularization=0)
D = chol(C) with C symmetric, D triangular.
Definition potrf.hpp:97
void syrk_diag_add_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, Vd &&d, simdified_value_t< VA > regularization=0)
D = chol(C + A diag(d) Aᵀ) with C symmetric, D triangular.
Definition potrf.hpp:79
#define GUANAQO_TRACE_LINALG(name, gflops)
void potrf(view< const T, Abi, OA > A, view< const T, Abi, OCD > C, view< T, Abi, OCD > D, T regularization, micro_kernels::potrf::diag_view_type< const T, Abi, Conf > d={})
Definition potrf.hpp:20
void potrf_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OCD > C, view< T, Abi, OCD > D, T regularization, diag_view_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:228
std::conditional_t< Conf.with_diag(), view< T, Abi >, std::false_type > diag_view_type
Definition potrf.hpp:24
typename detail::simdified_value< V >::type simdified_value_t
Definition simdify.hpp:202
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:204
constexpr bool simdify_compatible
Definition simdify.hpp:207
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:214
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Aligned allocation for matrix storage.
Light-weight wrapper class used for overload resolution of triangular and symmetric matrices.