guanaqo 1.0.0-alpha.24
Utilities for scientific software
Loading...
Searching...
No Matches
blas-interface.cpp
Go to the documentation of this file.
2#include <guanaqo/blas/export.h>
4#include <guanaqo/openmp.h>
5
6namespace guanaqo::blas {
7
8template <>
9GUANAQO_BLAS_EXPORT void
10xgemv(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, blas_index_t M,
11 blas_index_t N, std::type_identity_t<double> alpha,
12 std::type_identity_t<const double *> A, blas_index_t lda,
13 std::type_identity_t<const double *> X, blas_index_t incX,
14 std::type_identity_t<double> beta, double *Y, blas_index_t incY) {
15 cblas_dgemv(Layout, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
16}
17template <>
18GUANAQO_BLAS_EXPORT void
19xgemv(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, blas_index_t M,
20 blas_index_t N, std::type_identity_t<float> alpha,
21 std::type_identity_t<const float *> A, blas_index_t lda,
22 std::type_identity_t<const float *> X, blas_index_t incX,
23 std::type_identity_t<float> beta, float *Y, blas_index_t incY) {
24 cblas_sgemv(Layout, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
25}
26
27template <>
28GUANAQO_BLAS_EXPORT void
29xgemm(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
30 index_t M, index_t N, index_t K, std::type_identity_t<double> alpha,
31 std::type_identity_t<const double *> A, index_t lda,
32 std::type_identity_t<const double *> B, index_t ldb,
33 std::type_identity_t<double> beta, double *C, index_t ldc) {
34 cblas_dgemm(Layout, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
35 ldc);
36}
37template <>
38GUANAQO_BLAS_EXPORT void
39xgemm(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
40 index_t M, index_t N, index_t K, std::type_identity_t<float> alpha,
41 std::type_identity_t<const float *> A, index_t lda,
42 std::type_identity_t<const float *> B, index_t ldb,
43 std::type_identity_t<float> beta, float *C, index_t ldc) {
44 cblas_sgemm(Layout, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
45 ldc);
46}
47
48template <>
49GUANAQO_BLAS_EXPORT void
50xgemmt(CBLAS_LAYOUT Layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE TransA,
51 CBLAS_TRANSPOSE TransB, index_t N, index_t K,
52 std::type_identity_t<double> alpha,
53 std::type_identity_t<const double *> A, index_t lda,
54 std::type_identity_t<const double *> B, index_t ldb,
55 std::type_identity_t<double> beta, double *C, index_t ldc) {
56 cblas_dgemmt(Layout, uplo, TransA, TransB, N, K, alpha, A, lda, B, ldb,
57 beta, C, ldc);
58}
59template <>
60GUANAQO_BLAS_EXPORT void
61xgemmt(CBLAS_LAYOUT Layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE TransA,
62 CBLAS_TRANSPOSE TransB, index_t N, index_t K,
63 std::type_identity_t<float> alpha, std::type_identity_t<const float *> A,
64 index_t lda, std::type_identity_t<const float *> B, index_t ldb,
65 std::type_identity_t<float> beta, float *C, index_t ldc) {
66 cblas_sgemmt(Layout, uplo, TransA, TransB, N, K, alpha, A, lda, B, ldb,
67 beta, C, ldc);
68}
69
70template <>
71GUANAQO_BLAS_EXPORT void
72xsymv(CBLAS_LAYOUT Layout, CBLAS_UPLO uplo, blas_index_t N,
73 std::type_identity_t<double> alpha,
74 std::type_identity_t<const double *> A, blas_index_t lda,
75 std::type_identity_t<const double *> X, blas_index_t incX,
76 std::type_identity_t<double> beta, double *Y, blas_index_t incY) {
77 cblas_dsymv(Layout, uplo, N, alpha, A, lda, X, incX, beta, Y, incY);
78}
79template <>
80GUANAQO_BLAS_EXPORT void
81xsymv(CBLAS_LAYOUT Layout, CBLAS_UPLO uplo, blas_index_t N,
82 std::type_identity_t<float> alpha, std::type_identity_t<const float *> A,
83 blas_index_t lda, std::type_identity_t<const float *> X,
84 blas_index_t incX, std::type_identity_t<float> beta, float *Y,
85 blas_index_t incY) {
86 cblas_ssymv(Layout, uplo, N, alpha, A, lda, X, incX, beta, Y, incY);
87}
88
89template <>
90GUANAQO_BLAS_EXPORT void
91xtrmv(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA,
92 CBLAS_DIAG Diag, index_t N, std::type_identity_t<const double *> A,
93 index_t lda, double *X, index_t incX) {
94 cblas_dtrmv(Layout, Uplo, TransA, Diag, N, A, lda, X, incX);
95}
96
97template <>
98GUANAQO_BLAS_EXPORT void xtrmv(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo,
99 CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag,
100 index_t N, std::type_identity_t<const float *> A,
101 index_t lda, float *X, index_t incX) {
102 cblas_strmv(Layout, Uplo, TransA, Diag, N, A, lda, X, incX);
103}
104
105template <>
106GUANAQO_BLAS_EXPORT void
107xtrsv(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA,
108 CBLAS_DIAG Diag, index_t N, std::type_identity_t<const double *> A,
109 index_t lda, double *X, index_t incX) {
110 cblas_dtrsv(Layout, Uplo, TransA, Diag, N, A, lda, X, incX);
111}
112
113template <>
114GUANAQO_BLAS_EXPORT void xtrsv(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo,
115 CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag,
116 index_t N, std::type_identity_t<const float *> A,
117 index_t lda, float *X, index_t incX) {
118 cblas_strsv(Layout, Uplo, TransA, Diag, N, A, lda, X, incX);
119}
120
121template <>
122GUANAQO_BLAS_EXPORT void xtrmm(CBLAS_LAYOUT Layout, CBLAS_SIDE Side,
123 CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA,
124 CBLAS_DIAG Diag, index_t M, index_t N,
125 std::type_identity_t<double> alpha,
126 std::type_identity_t<const double *> A,
127 index_t lda, double *B, index_t ldb) {
128 cblas_dtrmm(Layout, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
129}
130template <>
131GUANAQO_BLAS_EXPORT void
132xtrmm(CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo,
133 CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, index_t M, index_t N,
134 std::type_identity_t<float> alpha, std::type_identity_t<const float *> A,
135 index_t lda, float *B, index_t ldb) {
136 cblas_strmm(Layout, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
137}
138
139template <>
140GUANAQO_BLAS_EXPORT void
141xsyrk(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE Trans, index_t N,
142 index_t K, std::type_identity_t<double> alpha,
143 std::type_identity_t<const double *> A, index_t lda,
144 std::type_identity_t<double> beta, double *C, index_t ldc) {
145 cblas_dsyrk(Layout, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
146}
147template <>
148GUANAQO_BLAS_EXPORT void
149xsyrk(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE Trans, index_t N,
150 index_t K, std::type_identity_t<float> alpha,
151 std::type_identity_t<const float *> A, index_t lda,
152 std::type_identity_t<float> beta, float *C, index_t ldc) {
153 cblas_ssyrk(Layout, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
154}
155
156template <>
157GUANAQO_BLAS_EXPORT void xtrsm(CBLAS_LAYOUT Layout, CBLAS_SIDE Side,
158 CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA,
159 CBLAS_DIAG Diag, index_t M, index_t N,
160 std::type_identity_t<double> alpha,
161 std::type_identity_t<const double *> A,
162 index_t lda, double *B, index_t ldb) {
163 cblas_dtrsm(Layout, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
164}
165template <>
166GUANAQO_BLAS_EXPORT void
167xtrsm(CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo,
168 CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, index_t M, index_t N,
169 std::type_identity_t<float> alpha, std::type_identity_t<const float *> A,
170 index_t lda, float *B, index_t ldb) {
171 cblas_strsm(Layout, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
172}
173
174template <>
175GUANAQO_BLAS_EXPORT void xsytrf_rk(const char *uplo, const index_t *n,
176 double *a, const index_t *lda, double *e,
177 index_t *ipiv, double *work,
178 const index_t *lwork, index_t *info) {
179 dsytrf_rk(uplo, n, a, lda, e, ipiv, work, lwork, info);
180}
181
182template <>
183GUANAQO_BLAS_EXPORT void xsytrf_rk(const char *uplo, const index_t *n, float *a,
184 const index_t *lda, float *e, index_t *ipiv,
185 float *work, const index_t *lwork,
186 index_t *info) {
187 ssytrf_rk(uplo, n, a, lda, e, ipiv, work, lwork, info);
188}
189
190template <>
191GUANAQO_BLAS_EXPORT void
192xtrtrs(const char *uplo, const char *trans, const char *diag, const index_t *n,
193 const index_t *nrhs, std::type_identity_t<const double *> A,
194 const index_t *ldA, double *B, const index_t *ldB, index_t *info) {
195 dtrtrs(uplo, trans, diag, n, nrhs, A, ldA, B, ldB, info);
196}
197
198template <>
199GUANAQO_BLAS_EXPORT void
200xtrtrs(const char *uplo, const char *trans, const char *diag, const index_t *n,
201 const index_t *nrhs, std::type_identity_t<const float *> A,
202 const index_t *ldA, float *B, const index_t *ldB, index_t *info) {
203 strtrs(uplo, trans, diag, n, nrhs, A, ldA, B, ldB, info);
204}
205
206template <>
207GUANAQO_BLAS_EXPORT void xscal(index_t N, std::type_identity_t<double> alpha,
208 double *X, index_t incX) {
209 cblas_dscal(N, alpha, X, incX);
210}
211
212template <>
213GUANAQO_BLAS_EXPORT void xscal(index_t N, std::type_identity_t<float> alpha,
214 float *X, index_t incX) {
215 cblas_sscal(N, alpha, X, incX);
216}
217
218template <>
219GUANAQO_BLAS_EXPORT void xpotrf(const char *uplo, index_t n, double *a,
220 index_t lda, index_t *info) {
221 dpotrf(uplo, &n, a, &lda, info);
222}
223template <>
224GUANAQO_BLAS_EXPORT void xpotrf(const char *uplo, index_t n, float *a,
225 index_t lda, index_t *info) {
226 spotrf(uplo, &n, a, &lda, info);
227}
228
229template <>
230GUANAQO_BLAS_EXPORT void xlauum(const char *uplo, index_t n, double *a,
231 index_t lda, index_t *info) {
232 dlauum(uplo, &n, a, &lda, info);
233}
234template <>
235GUANAQO_BLAS_EXPORT void xlauum(const char *uplo, index_t n, float *a,
236 index_t lda, index_t *info) {
237 slauum(uplo, &n, a, &lda, info);
238}
239
240template <>
241GUANAQO_BLAS_EXPORT void xtrtri(const char *uplo, const char *diag, index_t n,
242 double *a, index_t lda, index_t *info) {
243 dtrtri(uplo, diag, &n, a, &lda, info);
244}
245template <>
246GUANAQO_BLAS_EXPORT void xtrtri(const char *uplo, const char *diag, index_t n,
247 float *a, index_t lda, index_t *info) {
248 strtri(uplo, diag, &n, a, &lda, info);
249}
250
251template <class T, class I>
252void xgemv_batch_strided(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, I m, I n,
253 std::type_identity_t<T> alpha,
254 std::type_identity_t<const T *> a, I lda, I stridea,
255 std::type_identity_t<const T *> x, I incx, I stridex,
256 std::type_identity_t<T> beta, T *y, I incy, I stridey,
257 I batch_size) {
258 GUANAQO_OMP(parallel for)
259 for (I i = 0; i < batch_size; ++i) {
260 auto offset_a = i * stridea;
261 auto offset_x = i * stridex;
262 auto offset_y = i * stridey;
263 xgemv(layout, trans, m, n, alpha, a + offset_a, lda, x + offset_x, incx,
264 beta, y + offset_y, incy);
265 }
266}
267
268#if GUANAQO_WITH_MKL
269template <>
270GUANAQO_BLAS_EXPORT void xgemv_batch_strided(
271 CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, index_t m, index_t n,
272 std::type_identity_t<double> alpha, std::type_identity_t<const double *> a,
273 index_t lda, index_t stridea, std::type_identity_t<const double *> x,
274 index_t incx, index_t stridex, std::type_identity_t<double> beta, double *y,
275 index_t incy, index_t stridey, index_t batch_size) {
276 if (m == 0 || n == 0)
277 return;
278 cblas_dgemv_batch_strided(layout, trans, m, n, alpha, a, lda, stridea, x,
279 incx, stridex, beta, y, incy, stridey,
280 batch_size);
281}
282
283template <>
284GUANAQO_BLAS_EXPORT void xgemv_batch_strided(
285 CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, index_t m, index_t n,
286 std::type_identity_t<float> alpha, std::type_identity_t<const float *> a,
287 index_t lda, index_t stridea, std::type_identity_t<const float *> x,
288 index_t incx, index_t stridex, std::type_identity_t<float> beta, float *y,
289 index_t incy, index_t stridey, index_t batch_size) {
290 if (m == 0 || n == 0)
291 return;
292 cblas_sgemv_batch_strided(layout, trans, m, n, alpha, a, lda, stridea, x,
293 incx, stridex, beta, y, incy, stridey,
294 batch_size);
295}
296#endif
297#if !GUANAQO_WITH_MKL
298template GUANAQO_BLAS_EXPORT void xgemv_batch_strided<double, index_t>(
299 CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, index_t m, index_t n,
300 std::type_identity_t<double> alpha, std::type_identity_t<const double *> a,
301 index_t lda, index_t stridea, std::type_identity_t<const double *> x,
302 index_t incx, index_t stridex, std::type_identity_t<double> beta, double *y,
303 index_t incy, index_t stridey, index_t batch_size);
304template GUANAQO_BLAS_EXPORT void xgemv_batch_strided<float, index_t>(
305 CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, index_t m, index_t n,
306 std::type_identity_t<float> alpha, std::type_identity_t<const float *> a,
307 index_t lda, index_t stridea, std::type_identity_t<const float *> x,
308 index_t incx, index_t stridex, std::type_identity_t<float> beta, float *y,
309 index_t incy, index_t stridey, index_t batch_size);
310#endif
311
312template <class T, class I>
313void xgemm_batch_strided(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA,
314 CBLAS_TRANSPOSE TransB, I M, I N, I K,
315 std::type_identity_t<T> alpha,
316 std::type_identity_t<const T *> A, I lda, I stridea,
317 std::type_identity_t<const T *> B, I ldb, I strideb,
318 std::type_identity_t<T> beta, T *C, I ldc, I stridec,
319 I batch_size) {
320 GUANAQO_OMP(parallel for)
321 for (I i = 0; i < batch_size; ++i) {
322 xgemm(Layout, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
323 ldc);
324 A += stridea;
325 B += strideb;
326 C += stridec;
327 }
328}
329
330#if GUANAQO_WITH_MKL
331template <>
332GUANAQO_BLAS_EXPORT void xgemm_batch_strided(
333 CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
334 index_t M, index_t N, index_t K, std::type_identity_t<double> alpha,
335 std::type_identity_t<const double *> A, index_t lda, index_t stridea,
336 std::type_identity_t<const double *> B, index_t ldb, index_t strideb,
337 std::type_identity_t<double> beta, double *C, index_t ldc, index_t stridec,
338 index_t batch_size) {
339 if (M == 0 || N == 0 || K == 0)
340 return;
341 cblas_dgemm_batch_strided(Layout, TransA, TransB, M, N, K, alpha, A, lda,
342 stridea, B, ldb, strideb, beta, C, ldc, stridec,
343 batch_size);
344}
345
346template <>
347GUANAQO_BLAS_EXPORT void xgemm_batch_strided(
348 CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
349 index_t M, index_t N, index_t K, std::type_identity_t<float> alpha,
350 std::type_identity_t<const float *> A, index_t lda, index_t stridea,
351 std::type_identity_t<const float *> B, index_t ldb, index_t strideb,
352 std::type_identity_t<float> beta, float *C, index_t ldc, index_t stridec,
353 index_t batch_size) {
354 if (M == 0 || N == 0 || K == 0)
355 return;
356 cblas_sgemm_batch_strided(Layout, TransA, TransB, M, N, K, alpha, A, lda,
357 stridea, B, ldb, strideb, beta, C, ldc, stridec,
358 batch_size);
359}
360#endif
361#if !GUANAQO_WITH_MKL
362template GUANAQO_BLAS_EXPORT void xgemm_batch_strided<double, index_t>(
363 CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
364 index_t M, index_t N, index_t K, std::type_identity_t<double> alpha,
365 std::type_identity_t<const double *> A, index_t lda, index_t stridea,
366 std::type_identity_t<const double *> B, index_t ldb, index_t strideb,
367 std::type_identity_t<double> beta, double *C, index_t ldc, index_t stridec,
368 index_t batch_size);
369template GUANAQO_BLAS_EXPORT void xgemm_batch_strided<float, index_t>(
370 CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
371 index_t M, index_t N, index_t K, std::type_identity_t<float> alpha,
372 std::type_identity_t<const float *> A, index_t lda, index_t stridea,
373 std::type_identity_t<const float *> B, index_t ldb, index_t strideb,
374 std::type_identity_t<float> beta, float *C, index_t ldc, index_t stridec,
375 index_t batch_size);
376#endif
377
378template <class T, class I>
379void xsyrk_batch_strided(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo,
380 CBLAS_TRANSPOSE Trans, I N, I K,
381 std::type_identity_t<T> alpha,
382 std::type_identity_t<const T *> A, I lda, I stridea,
383 std::type_identity_t<T> beta, T *C, I ldc, I stridec,
384 I batch_size) {
385 GUANAQO_OMP(parallel for)
386 for (I i = 0; i < batch_size; ++i) {
387 xsyrk(Layout, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
388 A += stridea;
389 C += stridec;
390 }
391}
392
393#if GUANAQO_WITH_MKL
394template <>
395GUANAQO_BLAS_EXPORT void
396xsyrk_batch_strided(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE Trans,
397 index_t N, index_t K, std::type_identity_t<double> alpha,
398 std::type_identity_t<const double *> A, index_t lda,
399 index_t stridea, std::type_identity_t<double> beta,
400 double *C, index_t ldc, index_t stridec,
401 index_t batch_size) {
402 cblas_dsyrk_batch_strided(Layout, Uplo, Trans, N, K, alpha, A, lda, stridea,
403 beta, C, ldc, stridec, batch_size);
404}
405
406template <>
407GUANAQO_BLAS_EXPORT void
408xsyrk_batch_strided(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE Trans,
409 index_t N, index_t K, std::type_identity_t<float> alpha,
410 std::type_identity_t<const float *> A, index_t lda,
411 index_t stridea, std::type_identity_t<float> beta, float *C,
412 index_t ldc, index_t stridec, index_t batch_size) {
413 cblas_ssyrk_batch_strided(Layout, Uplo, Trans, N, K, alpha, A, lda, stridea,
414 beta, C, ldc, stridec, batch_size);
415}
416#endif
417#if !GUANAQO_WITH_MKL
418template GUANAQO_BLAS_EXPORT void xsyrk_batch_strided<double, index_t>(
419 CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE Trans, index_t N,
420 index_t K, std::type_identity_t<double> alpha,
421 std::type_identity_t<const double *> A, index_t lda, index_t stridea,
422 std::type_identity_t<double> beta, double *C, index_t ldc, index_t stridec,
423 index_t batch_size);
424template GUANAQO_BLAS_EXPORT void xsyrk_batch_strided<float, index_t>(
425 CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE Trans, index_t N,
426 index_t K, std::type_identity_t<float> alpha,
427 std::type_identity_t<const float *> A, index_t lda, index_t stridea,
428 std::type_identity_t<float> beta, float *C, index_t ldc, index_t stridec,
429 index_t batch_size);
430#endif
431
432template <class T, class I>
433void xtrsm_batch_strided(CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo,
434 CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, I M, I N,
435 std::type_identity_t<T> alpha,
436 std::type_identity_t<const T *> A, I lda, I stridea,
437 T *B, I ldb, I strideb, I batch_size) {
438 GUANAQO_OMP(parallel for)
439 for (I i = 0; i < batch_size; ++i) {
440 xtrsm(Layout, Side, Uplo, TransA, Diag, M, N, alpha, A, lda, B, ldb);
441 A += stridea;
442 B += strideb;
443 }
444}
445
446#if GUANAQO_WITH_MKL
447template <>
448GUANAQO_BLAS_EXPORT void
449xtrsm_batch_strided(CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo,
450 CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, index_t M,
451 index_t N, std::type_identity_t<double> alpha,
452 std::type_identity_t<const double *> A, index_t lda,
453 index_t stridea, double *B, index_t ldb, index_t strideb,
454 index_t batch_size) {
455 cblas_dtrsm_batch_strided(Layout, Side, Uplo, TransA, Diag, M, N, alpha, A,
456 lda, stridea, B, ldb, strideb, batch_size);
457}
458
459template <>
460GUANAQO_BLAS_EXPORT void
461xtrsm_batch_strided(CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo,
462 CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, index_t M,
463 index_t N, std::type_identity_t<float> alpha,
464 std::type_identity_t<const float *> A, index_t lda,
465 index_t stridea, float *B, index_t ldb, index_t strideb,
466 index_t batch_size) {
467 cblas_strsm_batch_strided(Layout, Side, Uplo, TransA, Diag, M, N, alpha, A,
468 lda, stridea, B, ldb, strideb, batch_size);
469}
470#endif
471#if !GUANAQO_WITH_MKL
472template GUANAQO_BLAS_EXPORT void xtrsm_batch_strided<double, index_t>(
473 CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo,
474 CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, index_t M, index_t N,
475 std::type_identity_t<double> alpha, const double *A, index_t lda,
476 index_t stridea, double *B, index_t ldb, index_t strideb,
477 index_t batch_size);
478template GUANAQO_BLAS_EXPORT void xtrsm_batch_strided<float, index_t>(
479 CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo,
480 CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, index_t M, index_t N,
481 std::type_identity_t<float> alpha, const float *A, index_t lda,
482 index_t stridea, float *B, index_t ldb, index_t strideb,
483 index_t batch_size);
484#endif
485
486template <class T, class I>
487void xpotrf_batch_strided(const char *Uplo, I N, T *A, I lda, I stridea,
488 I batch_size) {
489 I info_all = 0;
490 GUANAQO_OMP(parallel for reduction(+:info_all))
491 for (I i = 0; i < batch_size; ++i) {
492 I info = 0;
493 I offset = i * stridea;
494 xpotrf(Uplo, N, A + offset, lda, &info);
495 if (info > 0)
496 info = 0; // Ignore factorization failure
497 info_all += info;
498 }
499 // TODO: proper error handling
500 lapack_throw_on_err("xpotrf_batch_strided", info_all);
501}
502
503template GUANAQO_BLAS_EXPORT void
504xpotrf_batch_strided<double, index_t>(const char *Uplo, index_t N, double *A,
505 index_t lda, index_t stridea,
506 index_t batch_size);
507template GUANAQO_BLAS_EXPORT void
508xpotrf_batch_strided<float, index_t>(const char *Uplo, index_t N, float *A,
509 index_t lda, index_t stridea,
510 index_t batch_size);
511
512} // namespace guanaqo::blas
This file provides simple overloaded wrappers around standard BLAS functions.
void xlauum(const char *uplo, I n, T *a, I lda, I *info)
void xtrsm(CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, I M, I N, std::type_identity_t< T > alpha, std::type_identity_t< const T * > A, I lda, T *B, I ldb)
void xtrmm(CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, I M, I N, std::type_identity_t< T > alpha, std::type_identity_t< const T * > A, I lda, T *B, I ldb)
void xpotrf(const char *uplo, I n, T *a, I lda, I *info)
void xsymv(CBLAS_LAYOUT Layout, CBLAS_UPLO uplo, I N, std::type_identity_t< T > alpha, std::type_identity_t< const T * > A, I lda, std::type_identity_t< const T * > X, I incX, std::type_identity_t< T > beta, T *Y, I incY)
void xtrmv(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, I N, std::type_identity_t< const T * > A, I lda, T *X, I incX)
void xtrtrs(const char *uplo, const char *trans, const char *diag, const I *n, const I *nrhs, std::type_identity_t< const T * > A, const I *ldA, T *B, const I *ldB, I *info)
void lapack_throw_on_err(Name &&name, index_t info)
Definition lapack.hpp:35
void xpotrf_batch_strided(const char *Uplo, I N, T *A, I lda, I stridea, I batch_size)
void xsyrk_batch_strided(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE Trans, I N, I K, std::type_identity_t< T > alpha, std::type_identity_t< const T * > A, I lda, I stridea, std::type_identity_t< T > beta, T *C, I ldc, I stridec, I batch_size)
void xtrsm_batch_strided(CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, I M, I N, std::type_identity_t< T > alpha, std::type_identity_t< const T * > A, I lda, I stridea, T *B, I ldb, I strideb, I batch_size)
void xtrsv(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, I N, std::type_identity_t< const T * > A, I lda, T *X, I incX)
void xgemv_batch_strided(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, I m, I n, std::type_identity_t< T > alpha, std::type_identity_t< const T * > a, I lda, I stridea, std::type_identity_t< const T * > x, I incx, I stridex, std::type_identity_t< T > beta, T *y, I incy, I stridey, I batch_size)
void xscal(I N, std::type_identity_t< T > alpha, T *X, I incX)
void xgemmt(CBLAS_LAYOUT Layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, I N, I K, std::type_identity_t< T > alpha, std::type_identity_t< const T * > A, I lda, std::type_identity_t< const T * > B, I ldb, std::type_identity_t< T > beta, T *C, I ldc)
void xsyrk(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE Trans, I N, I K, std::type_identity_t< T > alpha, std::type_identity_t< const T * > A, I lda, std::type_identity_t< T > beta, T *C, I ldc)
void xsytrf_rk(const char *uplo, const I *n, T *a, const I *lda, T *e, I *ipiv, T *work, const I *lwork, I *info)
void xgemm(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, I M, I N, I K, std::type_identity_t< T > alpha, std::type_identity_t< const T * > A, I lda, std::type_identity_t< const T * > B, I ldb, std::type_identity_t< T > beta, T *C, I ldc)
void xgemv(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, I M, I N, std::type_identity_t< T > alpha, std::type_identity_t< const T * > A, I lda, std::type_identity_t< const T * > X, I incX, std::type_identity_t< T > beta, T *Y, I incY)
void xgemm_batch_strided(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, I M, I N, I K, std::type_identity_t< T > alpha, std::type_identity_t< const T * > A, I lda, I stridea, std::type_identity_t< const T * > B, I ldb, I strideb, std::type_identity_t< T > beta, T *C, I ldc, I stridec, I batch_size)
void xtrtri(const char *uplo, const char *diag, I n, T *a, I lda, I *info)
#define GUANAQO_OMP(X)
Emit the OpenMP pragma X if OpenMP is enabled.
Definition openmp.h:28
LAPACK error handling.
#define dsytrf_rk(...)
Definition lapack.hpp:56
#define strtri(...)
Definition lapack.hpp:51
#define strtrs(...)
Definition lapack.hpp:55
#define dtrtrs(...)
Definition lapack.hpp:54
#define slauum(...)
Definition lapack.hpp:53
#define dtrtri(...)
Definition lapack.hpp:50
#define dlauum(...)
Definition lapack.hpp:52
#define ssytrf_rk(...)
Definition lapack.hpp:57
#define dpotrf(...)
Definition lapack.hpp:48
#define spotrf(...)
Definition lapack.hpp:49
template GUANAQO_BLAS_EXPORT void xtrsm_batch_strided< double, index_t >(CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, index_t M, index_t N, std::type_identity_t< double > alpha, const double *A, index_t lda, index_t stridea, double *B, index_t ldb, index_t strideb, index_t batch_size)
template GUANAQO_BLAS_EXPORT void xpotrf_batch_strided< float, index_t >(const char *Uplo, index_t N, float *A, index_t lda, index_t stridea, index_t batch_size)
template GUANAQO_BLAS_EXPORT void xtrsm_batch_strided< float, index_t >(CBLAS_LAYOUT Layout, CBLAS_SIDE Side, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE TransA, CBLAS_DIAG Diag, index_t M, index_t N, std::type_identity_t< float > alpha, const float *A, index_t lda, index_t stridea, float *B, index_t ldb, index_t strideb, index_t batch_size)
template GUANAQO_BLAS_EXPORT void xpotrf_batch_strided< double, index_t >(const char *Uplo, index_t N, double *A, index_t lda, index_t stridea, index_t batch_size)
int blas_index_t
Definition blas.hpp:20
template GUANAQO_BLAS_EXPORT void xgemv_batch_strided< double, index_t >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, index_t m, index_t n, std::type_identity_t< double > alpha, std::type_identity_t< const double * > a, index_t lda, index_t stridea, std::type_identity_t< const double * > x, index_t incx, index_t stridex, std::type_identity_t< double > beta, double *y, index_t incy, index_t stridey, index_t batch_size)
template GUANAQO_BLAS_EXPORT void xgemm_batch_strided< double, index_t >(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, index_t M, index_t N, index_t K, std::type_identity_t< double > alpha, std::type_identity_t< const double * > A, index_t lda, index_t stridea, std::type_identity_t< const double * > B, index_t ldb, index_t strideb, std::type_identity_t< double > beta, double *C, index_t ldc, index_t stridec, index_t batch_size)
template GUANAQO_BLAS_EXPORT void xsyrk_batch_strided< double, index_t >(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE Trans, index_t N, index_t K, std::type_identity_t< double > alpha, std::type_identity_t< const double * > A, index_t lda, index_t stridea, std::type_identity_t< double > beta, double *C, index_t ldc, index_t stridec, index_t batch_size)
template GUANAQO_BLAS_EXPORT void xgemm_batch_strided< float, index_t >(CBLAS_LAYOUT Layout, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, index_t M, index_t N, index_t K, std::type_identity_t< float > alpha, std::type_identity_t< const float * > A, index_t lda, index_t stridea, std::type_identity_t< const float * > B, index_t ldb, index_t strideb, std::type_identity_t< float > beta, float *C, index_t ldc, index_t stridec, index_t batch_size)
template GUANAQO_BLAS_EXPORT void xsyrk_batch_strided< float, index_t >(CBLAS_LAYOUT Layout, CBLAS_UPLO Uplo, CBLAS_TRANSPOSE Trans, index_t N, index_t K, std::type_identity_t< float > alpha, std::type_identity_t< const float * > A, index_t lda, index_t stridea, std::type_identity_t< float > beta, float *C, index_t ldc, index_t stridec, index_t batch_size)
template GUANAQO_BLAS_EXPORT void xgemv_batch_strided< float, index_t >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, index_t m, index_t n, std::type_identity_t< float > alpha, std::type_identity_t< const float * > a, index_t lda, index_t stridea, std::type_identity_t< const float * > x, index_t incx, index_t stridex, std::type_identity_t< float > beta, float *y, index_t incy, index_t stridey, index_t batch_size)
OpenMP helpers.