batmat 0.0.14
Batched linear algebra routines
Loading...
Searching...
No Matches
flops.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <batmat/assume.hpp>
4#include <batmat/config.hpp>
6#include <algorithm>
7
9
10/// @addtogroup topic-linalg-flops
11/// @{
12
13/// Count of individual floating point operations, broken down by type.
14struct FlopCount {
15 index_t fma = 0;
16 index_t mul = 0;
17 index_t add = 0;
18 index_t div = 0;
19 index_t sqrt = 0;
20};
21
22/// Combine two flop counts by summing the counts of each operation type.
24 return {.fma = a.fma + b.fma,
25 .mul = a.mul + b.mul,
26 .add = a.add + b.add,
27 .div = a.div + b.div,
28 .sqrt = a.sqrt + b.sqrt};
29}
30
31/// Compute the total number of floating point operations by summing the counts of all operation
32/// types.
33constexpr index_t total(FlopCount c) { return c.fma + c.mul + c.add + c.div + c.sqrt; }
34
35/// Matrix-matrix multiplication of m×k and k×n matrices.
36/// @implementation{flops-gemm}
37// [flops-gemm]
38constexpr FlopCount gemm(index_t m, index_t n, index_t k) { return {.fma = m * k * n}; }
39// [flops-gemm]
40
41/// Matrix-matrix multiplication of m×k and k×n matrices where one or more of the matrices
42/// are triangular or trapezoidal.
43/// @implementation{flops-trmm}
44// [flops-trmm]
45constexpr FlopCount trmm(index_t m, index_t n, index_t k, MatrixStructure sA, MatrixStructure sB,
46 MatrixStructure sC) {
47 using enum MatrixStructure;
48 if (sB == General && sC == General) {
49 if (sA == LowerTriangular || sA == UpperTriangular) { // trapezoidal A
50 if (m >= k) // tall A
51 // x x x x x x x x x x x x
52 // x x x x x x x x x x x x x
53 // x x x x x x x x x x x x x x
54 // x x x x x
55 // x x x x
56 return {.fma = k * (k + 1) / 2 * n + (m - k) * k * n};
57 else // wide A
58 // x x x x x x x x x x x x x x
59 // x x x x x x x x x x x x x x
60 // x x x x x x x x x x x x x x
61 // x x x x x x
62 // x x x x x x
63 return {.fma = m * (m + 1) / 2 * n + (k - m) * (k - m) * n};
64 } else if (sA == General) {
65 return {.fma = m * k * n};
66 } else {
67 BATMAT_ASSUME(!"invalid structure");
68 }
69 } else if (sA == General && sC == General) {
70 if (sB == LowerTriangular || sB == UpperTriangular) { // trapezoidal B
71 if (n >= k) // wide B
72 return {.fma = k * (k + 1) / 2 * m + (n - k) * k * m};
73 else // tall B
74 return {.fma = n * (n + 1) / 2 * m + (k - n) * (k - n) * m};
75 } else {
76 BATMAT_ASSUME(!"invalid structure");
77 }
78 } else if (sA == General && sB == General) {
79 if (sC == LowerTriangular || sC == UpperTriangular) {
80 BATMAT_ASSUME(m == n);
81 return {.fma = m * (m + 1) / 2 * k};
82 } else {
83 BATMAT_ASSUME(!"invalid structure");
84 }
85 } else if (sC == LowerTriangular || sC == UpperTriangular) {
86 if (sA == transpose(sB)) {
87 BATMAT_ASSUME(m == n);
88 BATMAT_ASSUME(m == k);
89 return {.fma = m * (m + 1) * (m + 2) / 6};
90 } else {
91 BATMAT_ASSUME(!"invalid structure");
92 }
93 } else {
94 BATMAT_ASSUME(!"unsupported structure");
95 }
96 return {};
97}
98// [flops-trmm]
99
100/// Matrix-matrix multiplication of m×k and k×n matrices where the result is symmetric.
101/// @implementation{flops-gemmt}
102// [flops-gemmt]
103constexpr FlopCount gemmt(index_t m, index_t n, index_t k, MatrixStructure sA, MatrixStructure sB,
104 MatrixStructure sC) {
105 return trmm(m, n, k, sA, sB, sC);
106}
107// [flops-gemmt]
108
109/// Symmetric rank-k update of n×n matrices.
110/// @implementation{flops-syrk}
111// [flops-syrk]
112constexpr FlopCount syrk(index_t n, index_t k) {
115}
116// [flops-syrk]
117
118/// Matrix-matrix multiplication of m×k and k×n matrices with a diagonal k×k matrix in the middle,
119/// where the result is symmetric.
120/// @implementation{flops-gemmt-diag}
121// [flops-gemmt-diag]
122constexpr FlopCount gemmt_diag(index_t m, index_t n, index_t k, MatrixStructure sC) {
123 constexpr auto sA = MatrixStructure::General, sB = sA;
124 return trmm(m, n, k, sA, sB, sC) + FlopCount{.mul = std::min(m, n) * k};
125}
126// [flops-gemmt-diag]
127
128/// Cholesky factorization and triangular solve for an m×n matrix with m≥n.
129/// @implementation{flops-potrf}
130// [flops-potrf]
131constexpr FlopCount potrf(index_t m, index_t n) {
132 BATMAT_ASSUME(m >= n);
133 return {
134 .fma = (n + 1) * n * (n - 1) / 6 // Schur complement (square)
135 + (m - n) * n * (n - 1) / 2, // (bottom)
136 .mul = n * (n - 1) / 2 // multiplication by inverse pivot (square)
137 + (m - n) * n, // (bottom)
138 .div = n, // inverting pivot
139 .sqrt = n, // square root pivot
140 };
141}
142// [flops-potrf]
143
144/// Hyperbolic Householder factorization update with L n×n and A nr×k.
145/// @implementation{flops-hyh-square}
146// [flops-hyh-square]
147constexpr FlopCount hyh_square(index_t n, index_t k) {
148 return {
149 .fma = k * n * n + 2 * n,
150 .mul = k * n + (n + 1) * n / 2 + n,
151 .add = (n + 1) * n / 2 + n,
152 .div = 2 * n,
153 .sqrt = n,
154 };
155}
156// [flops-hyh-square]
157
158/// Hyperbolic Householder factorization application to L2 nr×nc and A2 nr×k.
159/// @implementation{flops-hyh-apply}
160// [flops-hyh-apply]
161constexpr FlopCount hyh_apply(index_t nr, index_t nc, index_t k) {
162 return {
163 .fma = 2 * nr * k * nc,
164 .mul = nr * nc,
165 .add = nr * nc,
166 };
167}
168// [flops-hyh-apply]
169
170/// Hyperbolic Householder factorization update with L nr×nc and A nr×k.
171/// @implementation{flops-hyh}
172// [flops-hyh]
173constexpr FlopCount hyh(index_t nr, index_t nc, index_t k) {
174 BATMAT_ASSUME(nr >= nc);
175 return hyh_square(nc, k) + hyh_apply(nr - nc, nc, k);
176}
177// [flops-hyh]
178
179/// Fused symmetric rank-k update and Cholesky factorization of an m×n matrix with m≥n.
180/// @implementation{flops-syrk-potrf}
181// [flops-syrk-potrf]
182constexpr FlopCount syrk_potrf(index_t m, index_t n, index_t k) {
183 BATMAT_ASSUME(m >= n);
184 return potrf(m, n) + FlopCount{.fma = n * (n + 1) * k / 2 + (m - n) * n * k};
185}
186// [flops-syrk-potrf]
187
188/// Triangular solve of m×n matrices.
189/// @implementation{flops-trsm}
190// [flops-trsm]
191constexpr FlopCount trsm(index_t m, index_t n) {
192 return {.fma = m * (m - 1) * n / 2, .mul = m * n, .div = m};
193}
194// [flops-trsm]
195
196/// Triangular inversion of an m×m matrix.
197/// @implementation{flops-trtri}
198// [flops-trtri]
199constexpr FlopCount trtri(index_t m) {
200 return {.fma = (m + 1) * m * (m - 1) / 6, .div = m}; // TODO
201}
202// [flops-trtri]
203
204/// @}
205
206} // namespace batmat::linalg::flops
#define BATMAT_ASSUME(x)
Invokes undefined behavior if the expression x does not evaluate to true.
Definition assume.hpp:17
constexpr FlopCount hyh(index_t nr, index_t nc, index_t k)
Hyperbolic Householder factorization update with L nr×nc and A nr×k.
Definition flops.hpp:173
constexpr FlopCount gemmt_diag(index_t m, index_t n, index_t k, MatrixStructure sC)
Matrix-matrix multiplication of m×k and k×n matrices with a diagonal k×k matrix in the middle,...
Definition flops.hpp:122
constexpr index_t total(FlopCount c)
Compute the total number of floating point operations by summing the counts of all operation types.
Definition flops.hpp:33
constexpr FlopCount syrk_potrf(index_t m, index_t n, index_t k)
Fused symmetric rank-k update and Cholesky factorization of an m×n matrix with m≥n.
Definition flops.hpp:182
constexpr FlopCount hyh_square(index_t n, index_t k)
Hyperbolic Householder factorization update with L n×n and A nr×k.
Definition flops.hpp:147
constexpr FlopCount trtri(index_t m)
Triangular inversion of an m×m matrix.
Definition flops.hpp:199
constexpr FlopCount gemm(index_t m, index_t n, index_t k)
Matrix-matrix multiplication of m×k and k×n matrices.
Definition flops.hpp:38
constexpr FlopCount trmm(index_t m, index_t n, index_t k, MatrixStructure sA, MatrixStructure sB, MatrixStructure sC)
Matrix-matrix multiplication of m×k and k×n matrices where one or more of the matrices are triangular...
Definition flops.hpp:45
constexpr FlopCount operator+(FlopCount a, FlopCount b)
Combine two flop counts by summing the counts of each operation type.
Definition flops.hpp:23
constexpr FlopCount potrf(index_t m, index_t n)
Cholesky factorization and triangular solve for an m×n matrix with m≥n.
Definition flops.hpp:131
constexpr FlopCount hyh_apply(index_t nr, index_t nc, index_t k)
Hyperbolic Householder factorization application to L2 nr×nc and A2 nr×k.
Definition flops.hpp:161
constexpr FlopCount trsm(index_t m, index_t n)
Triangular solve of m×n matrices.
Definition flops.hpp:191
constexpr FlopCount syrk(index_t n, index_t k)
Symmetric rank-k update of n×n matrices.
Definition flops.hpp:112
constexpr FlopCount gemmt(index_t m, index_t n, index_t k, MatrixStructure sA, MatrixStructure sB, MatrixStructure sC)
Matrix-matrix multiplication of m×k and k×n matrices where the result is symmetric.
Definition flops.hpp:103
Count of individual floating point operations, broken down by type.
Definition flops.hpp:14
constexpr MatrixStructure transpose(MatrixStructure s)
Definition structure.hpp:11