batmat 0.0.15
Batched linear algebra routines
Loading...
Searching...
No Matches
hyhound.hpp
Go to the documentation of this file.
1#pragma once
2
5#include <batmat/lut.hpp>
7#include <batmat/simd.hpp>
8
10
12 bool sign_only = false;
13};
14
15template <class T, class Abi, index_t R>
17 using value_type = T;
19
21 static constexpr ptrdiff_t inner_stride =
23
24 static constexpr index_t num_elem_per_layer() { return R * (R + 1) / 2; }
25 static constexpr size_t size() {
26 return simd::size() * static_cast<size_t>(num_elem_per_layer());
27 }
28 static constexpr size_t alignment() {
30 }
31
32 [[gnu::always_inline]] value_type &operator()(index_t r, index_t c) const noexcept {
33 assert(r <= c);
34 return data[(r + c * (c + 1) / 2) * inner_stride];
35 }
36 [[gnu::always_inline]] simd load(index_t r, index_t c) const noexcept {
37 return datapar::aligned_load<simd>(&operator()(r, c));
38 }
39 [[gnu::always_inline]] void store(simd x, index_t r, index_t c) const noexcept
40 requires(!std::is_const_v<T>)
41 {
42 datapar::aligned_store(x, &operator()(r, c));
43 }
44
45 [[gnu::always_inline]] triangular_accessor(value_type *data) noexcept : data{data} {}
46 operator triangular_accessor<const T, Abi, R>() const noexcept { return {data}; }
47};
48
49template <class T, class Abi>
50inline constexpr index_t SizeR = gemm::RowsReg<T, Abi>; // TODO
51template <class T, class Abi>
52inline constexpr index_t SizeS = gemm::RowsReg<T, Abi>; // TODO
53
54template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
58
59template <class T, class Abi, KernelConfig Conf, index_t R, StorageOrder OL, StorageOrder OA>
62
63enum class Structure {
65 Zero = 1,
66 Upper = 2,
67};
68
69template <class T, class Abi, KernelConfig Conf, index_t R, index_t S, StorageOrder OL,
71void hyhound_diag_tail_microkernel(index_t kA_in_offset, index_t kA_in, index_t k,
72 triangular_accessor<const T, Abi, SizeR<T, Abi>> W,
76 Structure struc_L, int rotate_A) noexcept;
77
78template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
79inline const constinit auto microkernel_diag_lut =
82 });
83
84template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA>
85inline const constinit auto microkernel_full_lut =
88 });
89
90template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OB>
91inline const constinit auto microkernel_tail_lut =
94 });
95
96template <class T, class Abi, KernelConfig Conf, StorageOrder OL, StorageOrder OA, StorageOrder OB>
98 []<index_t NR, index_t NS>(index_constant<NR>, index_constant<NS>) {
100 });
101
102// Helper function to compute size of the storage for the matrix W (part of the hyperbolic
103// Householder representation).
104template <class T, class Abi, StorageOrder OL>
105constexpr std::pair<index_t, index_t> hyhound_W_size(view<T, Abi, OL> L) {
106 static constexpr index_constant<SizeR<std::remove_const_t<T>, Abi>> R;
108 return {W_t::num_elem_per_layer(), (L.cols() + R - 1) / R};
109}
110
111// Low-level register-blocked routines
112template <class T, class Abi, KernelConfig Conf = {}, StorageOrder OL = StorageOrder::ColMajor,
113 StorageOrder OA = StorageOrder::ColMajor>
115
116template <class T, class Abi, KernelConfig Conf = {}, StorageOrder OL = StorageOrder::ColMajor,
117 StorageOrder OA = StorageOrder::ColMajor>
119 view<T, Abi> W) noexcept;
120
121template <class T, class Abi, KernelConfig Conf = {}, StorageOrder OL = StorageOrder::ColMajor,
122 StorageOrder OA = StorageOrder::ColMajor>
126 index_t kA_in_offset = 0) noexcept;
127
128template <class T, class Abi, StorageOrder OL1 = StorageOrder::ColMajor,
129 StorageOrder OA1 = StorageOrder::ColMajor, StorageOrder OL2 = StorageOrder::ColMajor,
130 StorageOrder OA2 = StorageOrder::ColMajor, KernelConfig Conf = {}>
133
134template <class T, class Abi, StorageOrder OL = StorageOrder::ColMajor,
135 StorageOrder OW = StorageOrder::ColMajor, StorageOrder OY = StorageOrder::ColMajor,
136 StorageOrder OU = StorageOrder::ColMajor, KernelConfig Conf = {}>
140 view<T, Abi, OW> A3_out, view<const T, Abi> D) noexcept;
141
142template <class T, class Abi, StorageOrder OL = StorageOrder::ColMajor,
143 StorageOrder OA = StorageOrder::ColMajor, StorageOrder OLu = StorageOrder::ColMajor,
144 StorageOrder OAu = StorageOrder::ColMajor, KernelConfig Conf = {}>
148 view<const T, Abi> D, bool shift_A_out) noexcept;
149
150} // namespace batmat::linalg::micro_kernels::hyhound
consteval auto make_1d_lut(F f)
Returns an array of the form:
Definition lut.hpp:39
consteval auto make_2d_lut(F f)
Returns a 2D array of the form:
Definition lut.hpp:25
void aligned_store(V v, typename V::value_type *p)
Definition simd.hpp:121
stdx::memory_alignment< simd< Tp, Abi > > simd_align
Definition simd.hpp:139
stdx::simd_size< Tp, Abi > simd_size
Definition simd.hpp:137
V aligned_load(const typename V::value_type *p)
Definition simd.hpp:111
stdx::simd< Tp, Abi > simd
Definition simd.hpp:99
constexpr index_t RowsReg
Register block size of the matrix-matrix multiplication micro-kernels.
Definition avx-512.hpp:13
void hyhound_diag_full_microkernel(index_t kA, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
Definition hyhound.tpp:72
const constinit auto microkernel_full_lut
Definition hyhound.hpp:85
void hyhound_diag_cyclic_register(view< T, Abi, OL > L11, view< T, Abi, OW > A1, view< T, Abi, OY > L21, view< const T, Abi, OW > A22, view< T, Abi, OW > A2_out, view< T, Abi, OU > L31, view< const T, Abi, OW > A31, view< T, Abi, OW > A3_out, view< const T, Abi > D) noexcept
Performs a factorization update of the following matrix:
Definition hyhound.tpp:490
void hyhound_diag_register(view< T, Abi, OL > L, view< T, Abi, OA > A, view< const T, Abi > D) noexcept
Block hyperbolic Householder factorization update using register blocking.
Definition hyhound.tpp:264
void hyhound_diag_riccati_register(view< T, Abi, OL > L11, view< T, Abi, OA > A1, view< T, Abi, OL > L21, view< const T, Abi, OA > A2, view< T, Abi, OA > A2_out, view< T, Abi, OLu > Lu1, view< T, Abi, OAu > Au_out, view< const T, Abi > D, bool shift_A_out) noexcept
Performs a factorization update of the following matrix:
Definition hyhound.tpp:584
void hyhound_diag_tail_microkernel(index_t kA_in_offset, index_t kA_in, index_t k, triangular_accessor< const T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< const T, Abi, OA > A_in, uview< T, Abi, OA > A_out, uview< const T, Abi, OB > B, uview< const T, Abi, StorageOrder::ColMajor > diag, Structure struc_L, int rotate_A) noexcept
Definition hyhound.tpp:136
constexpr std::pair< index_t, index_t > hyhound_W_size(view< T, Abi, OL > L)
Definition hyhound.hpp:105
void hyhound_diag_diag_microkernel(index_t kA, triangular_accessor< T, Abi, SizeR< T, Abi > > W, uview< T, Abi, OL > L, uview< T, Abi, OA > A, uview< const T, Abi, StorageOrder::ColMajor > diag) noexcept
Definition hyhound.tpp:19
void hyhound_diag_apply_register(view< T, Abi, OL > L, view< const T, Abi, OA > Ain, view< T, Abi, OA > Aout, view< const T, Abi, OA > B, view< const T, Abi > D, view< const T, Abi > W, index_t kA_in_offset=0) noexcept
Apply a block hyperbolic Householder transformation.
Definition hyhound.tpp:377
const constinit auto microkernel_tail_lut_2
Definition hyhound.hpp:97
const constinit auto microkernel_tail_lut
Definition hyhound.hpp:91
const constinit auto microkernel_diag_lut
Definition hyhound.hpp:79
void hyhound_diag_2_register(view< T, Abi, OL1 > L11, view< T, Abi, OA1 > A1, view< T, Abi, OL2 > L21, view< T, Abi, OA2 > A2, view< const T, Abi > D) noexcept
Same as hyhound_diag_register but for two block rows at once.
Definition hyhound.tpp:428
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
Definition uview.hpp:70
std::integral_constant< index_t, I > index_constant
Definition lut.hpp:10
simd load(index_t r, index_t c) const noexcept
Definition hyhound.hpp:36
datapar::simd< std::remove_const_t< T >, Abi > simd
Definition hyhound.hpp:20
void store(simd x, index_t r, index_t c) const noexcept
Definition hyhound.hpp:39
value_type & operator()(index_t r, index_t c) const noexcept
Definition hyhound.hpp:32