batmat 0.0.15
Batched linear algebra routines
Loading...
Searching...
No Matches
potrf.hpp
Go to the documentation of this file.
1#pragma once
2
10#include <batmat/loop.hpp>
12#include <guanaqo/trace.hpp>
13
14namespace batmat::linalg {
15
16namespace detail {
17template <class T, class Abi, micro_kernels::potrf::KernelConfig Conf, StorageOrder OA,
18 StorageOrder OCD>
19 requires(Conf.struc_C != MatrixStructure::General)
21 T regularization,
23 // Check dimensions
24 BATMAT_ASSERT(D.rows() >= D.cols());
25 BATMAT_ASSERT(A.cols() == 0 || A.rows() == D.rows());
26 BATMAT_ASSERT(C.rows() == D.rows());
27 BATMAT_ASSERT(C.cols() == D.cols());
28 if constexpr (Conf.with_diag()) {
29 BATMAT_ASSERT(d.rows() == A.cols());
30 BATMAT_ASSERT(d.cols() == 1);
31 }
32 const index_t M = D.rows(), N = D.cols();
33 GUANAQO_TRACE_LINALG("potrf", total(flops::syrk_potrf(M, N, A.cols())) * C.depth());
34 // Degenerate case
35 if (M == 0 || N == 0) [[unlikely]]
36 return;
37 return micro_kernels::potrf::potrf_copy_register<T, Abi, Conf>(A, C, D, regularization, d);
38}
39} // namespace detail
40
41/// @addtogroup topic-linalg
42/// @{
43
44/// @name Cholesky factorization of batches of matrices
45/// @{
46
47/// D = chol(C + AAᵀ) with C symmetric, D triangular
48template <MatrixStructure SC, simdifiable VA, simdifiable VC, simdifiable VD>
51 simdified_value_t<VA> regularization = 0) {
52 detail::potrf<simdified_value_t<VA>, simdified_abi_t<VA>, {.negate_A = false, .struc_C = SC}>(
53 simdify(A).as_const(), simdify(C.value).as_const(), simdify(D.value), regularization);
54}
55/// D = chol(D + AAᵀ) with D symmetric/triangular
56template <MatrixStructure SC, simdifiable VA, simdifiable VD>
59 syrk_add_potrf(A, D.ref(), D.ref());
60}
61
62/// D = chol(C - AAᵀ) with C symmetric, D triangular
63template <MatrixStructure SC, simdifiable VA, simdifiable VC, simdifiable VD>
66 simdified_value_t<VA> regularization = 0) {
67 detail::potrf<simdified_value_t<VA>, simdified_abi_t<VA>, {.negate_A = true, .struc_C = SC}>(
68 simdify(A).as_const(), simdify(C.value).as_const(), simdify(D.value), regularization);
69}
70/// D = chol(D - AAᵀ) with D symmetric/triangular
71template <MatrixStructure SC, simdifiable VA, simdifiable VD>
73void syrk_sub_potrf(VA &&A, Structured<VD, SC> D, simdified_value_t<VA> regularization = 0) {
74 syrk_sub_potrf(A, D.ref(), D.ref(), regularization);
75}
76
77/// D = chol(C + A diag(d) Aᵀ) with C symmetric, D triangular
78template <MatrixStructure SC, simdifiable VA, simdifiable VC, simdifiable VD, simdifiable Vd>
81 simdified_value_t<VA> regularization = 0) {
84 {.negate_A = false, .diag_A = KernelConfig::diag, .struc_C = SC}>(
85 simdify(A).as_const(), simdify(C.value).as_const(), simdify(D.value), regularization,
86 simdify(d).as_const());
87}
88/// D = chol(D + A diag(d) Aᵀ) with D symmetric/triangular
89template <MatrixStructure SC, simdifiable VA, simdifiable VD, simdifiable Vd>
91void syrk_diag_add_potrf(VA &&A, Structured<VD, SC> D, Vd &&d) {
92 syrk_diag_add_potrf(A, D.ref(), D.ref(), d);
93}
94
95/// D = chol(C) with C symmetric, D triangular
96template <MatrixStructure SC, simdifiable VC, simdifiable VD>
99 decltype(simdify(C.value).as_const()) null{{.data = nullptr, .rows = 0, .cols = 0}};
101 null, simdify(C.value).as_const(), simdify(D.value), regularization);
102}
103
104/// D = chol(D) with D symmetric/triangular
105template <MatrixStructure SD, simdifiable VD>
106void potrf(Structured<VD, SD> D, simdified_value_t<VD> regularization = 0) {
107 decltype(simdify(D.value).as_const()) null{{.data = nullptr, .rows = 0, .cols = 0}};
109 null, simdify(D.value).as_const(), simdify(D.value), regularization);
110}
111
112/// @}
113
114/// @}
115
116} // namespace batmat::linalg
#define BATMAT_ASSERT(x)
Definition assume.hpp:14
constexpr FlopCount syrk_potrf(index_t m, index_t n, index_t k)
Fused symmetric rank-k update and Cholesky factorization of an m×n matrix with m≥n.
Definition flops.hpp:182
void syrk_add_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VA > regularization=0)
D = chol(C + AAᵀ) with C symmetric, D triangular.
Definition potrf.hpp:50
void syrk_sub_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VA > regularization=0)
D = chol(C - AAᵀ) with C symmetric, D triangular.
Definition potrf.hpp:65
void potrf(Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VC > regularization=0)
D = chol(C) with C symmetric, D triangular.
Definition potrf.hpp:98
void syrk_diag_add_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, Vd &&d, simdified_value_t< VA > regularization=0)
D = chol(C + A diag(d) Aᵀ) with C symmetric, D triangular.
Definition potrf.hpp:80
#define GUANAQO_TRACE_LINALG(name, gflops)
void potrf(view< const T, Abi, OA > A, view< const T, Abi, OCD > C, view< T, Abi, OCD > D, T regularization, micro_kernels::potrf::diag_view_type< const T, Abi, Conf > d={}) noexcept
Definition potrf.hpp:20
void potrf_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OCD > C, view< T, Abi, OCD > D, T regularization, diag_view_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:211
std::conditional_t< Conf.with_diag(), view< T, Abi >, std::false_type > diag_view_type
Definition potrf.hpp:24
typename detail::simdified_value< V >::type simdified_value_t
Definition simdify.hpp:202
typename detail::simdified_abi< V >::type simdified_abi_t
Definition simdify.hpp:204
constexpr bool simdify_compatible
Definition simdify.hpp:207
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
Definition simdify.hpp:214
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
Aligned allocation for matrix storage.
Light-weight wrapper class used for overload resolution of triangular and symmetric matrices.