batmat 0.0.13
Batched linear algebra routines
Loading...
Searching...
No Matches
example.cpp
Go to the documentation of this file.
5#include <guanaqo/print.hpp>
6#include <algorithm>
7#include <cmath>
8#include <iostream>
9#include <limits>
10#include <random>
11
12using batmat::index_t;
13using batmat::real_t;
14namespace la = batmat::linalg;
15
16int main() {
17 using batch_size = std::integral_constant<index_t, 4>;
18 constexpr auto storage_order = batmat::matrix::StorageOrder::ColMajor;
19 // Class representing a batch of four matrices.
21 // Allocate some batches of matrices (initialized to zero).
22 index_t n = 3, m = n + 5;
23 Mat C{{.rows = n, .cols = n}}, A{{.rows = n, .cols = m}};
24 // Fill A with random values.
25 std::mt19937 rng{12345};
26 std::uniform_real_distribution<real_t> uni{-1.0, 1.0};
27 std::ranges::generate(A, [&] { return uni(rng); });
28 // Compute C = AAᵀ to make it symmetric positive definite (lower triangular part only).
29 la::syrk(A, la::tril(C));
30 // Allocate L for the Cholesky factors.
31 Mat L{{.rows = n, .cols = n}, batmat::matrix::uninitialized};
32 // Compute the Cholesky factors L of C (lower triangular).
33 la::fill(0, la::triu(L));
35 // Print the results.
36 for (index_t l = 0; l < C.depth(); ++l) {
37 guanaqo::print_python(std::cout << "C[" << l << "] =\n", C(l));
38 guanaqo::print_python(std::cout << "L[" << l << "] =\n", L(l));
39 }
40 // Compute LLᵀ (in-place).
42 // Check that LLᵀ == C.
43 int errors = 0;
44 const auto eps = std::numeric_limits<real_t>::epsilon();
45 for (index_t l = 0; l < C.depth(); ++l)
46 for (index_t c = 0; c < C.cols(); ++c)
47 for (index_t r = c; r < C.rows(); ++r)
48 errors += std::abs(C(l, r, c) - L(l, r, c)) < 10 * eps ? 0 : 1;
49 return errors;
50}
int main()
Definition example.cpp:16
std::ostream & print_python(std::ostream &os, std::span< T, E > x, std::string_view end="\n", bool squeeze=true)
void syrk(Structured< VA, SA > A, Structured< VD, SD > D, Opts... opts)
D = A Aᵀ with D symmetric.
Definition gemm.hpp:310
void potrf(Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VC > regularization=0)
D = chol(C) with C symmetric, D triangular.
Definition potrf.hpp:75
void fill(simdified_value_t< VB > a, VB &&B)
B = a.
Definition copy.hpp:204
constexpr auto triu(M &&m)
Upper-triangular view.
constexpr auto tril(M &&m)
Lower-triangular view.
struct batmat::matrix::uninitialized_t uninitialized
Tag type to indicate that memory should not be initialized.
Class for a batch of matrices that owns its storage.
Class for a batch of matrices that owns its storage.
Definition matrix.hpp:52