batmat 0.0.15
Batched linear algebra routines
Loading...
Searching...
No Matches
potrf.hpp
Go to the documentation of this file.
1#pragma once
2
5#include <batmat/lut.hpp>
7
9
11 bool negate_A = false; ///< Whether to compute chol(C - AAᵀ) instead of chol(C + AAᵀ)
12 enum {
13 none, ///< chol(C ± AAᵀ)
14 diag, ///< chol(C ± AΣAᵀ) with Σ diagonal
15 diag_sign_only, ///< chol(C ± AΣAᵀ) with Σ diagonal and containing only ±0 (just sign bits)
18 [[nodiscard]] constexpr bool with_diag() const noexcept { return diag_A != none; }
19};
20
21template <class T, class Abi, KernelConfig Conf>
22using diag_uview_type = std::conditional_t<Conf.with_diag(), uview_vec<T, Abi>, std::false_type>;
23template <class T, class Abi, KernelConfig Conf>
24using diag_view_type = std::conditional_t<Conf.with_diag(), view<T, Abi>, std::false_type>;
25
26template <class T, class Abi, KernelConfig Conf, index_t RowsReg, StorageOrder O1, StorageOrder O2>
28 uview<const T, Abi, O2> C, uview<T, Abi, O2> D, T *invD, index_t k1,
29 index_t k2, T regularization,
31
32template <class T, class Abi, KernelConfig Conf, index_t RowsReg, index_t ColsReg, StorageOrder O1,
33 StorageOrder O2>
37 uview<T, Abi, O2> D, index_t k1, index_t k2,
39
40template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OCD>
42 T regularization, diag_view_type<const T, Abi, Conf> diag) noexcept;
43
44// Square block sizes greatly simplify handling of triangular matrices.
45using gemm::RowsReg;
46template <class T, class Abi>
47constexpr index_t ColsReg = RowsReg<T, Abi>;
48
49template <class T, class Abi, KernelConfig Conf, StorageOrder OA, StorageOrder OC>
50inline const constinit auto potrf_copy_lut =
53 });
54
55template <class T, class Abi, KernelConfig Conf, StorageOrder O1, StorageOrder O2>
57 []<index_t Row, index_t Col>(index_constant<Row>, index_constant<Col>) {
59 });
60
61} // namespace batmat::linalg::micro_kernels::potrf
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
const constinit auto potrf_copy_lut
Definition potrf.hpp:50
std::conditional_t< Conf.with_diag(), uview_vec< T, Abi >, std::false_type > diag_uview_type
Definition potrf.hpp:22
void trsm_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O1 > B1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > B2, uview< const T, Abi, O2 > L, const T *invL, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, index_t k1, index_t k2, diag_uview_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:149
void potrf_copy_register(view< const T, Abi, OA > A, view< const T, Abi, OCD > C, view< T, Abi, OCD > D, T regularization, diag_view_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:211
const constinit auto trsm_copy_lut
Definition potrf.hpp:56
std::conditional_t< Conf.with_diag(), view< T, Abi >, std::false_type > diag_view_type
Definition potrf.hpp:24
void potrf_copy_microkernel(uview< const T, Abi, O1 > A1, uview< const T, Abi, O2 > A2, uview< const T, Abi, O2 > C, uview< T, Abi, O2 > D, T *invD, index_t k1, index_t k2, T regularization, diag_uview_type< const T, Abi, Conf > diag) noexcept
Definition potrf.tpp:44
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
enum batmat::linalg::micro_kernels::potrf::KernelConfig::@251003062227232146206006340175351212340342314017 diag_A
@ diag_sign_only
chol(C ± AΣAᵀ) with Σ diagonal and containing only ±0 (just sign bits)
Definition potrf.hpp:15
@ diag
chol(C ± AΣAᵀ) with Σ diagonal
Definition potrf.hpp:14
constexpr bool with_diag() const noexcept
Definition potrf.hpp:18
bool negate_A
Whether to compute chol(C - AAᵀ) instead of chol(C + AAᵀ).
Definition potrf.hpp:11