6#include "CWC_utility.H"
13#include "cusolverDn.h"
16#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
18inline void gpuAssert(cudaError_t code,
const char *file,
int line)
20 if (code != cudaSuccess)
22 fprintf(stderr,
"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
30 void fortran_name(dcopy,DCOPY)(
int *n,
double *dx,
int *incx,
double *dy,
int *incy);
31 void fortran_name(daxpy,DAXPY)(
int *n,
double *da,
double *dx,
int *incx,
double *dy,\
33 double fortran_name(dnrm2,DNRM2)(
int *n,
double *dx,
int *incx);
34 float fortran_name(snrm2,SNRM2)(
int *n,
float *dx,
int *incx);
35 double fortran_name(dznrm2,DZNRM2)(
int *n,CPX *dx,
int *incx);
36 void fortran_name(dscal,DSCAL)(
int *n,
double *da,
double *dx,
int *incx);
37 void fortran_name(zscal,ZSCAL)(
int *n,CPX *da,CPX *dx,
int *incx);
38 void fortran_name(dgemm,DGEMM)(
char *transa,
char *transb,
int *m,
int *n,
int *k, \
39 double *alpha,
double *a,
int *lda,
double *b,
int *ldb, \
40 double *beta,
double *c,
int *ldc);
41 void fortran_name(sgemm,SGEMM)(
char *transa,
char *transb,
int *m,
int *n,
int *k, \
42 float *alpha,
float *a,
int *lda,
float *b,
int *ldb, \
43 float *beta,
float *c,
int *ldc);
44 void fortran_name(dsymm,DSYMM)(
char *side,
char *uplo,
int *m,
int *n,
double *alpha, \
45 double *a,
int *lda,
double *b,
int *ldb,
double *beta, \
47 void fortran_name(dtrsm,DTRSM)(
char *side,
char *uplo,
char *transa,
char *diag, \
49 double *alpha,
double *a,
int *lda,
double *b,
int *ldb);
50 void fortran_name(strsm,STRSM)(
char *side,
char *uplo,
char *transa,
char *diag, \
52 float *alpha,
float *a,
int *lda,
float *b,
int *ldb);
53 void fortran_name(zgemm,ZGEMM)(
char *transa,
char *transb,
int *m,
int *n,
int *k, \
54 CPX *alpha, CPX *a,
int *lda, CPX *b,
int *ldb, \
55 CPX *beta, CPX *c,
int *ldc);
56 void fortran_name(zsymm,ZSYMM)(
char *side,
char *uplo,
int *m,
int *n, CPX *alpha, \
57 CPX *a,
int *lda, CPX *b,
int *ldb, CPX *beta, \
59 void fortran_name(ztrsm,ZTRSM)(
char *side,
char *uplo,
char *transa,
char *diag, \
61 CPX *alpha, CPX *a,
int *lda, CPX *b,
int *ldb);
62 void fortran_name(zhemm,ZHEMM)(
char *side,
char *uplo,
int *m,
int *n, CPX *alpha, \
63 CPX *a,
int *lda, CPX *b,
int *ldb, CPX *beta, \
65 void fortran_name(dgemv,DGEMV)(
char *trans,
int *m,
int *n,
double *alpha,
double *a, \
66 int *lda,
double *x,
int *incx,
double *beta,
double *y, \
69 void fortran_name(zgemv,ZGEMV)(
char *trans,
int *m,
int *n, CPX *alpha, CPX *a, \
70 int *lda, CPX *x,
int *incx, CPX *beta, CPX *y, \
73 double fortran_name(ddot,DDOT)(
int *n,
double *x,
int *incx,
double *y,
int *incy);
74 CPX fortran_name(zdotc,ZDOTC)(
int *n,CPX *x,
int *incx,CPX *y,
int *incy);
75 void fortran_name(zcopy,ZCOPY)(
int *n,CPX *dx,
int *incx,CPX *dy,
int *incy);
76 void fortran_name(zaxpy,ZAXPY)(
int *n, CPX *alpha, CPX *x,
int *incx, CPX *y,
int *incy);
77 double fortran_name(dasum,DASUM)(
int *n,
double *dx,
int *incx);
78 double fortran_name(dzasum,DZASUM)(
int *n,CPX *dx,
int *incx);
81 void fortran_name(dgetrf,DGETRF)(
int *m,
int *n,
double *a,
int *lda,
int *ipiv,
int *info);
82 void fortran_name(dgetrs,DGETRS)(
char *trans,
int *n,
int *nrhs,
double *a,
int *lda, \
83 int *ipiv,
double *b,
int *ldb,
int *info);
84 void fortran_name(zgetrf,ZGETRF)(
int *m,
int *n, CPX *a,
int *lda,
int *ipiv,
int *info);
85 void fortran_name(zgetrs,ZGETRS)(
char *trans,
int *n,
int *nrhs, CPX *a,
int *lda, \
86 int *ipiv, CPX *b,
int *ldb,
int *info);
87 void fortran_name(zgetri,ZGETRI)(
int *n,CPX *a,
int *lda,
int *ipiv,CPX *work,
int *lwork,\
89 void fortran_name(dgeev,DGEEV)(
char *jobvl,
char *jobvr,
int *n,
double *a,
int *lda, \
90 double *wr,
double *wi,
double *vl,
int *ldvl,
double *vr, \
91 int *ldvr,
double *work,
int *lwork,
int *info);
92 void fortran_name(dsyev,DSYEV)(
char *JOBZ,
char *UPLO,
int *N,
double *A,
int *LDA,
double *W,\
93 double *WORK,
int *LWORK,
int *INFO);
94 void fortran_name(dggev,DGGEV)(
char* jobvl,
char *jobvr,
int *n,
double *a,
int *lda, \
95 double *b,
int *ldb,
double *alphar,
double *alphai, \
96 double *beta,
double *vl,
int *ldvl,
double *vr,
int *ldvr, \
97 double *work,
int *lwork,
int *info);
98 void fortran_name(zggev,ZGGEV)(
char* jobvl,
char *jobvr,
int *n, CPX *a,
int *lda, \
99 CPX *b,
int *ldb, CPX *alpha, CPX *beta, CPX *vl,
int *ldvl,\
100 CPX *vr,
int *ldvr, CPX *work,
int *lwork,
double *rwork, \
102 void fortran_name(zgeev,ZGEEV)(
char *jobvl,
char *jobvr,
int *n, CPX *a,
int *lda, \
103 CPX *w, CPX *vl,
int *ldvl, CPX *vr,
int *ldvr, CPX *work, \
104 int *lwork,
double *rwork,
int *info);
105 void fortran_name(zheev,ZHEEV)(
char *jobvl,
char *uplo,
int *n, CPX *a,
int *lda, \
106 double *w, CPX *work,
int *lwork,
double *rwork,
int *info);
107 void fortran_name(dgetri,DGETRI)(
int *n,
double *a,
int *lda,
int *ipiv,
double *work,
int *lwork,\
109 void fortran_name(dsytri,DSYTRI)(
char *uplo,
int *n,
double *a,
int *lda,
int *ipiv,
double *work,\
111 void fortran_name(zhetrf,ZHETRF)(
char *uplo,
int *n,CPX *a,
int *lda,
int *ipiv,CPX *work,\
112 int *lwork,
int *info);
113 void fortran_name(zhetri,ZHETRI)(
char *uplo,
int *n,CPX *a,
int *lda,
int *ipiv,CPX *work,\
115 void fortran_name(zhetrs,ZHETRS)(
char *uplo,
int *n,
int *nrhs,CPX *a,
int *lda,
int *ipiv,\
116 CPX *b,
int *ldb,
int *info);
117 void fortran_name(dsysv,DSYSV)(
char *uplo,
int *n,
int *nrhs,
double *a,
int *lda,
int *ipiv,\
118 double *b,
int *ldb,
double *work,
int *lwork,
int *info);
119 void fortran_name(dsytrf,DSYTRF)(
char *uplo,
int *n,
double *a,
int *lda,
int *ipiv,
double *work,\
120 int *lwork,
int *info);
121 void fortran_name(dsytrs,DSYTRS)(
char *uplo,
int *n,
int *nrhs,
double *a,
int *lda,
int *ipiv,\
122 double *b,
int *ldb,
int *info);
123 void fortran_name(dstebz,DSTEBZ)(
char *range,
char *order,
int *iter,
double *vl,
double *vu,
int *il,
int *iu,\
124 double *abstol,
double *diag,
double *offd,
int *neval,
int *nsplit,\
125 double *eval,
int *iblock,
int *isplit,
double *work,
int *iwork,
int *info);
126 void fortran_name(zlarnv,ZLARNV)(
int*,
int*,
int*,CPX*);
127 void fortran_name(dgesdd,DGESDD)(
char *jobz,
int *m,
int *n,
double *a,
int *lda,
double *s,
double *u,
int *ldu, \
128 double *vt,
int *ldvt,
double *work,
int *lwork,
int *iwork,
int *info);
129 void fortran_name(zgesdd,ZGESDD)(
char *jobz,
int *m,
int *n,CPX *a,
int *lda,
double *s,CPX *u,
int *ldu,\
130 CPX *vt,
int *ldvt,CPX *work,
int *lwork,
double *rwork,
int *iwork,
int *info);
136inline double c_dzasum(
int n,CPX *dx,
int incx)
138 return fortran_name(dzasum,DZASUM)(&n,dx,&incx);
143inline void c_icopy(
int n,
int *dx,
int incx,
int *dy,
int incy)
147 for(i=0;i<n;i++) dy[i*incy] = dx[i*incx];
152inline void c_dcopy(
int n,
double *dx,
int incx,
double *dy,
int incy)
154 fortran_name(dcopy,DCOPY)(&n,dx,&incx,dy,&incy);
159inline void c_daxpy(
int n,
double da,
double *dx,
int incx,
double *dy,
int incy)
161 fortran_name(daxpy,DAXPY)(&n,&da,dx,&incx,dy,&incy);
166inline double c_dnrm2(
int n,
double* dx,
int incx)
168 return fortran_name(dnrm2,DNRM2)(&n,dx,&incx);
174inline double c_dnrm2(
int n,
float* dx,
int incx)
176 return fortran_name(snrm2,SNRM2)(&n,dx,&incx);
181inline double c_dznrm2(
int n,CPX* dx,
int incx)
183 return fortran_name(dznrm2,DZNRM2)(&n,dx,&incx);
188inline void c_dscal(
int n,
double da,
double *dx,
int incx)
190 fortran_name(dscal,DSCAL)(&n,&da,dx,&incx);
195inline void c_zscal(
int n,CPX da,CPX *dx,
int incx)
197 fortran_name(zscal,ZSCAL)(&n,&da,dx,&incx);
202inline void c_dgemm(
char transa,
char transb,
int m,
int n,
int k,
double alpha,
double *a, \
203 int lda,
double *b,
int ldb,
double beta,
double *c,
int ldc)
205 fortran_name(dgemm,DGEMM)(&transa,&transb,&m,&n,&k,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
210inline void c_sgemm(
char transa,
char transb,
int m,
int n,
int k,
float alpha,
float *a, \
211 int lda,
float *b,
int ldb,
float beta,
float *c,
int ldc)
213 fortran_name(sgemm,SGEMM)(&transa,&transb,&m,&n,&k,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
218inline void c_dsymm(
char side,
char uplo,
int m,
int n,
double alpha,
double *a,
int lda, \
219 double *b,
int ldb,
double beta,
double *c,
int ldc)
221 fortran_name(dsymm,DSYMM)(&side,&uplo,&m,&n,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
226inline void c_dtrsm(
char side,
char uplo,
char transa,
char diag,
int m,
int n,
double alpha,
double *a,
int lda,
double *b,
int ldb)
228 fortran_name(dtrsm,DTRSM)(&side,&uplo,&transa,&diag,&m,&n,&alpha,a,&lda,b,&ldb);
234inline void c_strsm(
char side,
char uplo,
char transa,
char diag,
int m,
int n,
float alpha,
float *a,
int lda,
float *b,
int ldb)
236 fortran_name(strsm,STRSM)(&side,&uplo,&transa,&diag,&m,&n,&alpha,a,&lda,b,&ldb);
241inline void c_zgemm(
char transa,
char transb,
int m,
int n,
int k, CPX alpha, CPX *a, \
242 int lda, CPX *b,
int ldb, CPX beta, CPX *c,
int ldc)
244 fortran_name(zgemm,ZGEMM)(&transa,&transb,&m,&n,&k,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
249inline void c_zsymm(
char side,
char uplo,
int m,
int n, CPX alpha, CPX *a,
int lda, \
250 CPX *b,
int ldb, CPX beta, CPX *c,
int ldc)
252 fortran_name(zsymm,ZSYMM)(&side,&uplo,&m,&n,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
257inline void c_ztrsm(
char side,
char uplo,
char transa,
char diag,
int m,
int n, CPX alpha, CPX *a,
int lda, CPX *b,
int ldb)
259 fortran_name(ztrsm,ZTRSM)(&side,&uplo,&transa,&diag,&m,&n,&alpha,a,&lda,b,&ldb);
264inline void c_zhemm(
char side,
char uplo,
int m,
int n, CPX alpha, CPX *a,
int lda, \
265 CPX *b,
int ldb, CPX beta, CPX *c,
int ldc)
267 fortran_name(zhemm,ZHEMM)(&side,&uplo,&m,&n,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
272inline void c_dgemv(
char transa,
int m,
int n,
double alpha,
double *a,
int lda,
double *x, \
273 int incx,
double beta,
double *y,
int incy)
275 fortran_name(dgemv,DGEMV)(&transa,&m,&n,&alpha,a,&lda,x,&incx,&beta,y,&incy);
280inline void c_zgemv(
char transa,
int m,
int n, CPX alpha, CPX *a,
int lda, CPX *x, \
281 int incx, CPX beta, CPX *y,
int incy)
283 fortran_name(zgemv,ZGEMV)(&transa,&m,&n,&alpha,a,&lda,x,&incx,&beta,y,&incy);
288inline double c_ddot(
int n,
double *x,
int incx,
double *y,
int incy)
290 return fortran_name(ddot,DDOT)(&n,x,&incx,y,&incy);
295inline CPX c_zdotc(
int n, CPX *x,
int incx, CPX *y,
int incy)
302 real = c_ddot(n,(
double*)x,2*incx,(
double*)y,2*incy)+\
303 c_ddot(n,(
double*)&x[0]+1,2*incx,(
double*)&y[0]+1,2*incy);
304 imag = -c_ddot(n,(
double*)&x[0]+1,2*incx,(
double*)y,2*incy)+\
305 c_ddot(n,(
double*)x,2*incx,(
double*)&y[0]+1,2*incy);
307 return CPX(real,imag);
313inline void c_zcopy(
int n,CPX *dx,
int incx,CPX *dy,
int incy)
315 fortran_name(zcopy,ZCOPY)(&n,dx,&incx,dy,&incy);
320inline void c_zaxpy(
int n, CPX alpha, CPX *x,
int incx, CPX *y,
int incy)
322 fortran_name(zaxpy,ZAXPY)(&n,&alpha,x,&incx,y,&incy);
327inline double c_dasum(
int n,
double *dx,
int incx)
329 return fortran_name(dasum,DASUM)(&n,dx,&incx);
334template <
typename T,
typename W>
335inline void c_tcopy(
int n,T *dx,
int incx,W *dy,
int incy);
338inline void c_tcopy(
int n,
double *dx,
int incx,
double *dy,
int incy)
340 c_dcopy(n,dx,incx,dy,incy);
344inline void c_tcopy(
int n,CPX *dx,
int incx,CPX *dy,
int incy)
346 c_zcopy(n,dx,incx,dy,incy);
350inline void c_tcopy(
int n,
double *dx,
int incx,CPX *dy,
int incy)
352 c_dcopy(n,dx,incx,(
double*)dy,2*incy);
357template <
typename T,
typename W>
358inline void c_taxpy(
int n, T alpha, T *x,
int incx, W *y,
int incy);
361inline void c_taxpy(
int n,
double alpha,
double *x,
int incx,
double *y,
int incy)
363 c_daxpy(n,alpha,x,incx,y,incy);
367inline void c_taxpy(
int n, CPX alpha, CPX *x,
int incx, CPX *y,
int incy)
369 c_zaxpy(n,alpha,x,incx,y,incy);
373inline void c_taxpy(
int n,
double alpha,
double *x,
int incx, CPX *y,
int incy)
375 c_daxpy(n,alpha,x,incx,(
double*)y,2*incy);
381inline double c_dtnrm2(
int n,T *x,
int incx);
384inline double c_dtnrm2(
int n,
double *x,
int incx)
386 return c_dnrm2(n,x,incx);
391inline double c_dtnrm2(
int n,
float *x,
int incx)
393 return c_dnrm2(n,x,incx);
397inline double c_dtnrm2(
int n,CPX *x,
int incx)
399 return c_dznrm2(n,x,incx);
405inline void c_tscal(
int n,T da,T *dx,
int incx);
408inline void c_tscal(
int n,
double da,
double *dx,
int incx)
410 c_dscal(n,da,dx,incx);
414inline void c_tscal(
int n,CPX da,CPX *dx,
int incx)
416 c_zscal(n,da,dx,incx);
422inline void c_tgemm(
char transa,
char transb,
int m,
int n,
int k, T alpha, T *a, \
423 int lda, T *b,
int ldb, T beta, T *c,
int ldc);
426inline void c_tgemm(
char transa,
char transb,
int m,
int n,
int k,
double alpha,
double *a, \
427 int lda,
double *b,
int ldb,
double beta,
double *c,
int ldc)
429 c_dgemm(transa,transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);
433inline void c_tgemm(
char transa,
char transb,
int m,
int n,
int k,
float alpha,
float *a, \
434 int lda,
float *b,
int ldb,
float beta,
float *c,
int ldc)
436 c_sgemm(transa,transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);
440inline void c_tgemm(
char transa,
char transb,
int m,
int n,
int k, CPX alpha, CPX *a, \
441 int lda, CPX *b,
int ldb, CPX beta, CPX *c,
int ldc)
443 c_zgemm(transa,transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);
449inline void c_tsymm(
char side,
char uplo,
int m,
int n, T alpha, T *a, \
450 int lda, T *b,
int ldb, T beta, T *c,
int ldc);
453inline void c_tsymm(
char side,
char uplo,
int m,
int n,
double alpha,
double *a, \
454 int lda,
double *b,
int ldb,
double beta,
double *c,
int ldc)
456 c_dsymm(side,uplo,m,n,alpha,a,lda,b,ldb,beta,c,ldc);
460inline void c_tsymm(
char side,
char uplo,
int m,
int n, CPX alpha, CPX *a, \
461 int lda, CPX *b,
int ldb, CPX beta, CPX *c,
int ldc)
463 c_zsymm(side,uplo,m,n,alpha,a,lda,b,ldb,beta,c,ldc);
469inline void c_ttrsm(
char side,
char uplo,
char transa,
char diag,
int m,
int n, \
470 T alpha, T *a,
int lda, T *b,
int ldb);
474inline void c_ttrsm(
char side,
char uplo,
char transa,
char diag,
int m,
int n, \
475 double alpha,
double *a,
int lda,
double *b,
int ldb)
477 c_dtrsm(side,uplo,transa,diag,m,n,alpha,a,lda,b,ldb);
482inline void c_ttrsm(
char side,
char uplo,
char transa,
char diag,
int m,
int n, \
483 float alpha,
float *a,
int lda,
float *b,
int ldb)
485 c_strsm(side,uplo,transa,diag,m,n,alpha,a,lda,b,ldb);
490inline void c_ttrsm(
char side,
char uplo,
char transa,
char diag,
int m,
int n, \
491 CPX alpha, CPX *a,
int lda, CPX *b,
int ldb)
493 c_ztrsm(side,uplo,transa,diag,m,n,alpha,a,lda,b,ldb);
499inline void tgemm_dev(
char transa,
char transb,
int m,
int n,
int k, T alpha, T *a,\
500 int lda, T *b,
int ldb, T beta, T *c,
int ldc, magma_queue_t queue);
503inline void tgemm_dev(
char transa,
char transb,
int m,
int n,
int k,
double alpha,\
504 double *a,
int lda,
double *b,
int ldb,
double beta,
double *c,
int ldc, magma_queue_t queue)
506 magma_trans_t magma_transa = magma_trans_const(transa);
507 magma_trans_t magma_transb = magma_trans_const(transb);
510 magma_dgemm(magma_transa,magma_transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc,queue);
516inline void tgemm_dev(
char transa,
char transb,
int m,
int n,
int k,
float alpha,\
517 float *a,
int lda,
float *b,
int ldb,
float beta,
float *c,
int ldc, magma_queue_t queue)
519 magma_trans_t magma_transa = magma_trans_const(transa);
520 magma_trans_t magma_transb = magma_trans_const(transb);
523 magma_sgemm(magma_transa,magma_transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc,queue);
529inline void tgemm_dev(
char transa,
char transb,
int m,
int n,
int k, CPX alpha,\
530 CPX *a,
int lda, CPX *b,
int ldb, CPX beta, CPX *c,
int ldc, magma_queue_t queue)
532 magma_trans_t magma_transa = magma_trans_const(transa);
533 magma_trans_t magma_transb = magma_trans_const(transb);
536 magma_zgemm(magma_transa,magma_transb,m,n,k,*
reinterpret_cast<magmaDoubleComplex*
>(&alpha),(magmaDoubleComplex_ptr)a,lda,(magmaDoubleComplex_ptr)b,ldb,*
reinterpret_cast<magmaDoubleComplex*
>(&beta),(magmaDoubleComplex_ptr)c,ldc,queue);
542inline void taxpy_dev(
void *handle,
int n,T alpha,T *x,
int incx,T *y,
int incy);
545inline void taxpy_dev(
void *handle,
int n,
double alpha,
double *x,
int incx,
double *y,
int incy)
547 daxpy_on_dev(handle,n,alpha,x,incx,y,incy);
551inline void taxpy_dev(
void *handle,
int n,CPX alpha,CPX *x,
int incx,CPX *y,
int incy)
553 zaxpy_on_dev(handle,n,alpha,x,incx,y,incy);
559inline void tasum_dev(
void *handle,
int n, T *x,
int incx, T *result);
562inline void tasum_dev(
void *handle,
int n,
double *x,
int incx,
double *result)
564 dasum_on_dev(handle, n, x, incx, result);
568inline void tasum_dev(
void *handle,
int n, CPX *x,
int incx, CPX *result)
571 zasum_on_dev(handle, n, x, incx, &dRes);
572 *result = CPX(dRes, 0.0);
578inline void tsum_dev(
int n, T *x,
int incx, T *result, magma_queue_t queue);
581inline void tsum_dev(
int n,
double *x,
int incx,
double *result, magma_queue_t queue)
583 dsum_on_dev(n, x, incx, result, queue);
587inline void tsum_dev(
int n, CPX *x,
int incx, CPX *result, magma_queue_t queue)
589 zsum_on_dev(n, x, incx, result, queue);
595inline void tgetrf_dev(
int m,
int n,T *a,
int lda,
int *ipiv,
int *info);
598inline void tgetrf_dev(
int m,
int n,
double *a,
int lda,
int *ipiv,
int *info)
600 magma_dgetrf_gpu(m,n,a,lda,ipiv,info);
604inline void tgetrf_dev(
int m,
int n,CPX *a,
int lda,
int *ipiv,
int *info)
606 magma_zgetrf_gpu(m,n,(magmaDoubleComplex_ptr)a,lda,ipiv,info);
612inline void tgetrs_dev(
char transa,
int n,
int nrhs,T *a,
int lda,
int *ipiv,T *b,
int ldb,
int *info);
615inline void tgetrs_dev(
char transa,
int n,
int nrhs,
double *a,
int lda,
int *ipiv,
double *b,
int ldb,\
618 magma_trans_t magma_transa = magma_trans_const(transa);
620 magma_dgetrs_gpu(magma_transa,n,nrhs,a,lda,ipiv,b,ldb,info);
624inline void tgetrs_dev(
char transa,
int n,
int nrhs,CPX *a,
int lda,
int *ipiv,CPX *b,
int ldb,\
627 magma_trans_t magma_transa = magma_trans_const(transa);
629 magma_zgetrs_gpu(magma_transa,n,nrhs,(magmaDoubleComplex_ptr)a,lda,ipiv,\
630 (magmaDoubleComplex_ptr)b,ldb,info);
636inline void tgesv_dev(
int n,
int nrhs,T *a,
int lda,
int *ipiv,T *b,
int ldb,
int type,
int *info);
639inline void tgesv_dev(
int n,
int nrhs,
double *a,
int lda,
int *ipiv,
double *b,
int ldb,
int type,
int *info)
642 magma_dgesv_nopiv_gpu(n,nrhs,a,lda,b,ldb,info);
647 magma_dsysv_nopiv_gpu(MagmaLower,n,nrhs,a,lda,b,ldb,info);
652inline void tgesv_dev(
int n,
int nrhs,CPX *a,
int lda,
int *ipiv,CPX *b,
int ldb,
int type,
int *info)
656 magma_zgesv_nopiv_gpu(n,nrhs,(magmaDoubleComplex_ptr)a,lda,(magmaDoubleComplex_ptr)b,\
663 magma_zhesv_nopiv_gpu(MagmaLower,n,nrhs,(magmaDoubleComplex_ptr)a,lda,(magmaDoubleComplex_ptr)b,\
671inline void tgetri_dev(
int n,T *a,
int lda,
int *ipiv,T *work,
int lwork,
int *info);
674inline void tgetri_dev(
int n,
double *a,
int lda,
int *ipiv,
double *work,
int lwork,
int *info)
676 magma_dgetri_gpu(n,a,lda,ipiv,work,lwork,info);
680inline void tgetri_dev(
int n,CPX *a,
int lda,
int *ipiv,CPX *work,
int lwork,
int *info)
682 magma_zgetri_gpu(n,(magmaDoubleComplex_ptr)a,lda,ipiv,(magmaDoubleComplex_ptr)work,lwork,info);
690inline void ttrtri_dev_cuda(
char uplo,
char unit_diag,
int n,T *a,
int lda,
int *info,
int* cuda_buffer_flag,
691 cusolverDnHandle_t* handle, cudaStream_t stream,
692 size_t *dev_size,
size_t *host_size,
double *mem_cuda_dev,
double *mem_cuda_host);
695inline void ttrtri_dev_cuda(
char uplo,
char unit_diag,
int n,
double *a,
int lda,
int *info,
int *cuda_buffer_flag,
696 cusolverDnHandle_t *handle, cudaStream_t stream,
697 size_t *dev_size,
size_t *host_size,
double *mem_cuda_dev,
double *mem_cuda_host)
700 cublasFillMode_t uplo_cuda;
702 uplo_cuda = CUBLAS_FILL_MODE_LOWER;
703 }
else if(uplo ==
'U') {
704 uplo_cuda = CUBLAS_FILL_MODE_UPPER;
706 printf(
"invalid potrf argmument uplo: %c\n", uplo);
711 cublasDiagType_t unit_diag_cuda;
712 if(unit_diag ==
'n' || unit_diag ==
'N'){
713 unit_diag_cuda = CUBLAS_DIAG_NON_UNIT;
714 }
else if(unit_diag ==
'u' || unit_diag ==
'U'){
715 unit_diag_cuda = CUBLAS_DIAG_UNIT;
717 printf(
"invalid trtri argmument unit_diag: %c\n", unit_diag);
722 if(cuda_buffer_flag[0] == 0){
724 cusolverStatus_t cuSolverError = cusolverDnXtrtri_bufferSize(
725 *handle, uplo_cuda, unit_diag_cuda, n, CUDA_R_64F, a, lda, dev_size, host_size);
727 if(cuSolverError != 0){
728 printf(
"cuSolverError buffer size allocation!\n");
732 cuda_buffer_flag[0] = 1;
739 cudaMalloc((
void**)&mem_cuda_dev, (*dev_size) *
sizeof(
double));
740 cudaMallocHost((
void**)&mem_cuda_host, (*host_size) *
sizeof(
double));
744 cusolverStatus_t cuSolverError = cusolverDnSetStream(*handle, stream);
745 if(cuSolverError != 0){
746 printf(
"cuSolverError set Stream! cuSolverError = %d\n", cuSolverError);
750 cuSolverError = cusolverDnXtrtri(
751 *handle, uplo_cuda, unit_diag_cuda, n, CUDA_R_64F, a, lda,
752 mem_cuda_dev, *dev_size, mem_cuda_host, *host_size, info);
754 if(cuSolverError != 0){
755 printf(
"cuSolverError trtri! cuSolverError : %d, lda = %d\n", cuSolverError, lda);
760 cudaMemcpyAsync(&info_host, info,
sizeof(
int), cudaMemcpyDeviceToHost, stream);
763 printf(
"cuSolverError trtri info not zero! info = %d, lda = %d\n", info_host, lda);
771inline void ttrtri_dev_cuda(
char uplo,
char unit_diag,
int n,
float *a,
int lda,
int *info,
int *cuda_buffer_flag,
772 cusolverDnHandle_t *handle, cudaStream_t stream,
773 size_t *dev_size,
size_t *host_size,
double *mem_cuda_dev,
double *mem_cuda_host)
776 cublasFillMode_t uplo_cuda;
778 uplo_cuda = CUBLAS_FILL_MODE_LOWER;
779 }
else if(uplo ==
'U') {
780 uplo_cuda = CUBLAS_FILL_MODE_UPPER;
782 printf(
"invalid potrf argmument uplo: %c\n", uplo);
787 cublasDiagType_t unit_diag_cuda;
788 if(unit_diag ==
'n' || unit_diag ==
'N'){
789 unit_diag_cuda = CUBLAS_DIAG_NON_UNIT;
790 }
else if(unit_diag ==
'u' || unit_diag ==
'U'){
791 unit_diag_cuda = CUBLAS_DIAG_UNIT;
793 printf(
"invalid trtri argmument unit_diag: %c\n", unit_diag);
798 if(cuda_buffer_flag[0] == 0){
800 cusolverStatus_t cuSolverError = cusolverDnXtrtri_bufferSize(
801 *handle, uplo_cuda, unit_diag_cuda, n, CUDA_R_32F, a, lda, dev_size, host_size);
803 if(cuSolverError != 0){
804 printf(
"cuSolverError buffer size allocation!\n");
808 cuda_buffer_flag[0] = 1;
815 cudaMalloc((
void**)&mem_cuda_dev, (*dev_size) *
sizeof(
double));
816 cudaMallocHost((
void**)&mem_cuda_host, (*host_size) *
sizeof(
double));
820 cusolverStatus_t cuSolverError = cusolverDnSetStream(*handle, stream);
821 if(cuSolverError != 0){
822 printf(
"cuSolverError set Stream! cuSolverError = %d\n", cuSolverError);
826 cuSolverError = cusolverDnXtrtri(
827 *handle, uplo_cuda, unit_diag_cuda, n, CUDA_R_32F, a, lda,
828 mem_cuda_dev, *dev_size, mem_cuda_host, *host_size, info);
830 if(cuSolverError != 0){
831 printf(
"cuSolverError trtri! cuSolverError : %d, lda = %d\n", cuSolverError, lda);
836 cudaMemcpyAsync(&info_host, info,
sizeof(
int), cudaMemcpyDeviceToHost, stream);
839 printf(
"cuSolverError trtri info not zero! info = %d, lda = %d\n", info_host, lda);
847inline void ttrtri_dev_cuda(
char uplo,
char unit_diag,
int n, CPX *a,
int lda,
int *info,
int *cuda_buffer_flag,
848 cusolverDnHandle_t *handle, cudaStream_t stream,
849 size_t *dev_size,
size_t *host_size,
double *mem_cuda_dev,
double *mem_cuda_host)
851 printf(
"just a placeholder. not working!\n");
858inline void ttrtri_dev(
char uplo,
char diag,
int n,T *a,
int lda,
int *info);
861inline void ttrtri_dev(
char uplo,
char diag,
int n,
double *a,
int lda,
int *info)
863 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
864 magma_diag_t magma_diag = magma_diag_const(diag);
866 magma_dtrtri_gpu(magma_uplo,magma_diag,n,a,lda,info);
871inline void ttrtri_dev(
char uplo,
char diag,
int n,
float *a,
int lda,
int *info)
873 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
874 magma_diag_t magma_diag = magma_diag_const(diag);
879 magma_strtri_gpu(magma_uplo,magma_diag,n,a,lda,info);
883inline void ttrtri_dev(
char uplo,
char diag,
int n,CPX *a,
int lda,
int *info)
885 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
886 magma_diag_t magma_diag = magma_diag_const(diag);
888 magma_ztrtri_gpu(magma_uplo,magma_diag,n,(magmaDoubleComplex_ptr)a,lda,info);
895inline void zgetrf_nopiv_dev(
int m,
int n,CPX *a,
int lda,
int *info)
897 magma_zgetrf_nopiv_gpu(m,n,(magmaDoubleComplex_ptr)a,lda,info);
902inline void zgetri_dev(
int n,CPX *a,
int lda,
int *ipiv,CPX *work,
int lwork,
int *info)
904 magma_zgetri_gpu(n,(magmaDoubleComplex_ptr)a,lda,ipiv,(magmaDoubleComplex_ptr)work,lwork,info);
909inline void zgetrs_nopiv_dev(
char transa,
int n,
int nrhs,CPX *a,
int lda,CPX *b,
int ldb,
int *info)
911 magma_trans_t magma_transa = magma_trans_const(transa);
913 magma_zgetrs_nopiv_gpu(magma_transa,n,nrhs,(magmaDoubleComplex_ptr)a,lda,\
914 (magmaDoubleComplex_ptr)b,ldb,info);
922inline void tpotrf_dev_cuda(
char uplo,
int n,T *a,
int lda,
int *info,
int* cuda_buffer_flag,
923 cusolverDnHandle_t* handle, cusolverDnParams_t *params, cudaStream_t stream,
924 size_t *dev_size,
size_t *host_size,
double *mem_cuda_dev,
double *mem_cuda_host);
927inline void tpotrf_dev_cuda(
char uplo,
int n,
double *a,
int lda,
int *info,
int *cuda_buffer_flag,
928 cusolverDnHandle_t *handle, cusolverDnParams_t *params, cudaStream_t stream,
929 size_t *dev_size,
size_t *host_size,
double *mem_cuda_dev,
double *mem_cuda_host)
937 static double* save_buffer_dev;
938 static double* save_buffer_dev2;
939 static double* save_buffer_host1;
940 static double* save_buffer_host2;
944 cublasFillMode_t uplo_cuda;
947 uplo_cuda = CUBLAS_FILL_MODE_LOWER;
948 }
else if(uplo ==
'U') {
949 uplo_cuda = CUBLAS_FILL_MODE_UPPER;
951 printf(
"invalid potrf argmument uplo: %c\n", uplo);
955 if(cuda_buffer_flag[0] == 0){
957 cusolverStatus_t cuSolverError = cusolverDnXpotrf_bufferSize(
958 *handle, *params, uplo_cuda, n, CUDA_R_64F, a, lda, CUDA_R_64F, dev_size, host_size);
959 if(cuSolverError != 0){
960 printf(
"cuSolverError buffer size allocation!\n");
964 cuda_buffer_flag[0] = 1;
967 printf(
"dev size = %ld, host size = %ld\n", *dev_size, *host_size);
968 cudaMalloc(&save_buffer_dev ,n*lda*
sizeof(
double));
969 cudaMalloc(&save_buffer_dev2,n*lda*
sizeof(
double));
970 cudaMallocHost(&save_buffer_host1,n*lda*
sizeof(
double));
971 cudaMallocHost(&save_buffer_host2,n*lda*
sizeof(
double));
972 cudaMemcpy(save_buffer_host1, a, n*lda*
sizeof(
double), cudaMemcpyDeviceToHost);
978 cudaMemcpy(save_buffer_host2, a, n*lda*
sizeof(
double), cudaMemcpyDeviceToHost);
980 cudaDeviceSynchronize();
982 for(
int i=0; i<n*lda; i++){
983 temp += (save_buffer_host1[i]-save_buffer_host2[i])*(save_buffer_host1[i]-save_buffer_host2[i]);
985 printf(
"temp = %f\n", sqrt(temp/(n*lda)));
986 cudaDeviceSynchronize();
991 cudaMalloc((
void**)&mem_cuda_dev, (*dev_size) *
sizeof(
double));
992 cudaMallocHost((
void**)&mem_cuda_host, (*host_size) *
sizeof(
double));
996 cusolverStatus_t cuSolverError = cusolverDnSetStream(*handle, stream);
997 if(cuSolverError != 0){
998 printf(
"cuSolverError set Stream! cuSolverError = %d\n", cuSolverError);
1002 cuSolverError = cusolverDnXpotrf(
1003 *handle, *params, uplo_cuda, n, CUDA_R_64F, a, lda, CUDA_R_64F,
1004 mem_cuda_dev, *dev_size, mem_cuda_host, *host_size, info);
1007 if(cuSolverError != 0){
1008 printf(
"cuSolverError potrf! cuSolverError : %d, lda = %d\n", cuSolverError, lda);
1013 cudaMemcpyAsync(&info_host, info,
sizeof(
int), cudaMemcpyDeviceToHost, stream);
1016 printf(
"cuSolverError potrf info not zero! info = %d, lda = %d\n", info_host, lda);
1024inline void tpotrf_dev_cuda(
char uplo,
int n,
float *a,
int lda,
int *info,
int *cuda_buffer_flag,
1025 cusolverDnHandle_t *handle, cusolverDnParams_t *params, cudaStream_t stream,
1026 size_t *dev_size,
size_t *host_size,
double *mem_cuda_dev,
double *mem_cuda_host)
1032 cublasFillMode_t uplo_cuda;
1035 uplo_cuda = CUBLAS_FILL_MODE_LOWER;
1036 }
else if(uplo ==
'U') {
1037 uplo_cuda = CUBLAS_FILL_MODE_UPPER;
1039 printf(
"invalid potrf argmument uplo: %c\n", uplo);
1043 if(cuda_buffer_flag[0] == 0){
1045 cusolverStatus_t cuSolverError = cusolverDnXpotrf_bufferSize(
1046 *handle, *params, uplo_cuda, n, CUDA_R_32F, a, lda, CUDA_R_32F, dev_size, host_size);
1047 if(cuSolverError != 0){
1048 printf(
"cuSolverError buffer size allocation!\n");
1052 cuda_buffer_flag[0] = 1;
1055 cudaMalloc((
void**)&mem_cuda_dev, (*dev_size) *
sizeof(
double));
1056 cudaMallocHost((
void**)&mem_cuda_host, (*host_size) *
sizeof(
double));
1060 cusolverStatus_t cuSolverError = cusolverDnSetStream(*handle, stream);
1061 if(cuSolverError != 0){
1062 printf(
"cuSolverError set Stream! cuSolverError = %d\n", cuSolverError);
1066 cuSolverError = cusolverDnXpotrf(
1067 *handle, *params, uplo_cuda, n, CUDA_R_32F, a, lda, CUDA_R_32F,
1068 mem_cuda_dev, *dev_size, mem_cuda_host, *host_size, info);
1071 if(cuSolverError != 0){
1072 printf(
"cuSolverError potrf! cuSolverError : %d, lda = %d\n", cuSolverError, lda);
1077 cudaMemcpyAsync(&info_host, info,
sizeof(
int), cudaMemcpyDeviceToHost, stream);
1080 printf(
"cuSolverError potrf info not zero! info = %d, lda = %d\n", info_host, lda);
1087inline void tpotrf_dev_cuda(
char uplo,
int n,CPX *a,
int lda,
int *info,
int *cuda_buffer_flag,
1088 cusolverDnHandle_t *handle, cusolverDnParams_t *params, cudaStream_t stream,
1089 size_t *dev_size,
size_t *host_size,
double *mem_cuda_dev,
double *mem_cuda_host)
1091 printf(
"just a placeholder. not working!\n");
1095#elif defined(MAGMA_EXPERT)
1097template <
typename T>
1098inline void magma_tpotrf_expert_wrapper(
char uplo,
int n, T* a,
int lda,
int info[1], magma_mode_t mode,
int subN,
int subSubN,
1099 void* host_work,
int *lwork_host,
void* device_work,
int *lwork_device,
1100 magma_event_t events[2], magma_queue_t queues[2],
int& init_flag){
1102 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1107 printf(
"in initializing potrf expert gpu.\n");
1109 magma_tpotrf_expert_gpu_work(magma_uplo, n, a, lda, info, mode, subN, subSubN,
1110 host_work, lwork_host, device_work, lwork_device, events, queues);
1111 printf(
"DPOTRF EXPERT WORK. Allocating %d of host memory. Allocating %d of device memory.\n", *lwork_host, *lwork_device);
1113 if(lwork_device > 0){
1114 gpuErrchk(cudaMalloc((
void**)&device_work,*lwork_device));
1118 gpuErrchk(cudaMallocHost((
void**)&host_work,*lwork_host));
1124 magma_tpotrf_expert_gpu_work(magma_uplo, n, a, lda, info, mode, subN, subSubN,
1125 host_work, lwork_host, device_work, lwork_device, events, queues);
1128template <
typename T>
1129inline void magma_tpotrf_expert_gpu_work(magma_uplo_t magma_uplo,
int n, T* a,
int lda,
int info[1], magma_mode_t mode,
int subN,
int subSubN,
void* host_work, magma_int_t *lwork_host,
void* device_work, magma_int_t *lwork_device, magma_event_t events[2], magma_queue_t queues[2]);
1133inline void magma_tpotrf_expert_gpu_work(magma_uplo_t magma_uplo,
int n,
double* a,
int lda,
int info[1], magma_mode_t mode,
int subN,
int subSubN,
void* host_work,
int *lwork_host,
void* device_work,
int *lwork_device, magma_event_t events[2], magma_queue_t queues[2])
1135 magma_int_t potrfErr = magma_dpotrf_expert_gpu_work(magma_uplo, n, a, lda, info, mode, subN, subSubN,
1136 host_work, lwork_host, device_work, lwork_device, events, queues);
1138 std::cout <<
"in magma potrf expert work error = " << potrfErr << std::endl;
1145inline void magma_tpotrf_expert_gpu_work(magma_uplo_t magma_uplo,
int n,
float* a,
int lda,
int info[1], magma_mode_t mode,
int subN,
int subSubN,
void* host_work,
int *lwork_host,
void* device_work,
int *lwork_device, magma_event_t events[2], magma_queue_t queues[2])
1147 magma_int_t potrfErr = magma_spotrf_expert_gpu_work(magma_uplo, n, a, lda, info, mode, subN, subSubN,
1148 host_work, lwork_host, device_work, lwork_device, events, queues);
1150 std::cout <<
"in magma potrf expert work error = " << potrfErr << std::endl;
1157template <
typename T>
1158inline void tpotrf_dev(
char uplo,
int n,T *a,
int lda,
int *info);
1162inline void tpotrf_dev(
char uplo,
int n,
double *a,
int lda,
int *info)
1168 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1175 magma_int_t potrfErr = magma_dpotrf_gpu(magma_uplo,n,a,lda,info);
1185 std::cout <<
"magma potrf error = " << potrfErr << std::endl;
1186 int ind_error = (potrfErr-1)*(lda+1);
1187 double a_problem_host;
1188 gpuErrchk(cudaMemcpy(&a_problem_host, &a[ind_error],
sizeof(
double), cudaMemcpyDeviceToHost));
1189 printf(
"ind error = %d, a[ind_error] = %f\n", ind_error, a_problem_host);
1196inline void tpotrf_dev(
char uplo,
int n,
float *a,
int lda,
int *info)
1202 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1204 magma_int_t potrfErr = magma_spotrf_gpu(magma_uplo,n,a,lda,info);
1208 std::cout <<
"magma potrf error = " << potrfErr << std::endl;
1209 int ind_error = (potrfErr-1)*(lda+1);
1210 double a_problem_host;
1211 gpuErrchk(cudaMemcpy(&a_problem_host, &a[ind_error],
sizeof(
double), cudaMemcpyDeviceToHost));
1212 printf(
"ind error = %d, a[ind_error] = %f\n", ind_error, a_problem_host);
1218inline void tpotrf_dev(
char uplo,
int n,CPX *a,
int lda,
int *info)
1220 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1222 magma_zpotrf_gpu(magma_uplo,n,(magmaDoubleComplex_ptr)a,lda,info);
1232template <
typename T>
1233inline void ttrsm_dev(
char side,
char uplo,
char trans,
char diag,
int m,
int n, T alpha, T *a,
int lda, T *b,
int ldb, magma_queue_t queue);
1236inline void ttrsm_dev(
char side,
char uplo,
char trans,
char diag,
int m,
int n,
double alpha,
double *a,
int lda,
double *b,
int ldb, magma_queue_t queue)
1238 magma_side_t magma_side = magma_side_const(side);
1239 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1240 magma_trans_t magma_trans = magma_trans_const(trans);
1241 magma_diag_t magma_diag = magma_diag_const(diag);
1243 magma_dtrsm(magma_side, magma_uplo, magma_trans, magma_diag, m, n, alpha, a, lda, b, ldb, queue);
1248inline void ttrsm_dev(
char side,
char uplo,
char trans,
char diag,
int m,
int n,
float alpha,
float *a,
int lda,
float *b,
int ldb, magma_queue_t queue)
1250 magma_side_t magma_side = magma_side_const(side);
1251 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1252 magma_trans_t magma_trans = magma_trans_const(trans);
1253 magma_diag_t magma_diag = magma_diag_const(diag);
1258 magma_strsm(magma_side, magma_uplo, magma_trans, magma_diag, m, n, alpha, a, lda, b, ldb, queue);
1262inline void ttrsm_dev(
char side,
char uplo,
char trans,
char diag,
int m,
int n, CPX alpha, CPX *a,
int lda, CPX *b,
int ldb, magma_queue_t queue)
1264 magma_side_t magma_side = magma_side_const(side);
1265 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1266 magma_trans_t magma_trans = magma_trans_const(trans);
1267 magma_diag_t magma_diag = magma_diag_const(diag);
1269 magma_ztrsm(magma_side, magma_uplo, magma_trans, magma_diag, m, n, *
reinterpret_cast<magmaDoubleComplex*
>(&alpha), (magmaDoubleComplex_ptr)a, lda, (magmaDoubleComplex_ptr)b, ldb, queue);
1274template <
typename T>
1275inline void tlacpy_dev(
char uplo,
int m,
int n, T *a,
int lda, T *b,
int ldb, magma_queue_t queue);
1278inline void tlacpy_dev(
char uplo,
int m,
int n,
double *a,
int lda,
double *b,
int ldb, magma_queue_t queue)
1280 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1282 magmablas_dlacpy(magma_uplo, m, n, a, lda, b, ldb, queue);
1288inline void tlacpy_dev(
char uplo,
int m,
int n,
float *a,
int lda,
float *b,
int ldb, magma_queue_t queue)
1290 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1294 magmablas_slacpy(magma_uplo, m, n, a, lda, b, ldb, queue);
1299inline void tlacpy_dev(
char uplo,
int m,
int n, CPX *a,
int lda, CPX *b,
int ldb, magma_queue_t queue)
1301 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1303 magmablas_zlacpy(magma_uplo, m, n, (magmaDoubleComplex_ptr)a, lda, (magmaDoubleComplex_ptr)b, ldb, queue);
1310inline void c_dgetrf(
int m,
int n,
double *a,
int lda,
int *ipiv,
int *info)
1312 fortran_name(dgetrf,DGETRF)(&m,&n,a,&lda,ipiv,info);
1317inline void c_dgetrs(
char transa,
int n,
int nrhs,
double *a,
int lda,
int *ipiv,
double *b, \
1320 fortran_name(dgetrs,DGETRS)(&transa,&n,&nrhs,a,&lda,ipiv,b,&ldb,info);
1325inline void c_zgetrf(
int m,
int n, CPX *a,
int lda,
int *ipiv,
int *info)
1327 fortran_name(zgetrf,ZGETRF)(&m,&n,a,&lda,ipiv,info);
1332inline void c_zgetrs(
char transa,
int n,
int nrhs, CPX *a,
int lda,
int *ipiv, CPX *b, \
1335 fortran_name(zgetrs,ZGETRS)(&transa,&n,&nrhs,a,&lda,ipiv,b,&ldb,info);
1340inline void c_zgetri(
int n,CPX *a,
int lda,
int *ipiv,CPX *work,
int lwork,
int *info)
1342 fortran_name(zgetri,ZGETRI)(&n,a,&lda,ipiv,work,&lwork,info);
1347inline void c_dgeev(
char jobvl,
char jobvr,
int n,
double *a,
int lda,
double *wr,
double *wi, \
1348 double *vl,
int ldvl,
double *vr,
int ldvr,
double *work,
int lwork, \
1351 fortran_name(dgeev,DGEEV)(&jobvl,&jobvr,&n,a,&lda,wr,wi,vl,&ldvl,vr,&ldvr,work,&lwork,info);
1356inline void c_dsyev(
char JOBZ,
char UPLO,
int N,
double *A,
int LDA,
double *W,
double *WORK,
int LWORK,\
1359 fortran_name(dsyev,DSYEV)(&JOBZ,&UPLO,&N,A,&LDA,W,WORK,&LWORK,INFO);
1364inline void c_dggev(
char jobvl,
char jobvr,
int n,
double *a,
int lda,
double *b,
int ldb, \
1365 double *alphar,
double *alphai,
double *beta,
double *vl,
int ldvl, \
1366 double *vr,
int ldvr,
double *work,
int lwork,
int *info)
1368 fortran_name(dggev,DGGEV)(&jobvl,&jobvr,&n,a,&lda,b,&ldb,alphar,alphai,beta,vl,&ldvl,vr,\
1369 &ldvr,work,&lwork,info);
1374inline void c_zggev(
char jobvl,
char jobvr,
int n, CPX *a,
int lda, CPX *b,
int ldb, CPX *alpha,\
1375 CPX *beta, CPX *vl,
int ldvl, CPX *vr,
int ldvr, CPX *work,
int lwork, \
1376 double *rwork,
int *info)
1378 fortran_name(zggev,ZGGEV)(&jobvl,&jobvr,&n,a,&lda,b,&ldb,alpha,beta,vl,&ldvl,vr,&ldvr,\
1379 work,&lwork,rwork,info);
1384inline void c_zgeev(
char jobvl,
char jobvr,
int n, CPX *a,
int lda, CPX *w, CPX *vl,
int ldvl, \
1385 CPX *vr,
int ldvr, CPX *work,
int lwork,
double *rwork,
int *info)
1387 fortran_name(zgeev,ZGEEV)(&jobvl,&jobvr,&n,a,&lda,w,vl,&ldvl,vr,&ldvr,work,&lwork,rwork,info);
1392inline void c_zheev(
char jobvl,
char uplo,
int n, CPX *a,
int lda,
double *w, CPX *work, \
1393 int lwork,
double *rwork,
int *info)
1395 fortran_name(zheev,ZHEEV)(&jobvl,&uplo,&n,a,&lda,w,work,&lwork,rwork,info);
1400inline void c_dgetri(
int n,
double *a,
int lda,
int *ipiv,
double *work,
int lwork,
int *info)
1402 fortran_name(dgetri,DGETRI)(&n,a,&lda,ipiv,work,&lwork,info);
1407inline void c_dsytri(
char uplo,
int n,
double *a,
int lda,
int *ipiv,
double *work,
int *info)
1409 fortran_name(dsytri,DSYTRI)(&uplo,&n,a,&lda,ipiv,work,info);
1414inline void c_zhetrf(
char uplo,
int n,CPX *a,
int lda,
int *ipiv,CPX *work,
int lwork,
int *info)
1416 fortran_name(zhetrf,ZHETRF)(&uplo,&n,a,&lda,ipiv,work,&lwork,info);
1421inline void c_zhetri(
char uplo,
int n,CPX *a,
int lda,
int *ipiv,CPX *work,
int *info)
1423 fortran_name(zhetri,ZHETRI)(&uplo,&n,a,&lda,ipiv,work,info);
1428inline void c_zhetrs(
char uplo,
int n,
int nrhs,CPX *a,
int lda,
int *ipiv,CPX *b,
int ldb,\
1431 fortran_name(zhetrs,ZHETRS)(&uplo,&n,&nrhs,a,&lda,ipiv,b,&ldb,info);
1436inline void c_dsysv(
char uplo,
int n,
int nrhs,
double *a,
int lda,
int *ipiv,
double *b,
int ldb,\
1437 double *work,
int lwork,
int *info)
1439 fortran_name(dsysv,DSYSV)(&uplo,&n,&nrhs,a,&lda,ipiv,b,&ldb,work,&lwork,info);
1444inline void c_dsytrf(
char uplo,
int n,
double *a,
int lda,
int *ipiv,
double *work,
int lwork,
int *info)
1446 fortran_name(dsytrf,DSYTRF)(&uplo,&n,a,&lda,ipiv,work,&lwork,info);
1451inline void c_dsytrs(
char uplo,
int n,
int nrhs,
double *a,
int lda,
int *ipiv,
double *b,
int ldb,\
1454 fortran_name(dsytrs,DSYTRS)(&uplo,&n,&nrhs,a,&lda,ipiv,b,&ldb,info);
1459inline void c_dstebz(
char *range,
char *order,
int *iter,
double *vl,
double *vu,
int *il,
int *iu,\
1460 double *abstol,
double *diag,
double *offd,
int *neval,
int *nsplit,\
1461 double *eval,
int *iblock,
int *isplit,
double *work,
int *iwork,
int *info)
1463 fortran_name(dstebz,DSTEBZ)(range,order,iter,vl,vu,il,iu,abstol,diag,offd,neval,nsplit,eval,\
1464 iblock,isplit,work,iwork,info);
1469inline void c_zlarnv(
int *idist,
int *iseed,
int n,CPX *x)
1471 fortran_name(zlarnv,ZLARNV)(idist,iseed,&n,x);
1476inline void c_dgesdd(
char jobz,
int m,
int n,
double *a,
int lda,
double *s,
double *u,
int ldu,\
1477 double *vt,
int ldvt,
double *work,
int lwork,
int *iwork,
int *info)
1479 fortran_name(dgesdd,DGESDD)(&jobz,&m,&n,a,&lda,s,u,&ldu,vt,&ldvt,work,&lwork,iwork,info);
1484inline void c_zgesdd(
char jobz,
int m,
int n,CPX *a,
int lda,
double *s,CPX *u,
int ldu,\
1485 CPX *vt,
int ldvt,CPX *work,
int lwork,
double *rwork,
int *iwork,
int *info)
1487 fortran_name(zgesdd,ZGESDD)(&jobz,&m,&n,a,&lda,s,u,&ldu,vt,&ldvt,work,&lwork,rwork,iwork,info);
1492template <
typename T>
1493inline void copy_csr_to_device(
int size,
int n_nonzeros,
int *hedge_i,
int *hindex_j,T *hnnz,\
1494 int *dedge_i,
int *dindex_j,T *dnnz);
1497inline void copy_csr_to_device(
int size,
int n_nonzeros,
int *hedge_i,
int *hindex_j,
double *hnnz,\
1498 int *dedge_i,
int *dindex_j,
double *dnnz)
1500 d_copy_csr_to_device(size,n_nonzeros,hedge_i,hindex_j,hnnz,dedge_i,dindex_j,dnnz);
1504inline void copy_csr_to_device(
int size,
int n_nonzeros,
int *hedge_i,
int *hindex_j,CPX *hnnz,\
1505 int *dedge_i,
int *dindex_j,CPX *dnnz)
1507 z_copy_csr_to_device(size,n_nonzeros,hedge_i,hindex_j,hnnz,dedge_i,dindex_j,dnnz);
1512template <
typename T>
1513inline void init_var_on_dev(T *var,
int N,cudaStream_t stream);
1516inline void init_var_on_dev(
double *var,
int N,cudaStream_t stream){
1517 d_init_var_on_dev(var,N,stream);
1521inline void init_var_on_dev(CPX *var,
int N,cudaStream_t stream){
1522 z_init_var_on_dev(var,N,stream);
1527template <
typename T>
1528inline void init_eye_on_dev(T *var,
int N,cudaStream_t stream);
1531inline void init_eye_on_dev(
double *var,
int N,cudaStream_t stream){
1532 d_init_eye_on_dev(var,N,stream);
1537inline void init_eye_on_dev(
float *var,
int N,cudaStream_t stream){
1538 s_init_eye_on_dev(var,N,stream);
1542inline void init_eye_on_dev(CPX *var,
int N,cudaStream_t stream){
1543 z_init_eye_on_dev(var,N,stream);
1548template <
typename T>
1549inline void csr_mult_f(
void *handle,
int m,
int n,
int k,
int n_nonzeros,
int *Aedge_i,
int *Aindex_j,\
1550 T *Annz,T alpha,T *B,T beta,T *C);
1553inline void csr_mult_f(
void *handle,
int m,
int n,
int k,
int n_nonzeros,
int *Aedge_i,
int *Aindex_j,\
1554 double *Annz,
double alpha,
double *B,
double beta,
double *C)
1556 d_csr_mult_f(handle,m,n,k,n_nonzeros,Aedge_i,Aindex_j,Annz,alpha,B,beta,C);
1560inline void csr_mult_f(
void *handle,
int m,
int n,
int k,
int n_nonzeros,
int *Aedge_i,
int *Aindex_j,\
1561 CPX *Annz,CPX alpha,CPX *B,CPX beta,CPX *C)
1563 z_csr_mult_f(handle,m,n,k,n_nonzeros,Aedge_i,Aindex_j,Annz,alpha,B,beta,C);
1568template <
typename T>
1569inline void transpose_matrix(T *odata,T *idata,
int size_x,
int size_y);
1572inline void transpose_matrix(
double *odata,
double *idata,
int size_x,
int size_y)
1574 d_transpose_matrix(odata,idata,size_x,size_y);
1578inline void transpose_matrix(CPX *odata,CPX *idata,
int size_x,
int size_y)
1580 z_transpose_matrix(odata,idata,size_x,size_y);
1585template <
typename T>
1586inline void extract_diag_on_dev(T *D,
int *edge_i,
int *index_j,T *nnz,
int NR,\
1587 int imin,
int imax,
int shift,
int findx,cudaStream_t stream);
1590inline void extract_diag_on_dev(
double *D,
int *edge_i,
int *index_j,
double *nnz,
int NR,\
1591 int imin,
int imax,
int shift,
int findx,cudaStream_t stream){
1592 d_extract_diag_on_dev(D,edge_i,index_j,nnz,NR,imin,imax,shift,findx,stream);
1596inline void extract_diag_on_dev(CPX *D,
int *edge_i,
int *index_j,CPX *nnz,
int NR,\
1597 int imin,
int imax,
int shift,
int findx,cudaStream_t stream){
1598 z_extract_diag_on_dev(D,edge_i,index_j,nnz,NR,imin,imax,shift,findx,stream);
1603template <
typename T>
1604inline void extract_not_diag_on_dev(T *D,
int *edge_i,
int *index_j,T *nnz,
int NR,\
1605 int imin,
int imax,
int jmin,
int side,
int shift,
int findx,cudaStream_t stream);
1608inline void extract_not_diag_on_dev(
double *D,
int *edge_i,
int *index_j,
double *nnz,
int NR,\
1609 int imin,
int imax,
int jmin,
int side,
int shift,
int findx,cudaStream_t stream){
1610 d_extract_not_diag_on_dev(D,edge_i,index_j,nnz,NR,imin,imax,jmin,side,shift,findx,stream);
1614inline void extract_not_diag_on_dev(CPX *D,
int *edge_i,
int *index_j,CPX *nnz,
int NR,\
1615 int imin,
int imax,
int jmin,
int side,
int shift,
int findx,cudaStream_t stream){
1616 z_extract_not_diag_on_dev(D,edge_i,index_j,nnz,NR,imin,imax,jmin,side,shift,findx,stream);
1621template <
typename T>
1622inline void tril_dev(T *A,
int lda,
int N);
1625inline void tril_dev(
double *A,
int lda,
int N)
1627 d_tril_on_dev(A, lda, N);
1632inline void tril_dev(
float *A,
int lda,
int N)
1634 s_tril_on_dev(A, lda, N);
1638inline void tril_dev(CPX *A,
int lda,
int N)
1640 z_tril_on_dev(A, lda, N);
1645template <
typename T>
1646inline void indexed_copy_dev(T *src, T *dst,
size_t *index,
size_t N);
1649inline void indexed_copy_dev(
double *src,
double *dst,
size_t *index,
size_t N)
1651 d_indexed_copy_on_dev(src, dst, index, N);
1655inline void indexed_copy_dev(CPX *src, CPX *dst,
size_t *index,
size_t N)
1657 z_indexed_copy_on_dev(src, dst, index, N);
1662template <
typename T>
1663inline void indexed_copy_offset_dev(T *src, T *dst,
size_t *index,
size_t N,
size_t offset);
1666inline void indexed_copy_offset_dev(
double *src,
double *dst,
size_t *index,
size_t N,
size_t offset)
1668 d_indexed_copy_offset_on_dev(src, dst, index, N, offset);
1673inline void indexed_copy_offset_dev(
float *src,
float *dst,
size_t *index,
size_t N,
size_t offset)
1675 s_indexed_copy_offset_on_dev(src, dst, index, N, offset);
1679inline void indexed_copy_offset_dev(CPX *src, CPX *dst,
size_t *index,
size_t N,
size_t offset)
1681 z_indexed_copy_offset_on_dev(src, dst, index, N, offset);
1686template <
typename T>
1687inline void indexed_copy(T *src, T *dst,
size_t *index,
size_t N)
1689 #pragma omp parallel for
1690 for (
int i = 0; i < N; i++)
1692 dst[i] = src[index[i]];
1698template <
typename T>
1699inline void indexed_log_sum(T *x,
size_t *index,
size_t N, T *sum)
1705 #pragma omp parallel for reduction(+:sum[:1])
1706 for (
int i = 0; i < N; i++)
1708 *sum += log(x[index[i]]);
1717template <
typename T>
1718inline void log_sum(T *x,
size_t N, T *sum)
1723 #pragma omp parallel for reduction(+:sum[:1])
1724 for (
int i = 0; i < N; i++)
1732template <
typename T>
1733inline void log_dev(T *x,
size_t N);
1736inline void log_dev(
double *x,
size_t N)
1742inline void log_dev(CPX *x,
size_t N)
1749template <
typename T>
1750inline void fill_dev(T *x,
const T value,
size_t N);
1753inline void fill_dev(
double *x,
const double value,
size_t N)
1755 d_fill_on_dev(x, value, N);
1760inline void fill_dev(
float *x,
const float value,
size_t N)
1762 s_fill_on_dev(x, value, N);
1766inline void fill_dev(CPX *x,
const CPX value,
size_t N)
1768 z_fill_on_dev(x, value, N);
1773template <
typename T>
1774inline void init_block_matrix_dev(T *M,
size_t *ia,
size_t *ja, T *a,
size_t nnz,
size_t ns,
size_t nt,
size_t nd);
1777inline void init_block_matrix_dev(
double *M,
size_t *ia,
size_t *ja,
double *a,
size_t nnz,
size_t ns,
size_t nt,
size_t nd)
1779 d_init_block_matrix_on_dev(M, ia, ja, a, nnz, ns, nt, nd);
1783inline void init_block_matrix_dev(CPX *M,
size_t *ia,
size_t *ja, CPX *a,
size_t nnz,
size_t ns,
size_t nt,
size_t nd)
1785 z_init_block_matrix_on_dev(M, ia, ja, a, nnz, ns, nt, nd);
1790template <
typename T>
1791inline void init_supernode_dev(T *M,
size_t *ia,
size_t *ja, T *a,
size_t supernode,
size_t supernode_nnz,
size_t supernode_offset,
size_t ns,
size_t nt,
size_t nd, cudaStream_t stream );
1794inline void init_supernode_dev(
double *M,
size_t *ia,
size_t *ja,
double *a,
size_t supernode,
size_t supernode_nnz,
size_t supernode_offset,
size_t ns,
size_t nt,
size_t nd, cudaStream_t stream )
1796 d_init_supernode_on_dev(M, ia, ja, a, supernode, supernode_nnz, supernode_offset, ns, nt, nd, stream);
1801inline void init_supernode_dev(
float *M,
size_t *ia,
size_t *ja,
float *a,
size_t supernode,
size_t supernode_nnz,
size_t supernode_offset,
size_t ns,
size_t nt,
size_t nd, cudaStream_t stream )
1803 s_init_supernode_on_dev(M, ia, ja, a, supernode, supernode_nnz, supernode_offset, ns, nt, nd, stream);
1807inline void init_supernode_dev(CPX *M,
size_t *ia,
size_t *ja, CPX *a,
size_t supernode,
size_t supernode_nnz,
size_t supernode_offset,
size_t ns,
size_t nt,
size_t nd, cudaStream_t stream )
1809 z_init_supernode_on_dev(M, ia, ja, a, supernode, supernode_nnz, supernode_offset, ns, nt, nd, stream );
1815template <
typename T>
1816inline void extract_nnzA_dev(T *a,
size_t *ia,
size_t *ja, T *M,
size_t supernode,
size_t supernode_nnz,
size_t supernode_offset,
size_t ns,
size_t nt,
size_t nd);
1819inline void extract_nnzA_dev(
double *a,
size_t *ia,
size_t *ja,
double *M,
size_t supernode,
size_t supernode_nnz,
size_t supernode_offset,
size_t ns,
size_t nt,
size_t nd)
1821 d_extract_nnzA_on_dev(a, ia, ja, M, supernode, supernode_nnz, supernode_offset, ns, nt, nd);
1826inline void extract_nnzA_dev(
float *a,
size_t *ia,
size_t *ja,
float *M,
size_t supernode,
size_t supernode_nnz,
size_t supernode_offset,
size_t ns,
size_t nt,
size_t nd)
1828 s_extract_nnzA_on_dev(a, ia, ja, M, supernode, supernode_nnz, supernode_offset, ns, nt, nd);
1832inline void extract_nnzA_dev(CPX *a,
size_t *ia,
size_t *ja, CPX *M,
size_t supernode,
size_t supernode_nnz,
size_t supernode_offset,
size_t ns,
size_t nt,
size_t nd)
1834 printf(
"just dummy version of COMPLEX extract_nnzA_dev. not implemented.\n");