INLA_DIST
Loading...
Searching...
No Matches
Blas.H
1#ifndef __BLAS
2#define __BLAS
3
4#include <iostream>
5#include "Types.H"
6#include "CWC_utility.H"
7#include "magma_v2.h"
8
9#define CUDA_POTRF
10//#define MAGMA_EXPERT
11
12#ifdef CUDA_POTRF
13#include "cusolverDn.h"
14#endif
15
16#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
17
18inline void gpuAssert(cudaError_t code, const char *file, int line)
19{
20 if (code != cudaSuccess)
21 {
22 fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
23 exit(code);
24 abort();
25 }
26}
27
28extern "C" {
29//Blas
30 void fortran_name(dcopy,DCOPY)(int *n,double *dx,int *incx,double *dy,int *incy);
31 void fortran_name(daxpy,DAXPY)(int *n,double *da,double *dx,int *incx,double *dy,\
32 int *incy);
33 double fortran_name(dnrm2,DNRM2)(int *n,double *dx,int *incx);
34 float fortran_name(snrm2,SNRM2)(int *n,float *dx,int *incx);
35 double fortran_name(dznrm2,DZNRM2)(int *n,CPX *dx,int *incx);
36 void fortran_name(dscal,DSCAL)(int *n,double *da,double *dx,int *incx);
37 void fortran_name(zscal,ZSCAL)(int *n,CPX *da,CPX *dx,int *incx);
38 void fortran_name(dgemm,DGEMM)(char *transa, char *transb, int *m, int *n, int *k, \
39 double *alpha, double *a, int *lda, double *b, int *ldb, \
40 double *beta, double *c, int *ldc);
41 void fortran_name(sgemm,SGEMM)(char *transa, char *transb, int *m, int *n, int *k, \
42 float *alpha, float *a, int *lda, float *b, int *ldb, \
43 float *beta, float *c, int *ldc);
44 void fortran_name(dsymm,DSYMM)(char *side, char *uplo, int *m, int *n, double *alpha, \
45 double *a, int *lda, double *b, int *ldb, double *beta, \
46 double *c, int *ldc);
47 void fortran_name(dtrsm,DTRSM)(char *side, char *uplo, char *transa, char *diag, \
48 int *m, int *n, \
49 double *alpha, double *a, int *lda, double *b, int *ldb);
50 void fortran_name(strsm,STRSM)(char *side, char *uplo, char *transa, char *diag, \
51 int *m, int *n, \
52 float *alpha, float *a, int *lda, float *b, int *ldb);
53 void fortran_name(zgemm,ZGEMM)(char *transa, char *transb, int *m, int *n, int *k, \
54 CPX *alpha, CPX *a, int *lda, CPX *b, int *ldb, \
55 CPX *beta, CPX *c, int *ldc);
56 void fortran_name(zsymm,ZSYMM)(char *side, char *uplo, int *m, int *n, CPX *alpha, \
57 CPX *a, int *lda, CPX *b, int *ldb, CPX *beta, \
58 CPX *c, int *ldc);
59 void fortran_name(ztrsm,ZTRSM)(char *side, char *uplo, char *transa, char *diag, \
60 int *m, int *n, \
61 CPX *alpha, CPX *a, int *lda, CPX *b, int *ldb);
62 void fortran_name(zhemm,ZHEMM)(char *side, char *uplo, int *m, int *n, CPX *alpha, \
63 CPX *a, int *lda, CPX *b, int *ldb, CPX *beta, \
64 CPX *c, int *ldc);
65 void fortran_name(dgemv,DGEMV)(char *trans, int *m, int *n, double *alpha, double *a, \
66 int *lda, double *x, int *incx, double *beta, double *y, \
67 int *incy);
68
69 void fortran_name(zgemv,ZGEMV)(char *trans, int *m, int *n, CPX *alpha, CPX *a, \
70 int *lda, CPX *x, int *incx, CPX *beta, CPX *y, \
71 int *incy);
72
73 double fortran_name(ddot,DDOT)(int *n,double *x,int *incx,double *y,int *incy);
74 CPX fortran_name(zdotc,ZDOTC)(int *n,CPX *x,int *incx,CPX *y,int *incy);
75 void fortran_name(zcopy,ZCOPY)(int *n,CPX *dx,int *incx,CPX *dy,int *incy);
76 void fortran_name(zaxpy,ZAXPY)(int *n, CPX *alpha, CPX *x, int *incx, CPX *y, int *incy);
77 double fortran_name(dasum,DASUM)(int *n,double *dx,int *incx);
78 double fortran_name(dzasum,DZASUM)(int *n,CPX *dx,int *incx);
79
80//Lapack
81 void fortran_name(dgetrf,DGETRF)(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
82 void fortran_name(dgetrs,DGETRS)(char *trans, int *n, int *nrhs, double *a, int *lda, \
83 int *ipiv, double *b, int *ldb, int *info);
84 void fortran_name(zgetrf,ZGETRF)(int *m, int *n, CPX *a, int *lda, int *ipiv, int *info);
85 void fortran_name(zgetrs,ZGETRS)(char *trans, int *n, int *nrhs, CPX *a, int *lda, \
86 int *ipiv, CPX *b, int *ldb, int *info);
87 void fortran_name(zgetri,ZGETRI)(int *n,CPX *a,int *lda,int *ipiv,CPX *work,int *lwork,\
88 int *info);
89 void fortran_name(dgeev,DGEEV)(char *jobvl, char *jobvr, int *n, double *a, int *lda, \
90 double *wr, double *wi, double *vl, int *ldvl, double *vr, \
91 int *ldvr, double *work, int *lwork, int *info);
92 void fortran_name(dsyev,DSYEV)(char *JOBZ,char *UPLO,int *N,double *A,int *LDA,double *W,\
93 double *WORK,int *LWORK,int *INFO);
94 void fortran_name(dggev,DGGEV)(char* jobvl, char *jobvr, int *n, double *a, int *lda, \
95 double *b, int *ldb, double *alphar, double *alphai, \
96 double *beta, double *vl, int *ldvl, double *vr, int *ldvr, \
97 double *work, int *lwork, int *info);
98 void fortran_name(zggev,ZGGEV)(char* jobvl, char *jobvr, int *n, CPX *a, int *lda, \
99 CPX *b, int *ldb, CPX *alpha, CPX *beta, CPX *vl, int *ldvl,\
100 CPX *vr, int *ldvr, CPX *work, int *lwork, double *rwork, \
101 int *info);
102 void fortran_name(zgeev,ZGEEV)(char *jobvl, char *jobvr, int *n, CPX *a, int *lda, \
103 CPX *w, CPX *vl, int *ldvl, CPX *vr, int *ldvr, CPX *work, \
104 int *lwork, double *rwork, int *info);
105 void fortran_name(zheev,ZHEEV)(char *jobvl, char *uplo, int *n, CPX *a, int *lda, \
106 double *w, CPX *work, int *lwork, double *rwork, int *info);
107 void fortran_name(dgetri,DGETRI)(int *n,double *a,int *lda,int *ipiv,double *work,int *lwork,\
108 int *info);
109 void fortran_name(dsytri,DSYTRI)(char *uplo,int *n,double *a,int *lda,int *ipiv,double *work,\
110 int *info);
111 void fortran_name(zhetrf,ZHETRF)(char *uplo,int *n,CPX *a,int *lda,int *ipiv,CPX *work,\
112 int *lwork,int *info);
113 void fortran_name(zhetri,ZHETRI)(char *uplo,int *n,CPX *a,int *lda,int *ipiv,CPX *work,\
114 int *info);
115 void fortran_name(zhetrs,ZHETRS)(char *uplo,int *n,int *nrhs,CPX *a,int *lda,int *ipiv,\
116 CPX *b,int *ldb,int *info);
117 void fortran_name(dsysv,DSYSV)(char *uplo,int *n,int *nrhs,double *a,int *lda,int *ipiv,\
118 double *b,int *ldb,double *work,int *lwork,int *info);
119 void fortran_name(dsytrf,DSYTRF)(char *uplo,int *n,double *a,int *lda,int *ipiv,double *work,\
120 int *lwork,int *info);
121 void fortran_name(dsytrs,DSYTRS)(char *uplo,int *n,int *nrhs,double *a,int *lda,int *ipiv,\
122 double *b,int *ldb,int *info);
123 void fortran_name(dstebz,DSTEBZ)(char *range,char *order,int *iter,double *vl,double *vu,int *il,int *iu,\
124 double *abstol,double *diag,double *offd,int *neval,int *nsplit,\
125 double *eval,int *iblock,int *isplit,double *work,int *iwork,int *info);
126 void fortran_name(zlarnv,ZLARNV)(int*,int*,int*,CPX*);
127 void fortran_name(dgesdd,DGESDD)(char *jobz,int *m,int *n,double *a,int *lda,double *s,double *u,int *ldu, \
128 double *vt,int *ldvt,double *work,int *lwork,int *iwork,int *info);
129 void fortran_name(zgesdd,ZGESDD)(char *jobz,int *m,int *n,CPX *a,int *lda,double *s,CPX *u,int *ldu,\
130 CPX *vt,int *ldvt,CPX *work,int *lwork,double *rwork,int *iwork,int *info);
131
132}
133
134/*SAB********************************************************************************************/
135
136inline double c_dzasum(int n,CPX *dx,int incx)
137{
138 return fortran_name(dzasum,DZASUM)(&n,dx,&incx);
139}
140
141/*My Blas*******************************************************************************************/
142
143inline void c_icopy(int n,int *dx,int incx,int *dy,int incy)
144{
145 int i;
146
147 for(i=0;i<n;i++) dy[i*incy] = dx[i*incx];
148}
149
150/*Blas*******************************************************************************************/
151
152inline void c_dcopy(int n,double *dx,int incx,double *dy,int incy)
153{
154 fortran_name(dcopy,DCOPY)(&n,dx,&incx,dy,&incy);
155}
156
157/************************************************************************************************/
158
159inline void c_daxpy(int n,double da,double *dx,int incx,double *dy,int incy)
160{
161 fortran_name(daxpy,DAXPY)(&n,&da,dx,&incx,dy,&incy);
162}
163
164/************************************************************************************************/
165
166inline double c_dnrm2(int n,double* dx,int incx)
167{
168 return fortran_name(dnrm2,DNRM2)(&n,dx,&incx);
169}
170
171// NEW SINGLE PRECISION
172/************************************************************************************************/
173
174inline double c_dnrm2(int n,float* dx,int incx)
175{
176 return fortran_name(snrm2,SNRM2)(&n,dx,&incx);
177}
178
179/************************************************************************************************/
180
181inline double c_dznrm2(int n,CPX* dx,int incx)
182{
183 return fortran_name(dznrm2,DZNRM2)(&n,dx,&incx);
184}
185
186/************************************************************************************************/
187
188inline void c_dscal(int n,double da,double *dx,int incx)
189{
190 fortran_name(dscal,DSCAL)(&n,&da,dx,&incx);
191}
192
193/************************************************************************************************/
194
195inline void c_zscal(int n,CPX da,CPX *dx,int incx)
196{
197 fortran_name(zscal,ZSCAL)(&n,&da,dx,&incx);
198}
199
200/************************************************************************************************/
201
202inline void c_dgemm(char transa, char transb, int m, int n, int k, double alpha, double *a, \
203 int lda, double *b, int ldb, double beta, double *c, int ldc)
204{
205 fortran_name(dgemm,DGEMM)(&transa,&transb,&m,&n,&k,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
206}
207
208/************************************************************************************************/
209
210inline void c_sgemm(char transa, char transb, int m, int n, int k, float alpha, float *a, \
211 int lda, float *b, int ldb, float beta, float *c, int ldc)
212{
213 fortran_name(sgemm,SGEMM)(&transa,&transb,&m,&n,&k,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
214}
215
216/************************************************************************************************/
217
218inline void c_dsymm(char side,char uplo, int m, int n, double alpha, double *a, int lda, \
219 double *b, int ldb, double beta, double *c, int ldc)
220{
221 fortran_name(dsymm,DSYMM)(&side,&uplo,&m,&n,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
222}
223
224/************************************************************************************************/
225
226inline void c_dtrsm(char side, char uplo, char transa, char diag, int m, int n, double alpha, double *a, int lda, double *b, int ldb)
227{
228 fortran_name(dtrsm,DTRSM)(&side,&uplo,&transa,&diag,&m,&n,&alpha,a,&lda,b,&ldb);
229}
230
231// new SINGLE PRECISION
232/************************************************************************************************/
233
234inline void c_strsm(char side, char uplo, char transa, char diag, int m, int n, float alpha, float *a, int lda, float *b, int ldb)
235{
236 fortran_name(strsm,STRSM)(&side,&uplo,&transa,&diag,&m,&n,&alpha,a,&lda,b,&ldb);
237}
238
239/************************************************************************************************/
240
241inline void c_zgemm(char transa, char transb, int m, int n, int k, CPX alpha, CPX *a, \
242 int lda, CPX *b, int ldb, CPX beta, CPX *c, int ldc)
243{
244 fortran_name(zgemm,ZGEMM)(&transa,&transb,&m,&n,&k,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
245}
246
247/************************************************************************************************/
248
249inline void c_zsymm(char side, char uplo, int m, int n, CPX alpha, CPX *a, int lda, \
250 CPX *b, int ldb, CPX beta, CPX *c, int ldc)
251{
252 fortran_name(zsymm,ZSYMM)(&side,&uplo,&m,&n,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
253}
254
255/************************************************************************************************/
256
257inline void c_ztrsm(char side, char uplo, char transa, char diag, int m, int n, CPX alpha, CPX *a, int lda, CPX *b, int ldb)
258{
259 fortran_name(ztrsm,ZTRSM)(&side,&uplo,&transa,&diag,&m,&n,&alpha,a,&lda,b,&ldb);
260}
261
262/************************************************************************************************/
263
264inline void c_zhemm(char side, char uplo, int m, int n, CPX alpha, CPX *a, int lda, \
265 CPX *b, int ldb, CPX beta, CPX *c, int ldc)
266{
267 fortran_name(zhemm,ZHEMM)(&side,&uplo,&m,&n,&alpha,a,&lda,b,&ldb,&beta,c,&ldc);
268}
269
270/************************************************************************************************/
271
272inline void c_dgemv(char transa, int m, int n, double alpha, double *a, int lda, double *x, \
273 int incx, double beta, double *y, int incy)
274{
275 fortran_name(dgemv,DGEMV)(&transa,&m,&n,&alpha,a,&lda,x,&incx,&beta,y,&incy);
276}
277
278/************************************************************************************************/
279
280inline void c_zgemv(char transa, int m, int n, CPX alpha, CPX *a, int lda, CPX *x, \
281 int incx, CPX beta, CPX *y, int incy)
282{
283 fortran_name(zgemv,ZGEMV)(&transa,&m,&n,&alpha,a,&lda,x,&incx,&beta,y,&incy);
284
285}
286
287/***********************************************************************************************/
288inline double c_ddot(int n, double *x, int incx, double *y, int incy)
289{
290 return fortran_name(ddot,DDOT)(&n,x,&incx,y,&incy);
291}
292
293/************************************************************************************************/
294
295inline CPX c_zdotc(int n, CPX *x, int incx, CPX *y, int incy)
296{
297
298 //return fortran_name(zdotc,ZDOTC)(&n,x,&incx,y,&incy);
299
300 double real,imag;
301
302 real = c_ddot(n,(double*)x,2*incx,(double*)y,2*incy)+\
303 c_ddot(n,(double*)&x[0]+1,2*incx,(double*)&y[0]+1,2*incy);
304 imag = -c_ddot(n,(double*)&x[0]+1,2*incx,(double*)y,2*incy)+\
305 c_ddot(n,(double*)x,2*incx,(double*)&y[0]+1,2*incy);
306
307 return CPX(real,imag);
308
309}
310
311/************************************************************************************************/
312
313inline void c_zcopy(int n,CPX *dx,int incx,CPX *dy,int incy)
314{
315 fortran_name(zcopy,ZCOPY)(&n,dx,&incx,dy,&incy);
316}
317
318/************************************************************************************************/
319
320inline void c_zaxpy(int n, CPX alpha, CPX *x, int incx, CPX *y, int incy)
321{
322 fortran_name(zaxpy,ZAXPY)(&n,&alpha,x,&incx,y,&incy);
323}
324
325/************************************************************************************************/
326
327inline double c_dasum(int n,double *dx,int incx)
328{
329 return fortran_name(dasum,DASUM)(&n,dx,&incx);
330}
331
332/************************************************************************************************/
333
334template <typename T,typename W>
335inline void c_tcopy(int n,T *dx,int incx,W *dy,int incy);
336
337template<>
338inline void c_tcopy(int n,double *dx,int incx,double *dy,int incy)
339{
340 c_dcopy(n,dx,incx,dy,incy);
341}
342
343template<>
344inline void c_tcopy(int n,CPX *dx,int incx,CPX *dy,int incy)
345{
346 c_zcopy(n,dx,incx,dy,incy);
347}
348
349template<>
350inline void c_tcopy(int n,double *dx,int incx,CPX *dy,int incy)
351{
352 c_dcopy(n,dx,incx,(double*)dy,2*incy);
353}
354
355/************************************************************************************************/
356
357template <typename T,typename W>
358inline void c_taxpy(int n, T alpha, T *x, int incx, W *y, int incy);
359
360template <>
361inline void c_taxpy(int n, double alpha, double *x, int incx, double *y, int incy)
362{
363 c_daxpy(n,alpha,x,incx,y,incy);
364}
365
366template <>
367inline void c_taxpy(int n, CPX alpha, CPX *x, int incx, CPX *y, int incy)
368{
369 c_zaxpy(n,alpha,x,incx,y,incy);
370}
371
372template <>
373inline void c_taxpy(int n, double alpha, double *x, int incx, CPX *y, int incy)
374{
375 c_daxpy(n,alpha,x,incx,(double*)y,2*incy);
376}
377
378/************************************************************************************************/
379
380template <typename T>
381inline double c_dtnrm2(int n,T *x,int incx);
382
383template <>
384inline double c_dtnrm2(int n,double *x,int incx)
385{
386 return c_dnrm2(n,x,incx);
387}
388
389// new SINGLE PRECISION
390template <>
391inline double c_dtnrm2(int n,float *x,int incx)
392{
393 return c_dnrm2(n,x,incx);
394}
395
396template <>
397inline double c_dtnrm2(int n,CPX *x,int incx)
398{
399 return c_dznrm2(n,x,incx);
400}
401
402/************************************************************************************************/
403
404template <typename T>
405inline void c_tscal(int n,T da,T *dx,int incx);
406
407template <>
408inline void c_tscal(int n,double da,double *dx,int incx)
409{
410 c_dscal(n,da,dx,incx);
411}
412
413template <>
414inline void c_tscal(int n,CPX da,CPX *dx,int incx)
415{
416 c_zscal(n,da,dx,incx);
417}
418
419/************************************************************************************************/
420
421template <typename T>
422inline void c_tgemm(char transa, char transb, int m, int n, int k, T alpha, T *a, \
423 int lda, T *b, int ldb, T beta, T *c, int ldc);
424
425template <>
426inline void c_tgemm(char transa, char transb, int m, int n, int k, double alpha, double *a, \
427 int lda, double *b, int ldb, double beta, double *c, int ldc)
428{
429 c_dgemm(transa,transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);
430}
431
432template <>
433inline void c_tgemm(char transa, char transb, int m, int n, int k, float alpha, float *a, \
434 int lda, float *b, int ldb, float beta, float *c, int ldc)
435{
436 c_sgemm(transa,transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);
437}
438
439template <>
440inline void c_tgemm(char transa, char transb, int m, int n, int k, CPX alpha, CPX *a, \
441 int lda, CPX *b, int ldb, CPX beta, CPX *c, int ldc)
442{
443 c_zgemm(transa,transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);
444}
445
446/************************************************************************************************/
447
448template <typename T>
449inline void c_tsymm(char side, char uplo, int m, int n, T alpha, T *a, \
450 int lda, T *b, int ldb, T beta, T *c, int ldc);
451
452template <>
453inline void c_tsymm(char side, char uplo, int m, int n, double alpha, double *a, \
454 int lda, double *b, int ldb, double beta, double *c, int ldc)
455{
456 c_dsymm(side,uplo,m,n,alpha,a,lda,b,ldb,beta,c,ldc);
457}
458
459template <>
460inline void c_tsymm(char side, char uplo, int m, int n, CPX alpha, CPX *a, \
461 int lda, CPX *b, int ldb, CPX beta, CPX *c, int ldc)
462{
463 c_zsymm(side,uplo,m,n,alpha,a,lda,b,ldb,beta,c,ldc);
464}
465
466/************************************************************************************************/
467
468template <typename T>
469inline void c_ttrsm(char side, char uplo, char transa, char diag, int m, int n, \
470 T alpha, T *a, int lda, T *b, int ldb);
471
472// for real numbers (double)
473template <>
474inline void c_ttrsm(char side, char uplo, char transa, char diag, int m, int n, \
475 double alpha, double *a, int lda, double *b, int ldb)
476{
477 c_dtrsm(side,uplo,transa,diag,m,n,alpha,a,lda,b,ldb);
478}
479
480// new SINGLE PRECISION
481template <>
482inline void c_ttrsm(char side, char uplo, char transa, char diag, int m, int n, \
483 float alpha, float *a, int lda, float *b, int ldb)
484{
485 c_strsm(side,uplo,transa,diag,m,n,alpha,a,lda,b,ldb);
486}
487
488// for complex numbers
489template <>
490inline void c_ttrsm(char side, char uplo, char transa, char diag, int m, int n, \
491 CPX alpha, CPX *a, int lda, CPX *b, int ldb)
492{
493 c_ztrsm(side,uplo,transa,diag,m,n,alpha,a,lda,b,ldb);
494}
495
496/************************************************************************************************/
497
498template <typename T>
499inline void tgemm_dev(char transa, char transb, int m, int n, int k, T alpha, T *a,\
500 int lda, T *b, int ldb, T beta, T *c, int ldc, magma_queue_t queue);
501
502template <>
503inline void tgemm_dev(char transa, char transb, int m, int n, int k, double alpha,\
504 double *a, int lda, double *b, int ldb, double beta, double *c, int ldc, magma_queue_t queue)
505{
506 magma_trans_t magma_transa = magma_trans_const(transa);
507 magma_trans_t magma_transb = magma_trans_const(transb);
508
509 //dgemm_on_dev(handle,transa,transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);
510 magma_dgemm(magma_transa,magma_transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc,queue);
511
512}
513
514// new SINGLE PRECISION
515template <>
516inline void tgemm_dev(char transa, char transb, int m, int n, int k, float alpha,\
517 float *a, int lda, float *b, int ldb, float beta, float *c, int ldc, magma_queue_t queue)
518{
519 magma_trans_t magma_transa = magma_trans_const(transa);
520 magma_trans_t magma_transb = magma_trans_const(transb);
521
522 //dgemm_on_dev(handle,transa,transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);
523 magma_sgemm(magma_transa,magma_transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc,queue);
524
525}
526
527
528template <>
529inline void tgemm_dev(char transa, char transb, int m, int n, int k, CPX alpha,\
530 CPX *a, int lda, CPX *b, int ldb, CPX beta, CPX *c, int ldc, magma_queue_t queue)
531{
532 magma_trans_t magma_transa = magma_trans_const(transa);
533 magma_trans_t magma_transb = magma_trans_const(transb);
534
535 //zgemm_on_dev(handle,transa,transb,m,n,k,alpha,a,lda,b,ldb,beta,c,ldc);
536 magma_zgemm(magma_transa,magma_transb,m,n,k,*reinterpret_cast<magmaDoubleComplex*>(&alpha),(magmaDoubleComplex_ptr)a,lda,(magmaDoubleComplex_ptr)b,ldb,*reinterpret_cast<magmaDoubleComplex*>(&beta),(magmaDoubleComplex_ptr)c,ldc,queue);
537}
538
539/************************************************************************************************/
540
541template <typename T>
542inline void taxpy_dev(void *handle,int n,T alpha,T *x,int incx,T *y,int incy);
543
544template <>
545inline void taxpy_dev(void *handle,int n,double alpha,double *x,int incx,double *y,int incy)
546{
547 daxpy_on_dev(handle,n,alpha,x,incx,y,incy);
548}
549
550template <>
551inline void taxpy_dev(void *handle,int n,CPX alpha,CPX *x,int incx,CPX *y,int incy)
552{
553 zaxpy_on_dev(handle,n,alpha,x,incx,y,incy);
554}
555
556/************************************************************************************************/
557
558template <typename T>
559inline void tasum_dev(void *handle, int n, T *x, int incx, T *result);
560
561template <>
562inline void tasum_dev(void *handle, int n, double *x, int incx, double *result)
563{
564 dasum_on_dev(handle, n, x, incx, result);
565}
566
567template <>
568inline void tasum_dev(void *handle, int n, CPX *x, int incx, CPX *result)
569{
570 double dRes = 0.0;
571 zasum_on_dev(handle, n, x, incx, &dRes);
572 *result = CPX(dRes, 0.0);
573}
574
575/************************************************************************************************/
576
577template <typename T>
578inline void tsum_dev(int n, T *x, int incx, T *result, magma_queue_t queue);
579
580template <>
581inline void tsum_dev(int n, double *x, int incx, double *result, magma_queue_t queue)
582{
583 dsum_on_dev(n, x, incx, result, queue);
584}
585
586template <>
587inline void tsum_dev(int n, CPX *x, int incx, CPX *result, magma_queue_t queue)
588{
589 zsum_on_dev(n, x, incx, result, queue);
590}
591
592/*magma*****************************************************************************************/
593
594template <typename T>
595inline void tgetrf_dev(int m,int n,T *a,int lda,int *ipiv,int *info);
596
597template <>
598inline void tgetrf_dev(int m,int n,double *a,int lda,int *ipiv,int *info)
599{
600 magma_dgetrf_gpu(m,n,a,lda,ipiv,info);
601}
602
603template <>
604inline void tgetrf_dev(int m,int n,CPX *a,int lda,int *ipiv,int *info)
605{
606 magma_zgetrf_gpu(m,n,(magmaDoubleComplex_ptr)a,lda,ipiv,info);
607}
608
609/************************************************************************************************/
610
611template <typename T>
612inline void tgetrs_dev(char transa,int n,int nrhs,T *a,int lda,int *ipiv,T *b,int ldb,int *info);
613
614template <>
615inline void tgetrs_dev(char transa,int n,int nrhs,double *a,int lda,int *ipiv,double *b,int ldb,\
616 int *info)
617{
618 magma_trans_t magma_transa = magma_trans_const(transa);
619
620 magma_dgetrs_gpu(magma_transa,n,nrhs,a,lda,ipiv,b,ldb,info);
621}
622
623template <>
624inline void tgetrs_dev(char transa,int n,int nrhs,CPX *a,int lda,int *ipiv,CPX *b,int ldb,\
625 int *info)
626{
627 magma_trans_t magma_transa = magma_trans_const(transa);
628
629 magma_zgetrs_gpu(magma_transa,n,nrhs,(magmaDoubleComplex_ptr)a,lda,ipiv,\
630 (magmaDoubleComplex_ptr)b,ldb,info);
631}
632
633/************************************************************************************************/
634
635template <typename T>
636inline void tgesv_dev(int n,int nrhs,T *a,int lda,int *ipiv,T *b,int ldb,int type,int *info);
637
638template <>
639inline void tgesv_dev(int n,int nrhs,double *a,int lda,int *ipiv,double *b,int ldb,int type,int *info)
640{
641 if(type){
642 magma_dgesv_nopiv_gpu(n,nrhs,a,lda,b,ldb,info);
643 }else{
644 /*
645 magma_dgesv_nopiv_gpu(n,nrhs,a,lda,b,ldb,info);
646 */
647 magma_dsysv_nopiv_gpu(MagmaLower,n,nrhs,a,lda,b,ldb,info);
648 }
649}
650
651template <>
652inline void tgesv_dev(int n,int nrhs,CPX *a,int lda,int *ipiv,CPX *b,int ldb,int type,int *info)
653{
654
655 if(type){
656 magma_zgesv_nopiv_gpu(n,nrhs,(magmaDoubleComplex_ptr)a,lda,(magmaDoubleComplex_ptr)b,\
657 ldb,info);
658 }else{
659 /*
660 magma_zgesv_nopiv_gpu(n,nrhs,(magmaDoubleComplex_ptr)a,lda,(magmaDoubleComplex_ptr)b,\
661 ldb,info);
662 */
663 magma_zhesv_nopiv_gpu(MagmaLower,n,nrhs,(magmaDoubleComplex_ptr)a,lda,(magmaDoubleComplex_ptr)b,\
664 ldb,info);
665 }
666}
667
668/************************************************************************************************/
669
670template <typename T>
671inline void tgetri_dev(int n,T *a,int lda,int *ipiv,T *work,int lwork,int *info);
672
673template <>
674inline void tgetri_dev(int n,double *a,int lda,int *ipiv,double *work,int lwork,int *info)
675{
676 magma_dgetri_gpu(n,a,lda,ipiv,work,lwork,info);
677}
678
679template <>
680inline void tgetri_dev(int n,CPX *a,int lda,int *ipiv,CPX *work,int lwork,int *info)
681{
682 magma_zgetri_gpu(n,(magmaDoubleComplex_ptr)a,lda,ipiv,(magmaDoubleComplex_ptr)work,lwork,info);
683}
684
685/************************************************************************************************/
686
687#ifdef CUDA_POTRF
688
689template <typename T>
690inline void ttrtri_dev_cuda(char uplo, char unit_diag, int n,T *a,int lda,int *info, int* cuda_buffer_flag,
691 cusolverDnHandle_t* handle, cudaStream_t stream,
692 size_t *dev_size, size_t *host_size, double *mem_cuda_dev, double *mem_cuda_host);
693
694template <>
695inline void ttrtri_dev_cuda(char uplo, char unit_diag, int n, double *a,int lda,int *info, int *cuda_buffer_flag,
696 cusolverDnHandle_t *handle, cudaStream_t stream,
697 size_t *dev_size, size_t *host_size, double *mem_cuda_dev, double *mem_cuda_host)
698{
699
700 cublasFillMode_t uplo_cuda;
701 if(uplo == 'L'){
702 uplo_cuda = CUBLAS_FILL_MODE_LOWER;
703 } else if(uplo =='U') {
704 uplo_cuda = CUBLAS_FILL_MODE_UPPER;
705 } else {
706 printf("invalid potrf argmument uplo: %c\n", uplo);
707 exit(1);
708 }
709
710 // to indicate whether diagonal has "unity", i.e. is all ones?? which it is not?!
711 cublasDiagType_t unit_diag_cuda;
712 if(unit_diag == 'n' || unit_diag == 'N'){
713 unit_diag_cuda = CUBLAS_DIAG_NON_UNIT;
714 } else if(unit_diag == 'u' || unit_diag == 'U'){
715 unit_diag_cuda = CUBLAS_DIAG_UNIT;
716 } else {
717 printf("invalid trtri argmument unit_diag: %c\n", unit_diag);
718 exit(1);
719 }
720
721 // initialize cuda buffer
722 if(cuda_buffer_flag[0] == 0){
723
724 cusolverStatus_t cuSolverError = cusolverDnXtrtri_bufferSize(
725 *handle, uplo_cuda, unit_diag_cuda, n, CUDA_R_64F, a, lda, dev_size, host_size);
726
727 if(cuSolverError != 0){
728 printf("cuSolverError buffer size allocation!\n");
729 exit(1);
730 }
731
732 cuda_buffer_flag[0] = 1;
733
734 //printf("dev size = %ld, host size = %ld\n", *dev_size, *host_size);
735
736
737 }
738
739 cudaMalloc((void**)&mem_cuda_dev, (*dev_size) * sizeof(double));
740 cudaMallocHost((void**)&mem_cuda_host, (*host_size) * sizeof(double));
741
742 //printf("n = %d, lda = %d, uplo = %c, dev_size = %ld, host_size = %ld\n", n, lda, uplo, dev_size[0], host_size[0]);
743
744 cusolverStatus_t cuSolverError = cusolverDnSetStream(*handle, stream);
745 if(cuSolverError != 0){
746 printf("cuSolverError set Stream! cuSolverError = %d\n", cuSolverError);
747 exit(1);
748 }
749
750 cuSolverError = cusolverDnXtrtri(
751 *handle, uplo_cuda, unit_diag_cuda, n, CUDA_R_64F, a, lda,
752 mem_cuda_dev, *dev_size, mem_cuda_host, *host_size, info);
753
754 if(cuSolverError != 0){
755 printf("cuSolverError trtri! cuSolverError : %d, lda = %d\n", cuSolverError, lda);
756 exit(1);
757 }
758
759 int info_host;
760 cudaMemcpyAsync(&info_host, info, sizeof(int), cudaMemcpyDeviceToHost, stream);
761
762 if(info_host != 0){
763 printf("cuSolverError trtri info not zero! info = %d, lda = %d\n", info_host, lda);
764 exit(1);
765 }
766
767}
768
769// *** new SINGLE PRECISION
770template <>
771inline void ttrtri_dev_cuda(char uplo, char unit_diag, int n, float *a,int lda,int *info, int *cuda_buffer_flag,
772 cusolverDnHandle_t *handle, cudaStream_t stream,
773 size_t *dev_size, size_t *host_size, double *mem_cuda_dev, double *mem_cuda_host)
774{
775
776 cublasFillMode_t uplo_cuda;
777 if(uplo == 'L'){
778 uplo_cuda = CUBLAS_FILL_MODE_LOWER;
779 } else if(uplo =='U') {
780 uplo_cuda = CUBLAS_FILL_MODE_UPPER;
781 } else {
782 printf("invalid potrf argmument uplo: %c\n", uplo);
783 exit(1);
784 }
785
786 // to indicate whether diagonal has "unity", i.e. is all ones?? which it is not?!
787 cublasDiagType_t unit_diag_cuda;
788 if(unit_diag == 'n' || unit_diag == 'N'){
789 unit_diag_cuda = CUBLAS_DIAG_NON_UNIT;
790 } else if(unit_diag == 'u' || unit_diag == 'U'){
791 unit_diag_cuda = CUBLAS_DIAG_UNIT;
792 } else {
793 printf("invalid trtri argmument unit_diag: %c\n", unit_diag);
794 exit(1);
795 }
796
797 // initialize cuda buffer
798 if(cuda_buffer_flag[0] == 0){
799
800 cusolverStatus_t cuSolverError = cusolverDnXtrtri_bufferSize(
801 *handle, uplo_cuda, unit_diag_cuda, n, CUDA_R_32F, a, lda, dev_size, host_size);
802
803 if(cuSolverError != 0){
804 printf("cuSolverError buffer size allocation!\n");
805 exit(1);
806 }
807
808 cuda_buffer_flag[0] = 1;
809
810 //printf("dev size = %ld, host size = %ld\n", *dev_size, *host_size);
811
812
813 }
814
815 cudaMalloc((void**)&mem_cuda_dev, (*dev_size) * sizeof(double));
816 cudaMallocHost((void**)&mem_cuda_host, (*host_size) * sizeof(double));
817
818 //printf("n = %d, lda = %d, uplo = %c, dev_size = %ld, host_size = %ld\n", n, lda, uplo, dev_size[0], host_size[0]);
819
820 cusolverStatus_t cuSolverError = cusolverDnSetStream(*handle, stream);
821 if(cuSolverError != 0){
822 printf("cuSolverError set Stream! cuSolverError = %d\n", cuSolverError);
823 exit(1);
824 }
825
826 cuSolverError = cusolverDnXtrtri(
827 *handle, uplo_cuda, unit_diag_cuda, n, CUDA_R_32F, a, lda,
828 mem_cuda_dev, *dev_size, mem_cuda_host, *host_size, info);
829
830 if(cuSolverError != 0){
831 printf("cuSolverError trtri! cuSolverError : %d, lda = %d\n", cuSolverError, lda);
832 exit(1);
833 }
834
835 int info_host;
836 cudaMemcpyAsync(&info_host, info, sizeof(int), cudaMemcpyDeviceToHost, stream);
837
838 if(info_host != 0){
839 printf("cuSolverError trtri info not zero! info = %d, lda = %d\n", info_host, lda);
840 exit(1);
841 }
842
843}
844
845
846template <>
847inline void ttrtri_dev_cuda(char uplo, char unit_diag, int n, CPX *a,int lda,int *info, int *cuda_buffer_flag,
848 cusolverDnHandle_t *handle, cudaStream_t stream,
849 size_t *dev_size, size_t *host_size, double *mem_cuda_dev, double *mem_cuda_host)
850{
851 printf("just a placeholder. not working!\n");
852 exit(1);
853}
854
855#else
856
857template <typename T>
858inline void ttrtri_dev(char uplo,char diag,int n,T *a,int lda,int *info);
859
860template <>
861inline void ttrtri_dev(char uplo,char diag,int n,double *a,int lda,int *info)
862{
863 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
864 magma_diag_t magma_diag = magma_diag_const(diag);
865
866 magma_dtrtri_gpu(magma_uplo,magma_diag,n,a,lda,info);
867}
868
869// new SINGLE PRECISION
870template <>
871inline void ttrtri_dev(char uplo,char diag,int n,float *a,int lda,int *info)
872{
873 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
874 magma_diag_t magma_diag = magma_diag_const(diag);
875
876 //printf("trtri float.\n");
877 //cudaDeviceSynchronize();
878
879 magma_strtri_gpu(magma_uplo,magma_diag,n,a,lda,info);
880}
881
882template <>
883inline void ttrtri_dev(char uplo,char diag,int n,CPX *a,int lda,int *info)
884{
885 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
886 magma_diag_t magma_diag = magma_diag_const(diag);
887
888 magma_ztrtri_gpu(magma_uplo,magma_diag,n,(magmaDoubleComplex_ptr)a,lda,info);
889}
890
891#endif // CUDA / MAGMA TRTRI
892
893/************************************************************************************************/
894
895inline void zgetrf_nopiv_dev(int m,int n,CPX *a,int lda,int *info)
896{
897 magma_zgetrf_nopiv_gpu(m,n,(magmaDoubleComplex_ptr)a,lda,info);
898}
899
900/************************************************************************************************/
901
902inline void zgetri_dev(int n,CPX *a,int lda,int *ipiv,CPX *work,int lwork,int *info)
903{
904 magma_zgetri_gpu(n,(magmaDoubleComplex_ptr)a,lda,ipiv,(magmaDoubleComplex_ptr)work,lwork,info);
905}
906
907/************************************************************************************************/
908
909inline void zgetrs_nopiv_dev(char transa,int n,int nrhs,CPX *a,int lda,CPX *b,int ldb,int *info)
910{
911 magma_trans_t magma_transa = magma_trans_const(transa);
912
913 magma_zgetrs_nopiv_gpu(magma_transa,n,nrhs,(magmaDoubleComplex_ptr)a,lda,\
914 (magmaDoubleComplex_ptr)b,ldb,info);
915}
916
917/************************************************************************************************/
918
919#ifdef CUDA_POTRF
920
921template <typename T>
922inline void tpotrf_dev_cuda(char uplo,int n,T *a,int lda,int *info, int* cuda_buffer_flag,
923 cusolverDnHandle_t* handle, cusolverDnParams_t *params, cudaStream_t stream,
924 size_t *dev_size, size_t *host_size, double *mem_cuda_dev, double *mem_cuda_host);
925
926template <>
927inline void tpotrf_dev_cuda(char uplo,int n,double *a,int lda,int *info, int *cuda_buffer_flag,
928 cusolverDnHandle_t *handle, cusolverDnParams_t *params, cudaStream_t stream,
929 size_t *dev_size, size_t *host_size, double *mem_cuda_dev, double *mem_cuda_host)
930{
931
932 // initialize cuda buffer
933 //printf("in Potrf. cuda buffer flag: %d\n", cuda_buffer_flag[0]);
934
935#if 0
936 // CAREFUL with STATIC variables. GLOBAL!
937 static double* save_buffer_dev;
938 static double* save_buffer_dev2;
939 static double* save_buffer_host1;
940 static double* save_buffer_host2;
941 static double res;
942#endif
943
944 cublasFillMode_t uplo_cuda;
945
946 if(uplo == 'L'){
947 uplo_cuda = CUBLAS_FILL_MODE_LOWER;
948 } else if(uplo =='U') {
949 uplo_cuda = CUBLAS_FILL_MODE_UPPER;
950 } else {
951 printf("invalid potrf argmument uplo: %c\n", uplo);
952 exit(1);
953 }
954
955 if(cuda_buffer_flag[0] == 0){
956
957 cusolverStatus_t cuSolverError = cusolverDnXpotrf_bufferSize(
958 *handle, *params, uplo_cuda, n, CUDA_R_64F, a, lda, CUDA_R_64F, dev_size, host_size);
959 if(cuSolverError != 0){
960 printf("cuSolverError buffer size allocation!\n");
961 exit(1);
962 }
963
964 cuda_buffer_flag[0] = 1;
965
966#if 0
967 printf("dev size = %ld, host size = %ld\n", *dev_size, *host_size);
968 cudaMalloc(&save_buffer_dev ,n*lda*sizeof(double));
969 cudaMalloc(&save_buffer_dev2,n*lda*sizeof(double));
970 cudaMallocHost(&save_buffer_host1,n*lda*sizeof(double));
971 cudaMallocHost(&save_buffer_host2,n*lda*sizeof(double));
972 cudaMemcpy(save_buffer_host1, a, n*lda*sizeof(double), cudaMemcpyDeviceToHost);
973#endif
974
975 } else {
976
977#if 0
978 cudaMemcpy(save_buffer_host2, a, n*lda*sizeof(double), cudaMemcpyDeviceToHost);
979 //printf("after cuda mem copy\n");
980 cudaDeviceSynchronize();
981 double temp = 0;
982 for(int i=0; i<n*lda; i++){
983 temp += (save_buffer_host1[i]-save_buffer_host2[i])*(save_buffer_host1[i]-save_buffer_host2[i]);
984 }
985 printf("temp = %f\n", sqrt(temp/(n*lda)));
986 cudaDeviceSynchronize();
987#endif
988
989 }
990
991 cudaMalloc((void**)&mem_cuda_dev, (*dev_size) * sizeof(double));
992 cudaMallocHost((void**)&mem_cuda_host, (*host_size) * sizeof(double));
993
994 //printf("n = %d, lda = %d, uplo = %c, dev_size = %ld, host_size = %ld\n", n, lda, uplo, dev_size[0], host_size[0]);
995
996 cusolverStatus_t cuSolverError = cusolverDnSetStream(*handle, stream);
997 if(cuSolverError != 0){
998 printf("cuSolverError set Stream! cuSolverError = %d\n", cuSolverError);
999 exit(1);
1000 }
1001
1002 cuSolverError = cusolverDnXpotrf(
1003 *handle, *params, uplo_cuda, n, CUDA_R_64F, a, lda, CUDA_R_64F,
1004 mem_cuda_dev, *dev_size, mem_cuda_host, *host_size, info);
1005
1006
1007 if(cuSolverError != 0){
1008 printf("cuSolverError potrf! cuSolverError : %d, lda = %d\n", cuSolverError, lda);
1009 exit(1);
1010 }
1011
1012 int info_host;
1013 cudaMemcpyAsync(&info_host, info, sizeof(int), cudaMemcpyDeviceToHost, stream);
1014
1015 if(info_host != 0){
1016 printf("cuSolverError potrf info not zero! info = %d, lda = %d\n", info_host, lda);
1017 exit(1);
1018 }
1019
1020}
1021
1022// *** new SINGLE PRECISION *** //
1023template <>
1024inline void tpotrf_dev_cuda(char uplo,int n,float *a,int lda,int *info, int *cuda_buffer_flag,
1025 cusolverDnHandle_t *handle, cusolverDnParams_t *params, cudaStream_t stream,
1026 size_t *dev_size, size_t *host_size, double *mem_cuda_dev, double *mem_cuda_host)
1027{
1028
1029 // initialize cuda buffer
1030 //printf("in Potrf. cuda buffer flag: %d\n", cuda_buffer_flag[0]);
1031
1032 cublasFillMode_t uplo_cuda;
1033
1034 if(uplo == 'L'){
1035 uplo_cuda = CUBLAS_FILL_MODE_LOWER;
1036 } else if(uplo =='U') {
1037 uplo_cuda = CUBLAS_FILL_MODE_UPPER;
1038 } else {
1039 printf("invalid potrf argmument uplo: %c\n", uplo);
1040 exit(1);
1041 }
1042
1043 if(cuda_buffer_flag[0] == 0){
1044
1045 cusolverStatus_t cuSolverError = cusolverDnXpotrf_bufferSize(
1046 *handle, *params, uplo_cuda, n, CUDA_R_32F, a, lda, CUDA_R_32F, dev_size, host_size);
1047 if(cuSolverError != 0){
1048 printf("cuSolverError buffer size allocation!\n");
1049 exit(1);
1050 }
1051
1052 cuda_buffer_flag[0] = 1;
1053 }
1054
1055 cudaMalloc((void**)&mem_cuda_dev, (*dev_size) * sizeof(double));
1056 cudaMallocHost((void**)&mem_cuda_host, (*host_size) * sizeof(double));
1057
1058 //printf("n = %d, lda = %d, uplo = %c, dev_size = %ld, host_size = %ld\n", n, lda, uplo, dev_size[0], host_size[0]);
1059
1060 cusolverStatus_t cuSolverError = cusolverDnSetStream(*handle, stream);
1061 if(cuSolverError != 0){
1062 printf("cuSolverError set Stream! cuSolverError = %d\n", cuSolverError);
1063 exit(1);
1064 }
1065
1066 cuSolverError = cusolverDnXpotrf(
1067 *handle, *params, uplo_cuda, n, CUDA_R_32F, a, lda, CUDA_R_32F,
1068 mem_cuda_dev, *dev_size, mem_cuda_host, *host_size, info);
1069
1070
1071 if(cuSolverError != 0){
1072 printf("cuSolverError potrf! cuSolverError : %d, lda = %d\n", cuSolverError, lda);
1073 exit(1);
1074 }
1075
1076 int info_host;
1077 cudaMemcpyAsync(&info_host, info, sizeof(int), cudaMemcpyDeviceToHost, stream);
1078
1079 if(info_host != 0){
1080 printf("cuSolverError potrf info not zero! info = %d, lda = %d\n", info_host, lda);
1081 exit(1);
1082 }
1083
1084}
1085
1086template <>
1087inline void tpotrf_dev_cuda(char uplo,int n,CPX *a,int lda,int *info, int *cuda_buffer_flag,
1088 cusolverDnHandle_t *handle, cusolverDnParams_t *params, cudaStream_t stream,
1089 size_t *dev_size, size_t *host_size, double *mem_cuda_dev, double *mem_cuda_host)
1090{
1091 printf("just a placeholder. not working!\n");
1092 exit(1);
1093}
1094
1095#elif defined(MAGMA_EXPERT)
1096
1097template <typename T>
1098inline void magma_tpotrf_expert_wrapper(char uplo, int n, T* a, int lda, int info[1], magma_mode_t mode, int subN, int subSubN,
1099 void* host_work, int *lwork_host, void* device_work, int *lwork_device,
1100 magma_event_t events[2], magma_queue_t queues[2], int& init_flag){
1101
1102 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1103
1104 if(init_flag < 1){
1105 *lwork_host = -1;
1106 *lwork_device = -1;
1107 printf("in initializing potrf expert gpu.\n");
1108
1109 magma_tpotrf_expert_gpu_work(magma_uplo, n, a, lda, info, mode, subN, subSubN,
1110 host_work, lwork_host, device_work, lwork_device, events, queues);
1111 printf("DPOTRF EXPERT WORK. Allocating %d of host memory. Allocating %d of device memory.\n", *lwork_host, *lwork_device);
1112
1113 if(lwork_device > 0){
1114 gpuErrchk(cudaMalloc((void**)&device_work,*lwork_device));
1115 }
1116
1117 if(lwork_host > 0){
1118 gpuErrchk(cudaMallocHost((void**)&host_work,*lwork_host));
1119 }
1120
1121 init_flag = 1;
1122 }
1123
1124 magma_tpotrf_expert_gpu_work(magma_uplo, n, a, lda, info, mode, subN, subSubN,
1125 host_work, lwork_host, device_work, lwork_device, events, queues);
1126}
1127
1128template <typename T>
1129inline void magma_tpotrf_expert_gpu_work(magma_uplo_t magma_uplo, int n, T* a, int lda, int info[1], magma_mode_t mode, int subN, int subSubN, void* host_work, magma_int_t *lwork_host, void* device_work, magma_int_t *lwork_device, magma_event_t events[2], magma_queue_t queues[2]);
1130
1131// new magma -- double precision
1132template <>
1133inline void magma_tpotrf_expert_gpu_work(magma_uplo_t magma_uplo, int n, double* a, int lda, int info[1], magma_mode_t mode, int subN, int subSubN, void* host_work, int *lwork_host, void* device_work, int *lwork_device, magma_event_t events[2], magma_queue_t queues[2])
1134{
1135 magma_int_t potrfErr = magma_dpotrf_expert_gpu_work(magma_uplo, n, a, lda, info, mode, subN, subSubN,
1136 host_work, lwork_host, device_work, lwork_device, events, queues);
1137 if(potrfErr != 0){
1138 std::cout << "in magma potrf expert work error = " << potrfErr << std::endl;
1139 exit(1);
1140 }
1141}
1142
1143// new magma -- single precision
1144template <>
1145inline void magma_tpotrf_expert_gpu_work(magma_uplo_t magma_uplo, int n, float* a, int lda, int info[1], magma_mode_t mode, int subN, int subSubN, void* host_work, int *lwork_host, void* device_work, int *lwork_device, magma_event_t events[2], magma_queue_t queues[2])
1146{
1147 magma_int_t potrfErr = magma_spotrf_expert_gpu_work(magma_uplo, n, a, lda, info, mode, subN, subSubN,
1148 host_work, lwork_host, device_work, lwork_device, events, queues);
1149 if(potrfErr != 0){
1150 std::cout << "in magma potrf expert work error = " << potrfErr << std::endl;
1151 exit(1);
1152 }
1153}
1154
1155#else
1156
1157template <typename T>
1158inline void tpotrf_dev(char uplo,int n,T *a,int lda,int *info);
1159
1160// OLD POTRF WITH MAMGA which gives problems for _native (when run in non-blocking mode)
1161template <>
1162inline void tpotrf_dev(char uplo,int n,double *a,int lda,int *info)
1163{
1164 //double a_host;
1165 //gpuErrchk(cudaMemcpy(&a_host, &a[0], sizeof(double), cudaMemcpyDeviceToHost));
1166 //printf("in potrf. n = %d, lda = %d, a[0] = %f\n", n, lda, a_host);
1167
1168 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1169 // using the magma_dpotrf_expert_gpu interface to
1170 // 1. put the entire Dpotrf function on the GPU and
1171 // 2. explicitly specify the block size to use.
1172 // Of powers of 2 64 appears to be the best for our test case - 128 would be second-best.
1173 // Did not evaluate intervening values.
1174 //magma_int_t potrfErr = magma_dpotrf_expert_gpu(magma_uplo,n,a,lda,info, 64, MagmaNative ); // native: gpu only
1175 magma_int_t potrfErr = magma_dpotrf_gpu(magma_uplo,n,a,lda,info); // hybrid version
1176 //magma_int_t potrfErr = magma_dpotrf_native(magma_uplo,n,a,lda,info); // gpu only version
1177
1178 /*double a_problem_host;
1179 int ind_error = 991;
1180 gpuErrchk(cudaMemcpy(&a_problem_host, &a[ind_error], sizeof(double), cudaMemcpyDeviceToHost));
1181 printf("ind error = %d, a[ind_error] = %f\n", ind_error, a_problem_host);
1182 exit(1);*/
1183
1184 if(potrfErr != 0){
1185 std::cout << "magma potrf error = " << potrfErr << std::endl;
1186 int ind_error = (potrfErr-1)*(lda+1);
1187 double a_problem_host;
1188 gpuErrchk(cudaMemcpy(&a_problem_host, &a[ind_error], sizeof(double), cudaMemcpyDeviceToHost));
1189 printf("ind error = %d, a[ind_error] = %f\n", ind_error, a_problem_host);
1190 exit(1);
1191 }
1192}
1193
1194// new SINGLE PRECISION
1195template <>
1196inline void tpotrf_dev(char uplo,int n,float *a,int lda,int *info)
1197{
1198 //double a_host;
1199 //gpuErrchk(cudaMemcpy(&a_host, &a[0], sizeof(double), cudaMemcpyDeviceToHost));
1200 //printf("in potrf. n = %d, lda = %d, a[0] = %f\n", n, lda, a_host);
1201
1202 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1203 //magma_int_t potrfErr = magma_dpotrf_expert_gpu(magma_uplo,n,a,lda,info, 64, MagmaNative ); // native: gpu only
1204 magma_int_t potrfErr = magma_spotrf_gpu(magma_uplo,n,a,lda,info); // hybrid version
1205 //magma_int_t potrfErr = magma_dpotrf_native(magma_uplo,n,a,lda,info); // gpu only version
1206
1207 if(potrfErr != 0){
1208 std::cout << "magma potrf error = " << potrfErr << std::endl;
1209 int ind_error = (potrfErr-1)*(lda+1);
1210 double a_problem_host;
1211 gpuErrchk(cudaMemcpy(&a_problem_host, &a[ind_error], sizeof(double), cudaMemcpyDeviceToHost));
1212 printf("ind error = %d, a[ind_error] = %f\n", ind_error, a_problem_host);
1213 exit(1);
1214 }
1215}
1216
1217template <>
1218inline void tpotrf_dev(char uplo,int n,CPX *a,int lda,int *info)
1219{
1220 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1221
1222 magma_zpotrf_gpu(magma_uplo,n,(magmaDoubleComplex_ptr)a,lda,info);
1223}
1224
1225//#else
1226
1227#endif // end #ifelse defined CUDA_POTRF
1228
1229
1230/************************************************************************************************/
1231
1232template <typename T>
1233inline void ttrsm_dev(char side, char uplo, char trans, char diag, int m, int n, T alpha, T *a, int lda, T *b, int ldb, magma_queue_t queue);
1234
1235template <>
1236inline void ttrsm_dev(char side, char uplo, char trans, char diag, int m, int n, double alpha, double *a, int lda, double *b, int ldb, magma_queue_t queue)
1237{
1238 magma_side_t magma_side = magma_side_const(side);
1239 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1240 magma_trans_t magma_trans = magma_trans_const(trans);
1241 magma_diag_t magma_diag = magma_diag_const(diag);
1242
1243 magma_dtrsm(magma_side, magma_uplo, magma_trans, magma_diag, m, n, alpha, a, lda, b, ldb, queue);
1244}
1245
1246// new SINGLE PRECISION
1247template <>
1248inline void ttrsm_dev(char side, char uplo, char trans, char diag, int m, int n, float alpha, float *a, int lda, float *b, int ldb, magma_queue_t queue)
1249{
1250 magma_side_t magma_side = magma_side_const(side);
1251 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1252 magma_trans_t magma_trans = magma_trans_const(trans);
1253 magma_diag_t magma_diag = magma_diag_const(diag);
1254
1255 //printf("in ttrsm float.\n");
1256 //cudaDeviceSynchronize();
1257
1258 magma_strsm(magma_side, magma_uplo, magma_trans, magma_diag, m, n, alpha, a, lda, b, ldb, queue);
1259}
1260
1261template <>
1262inline void ttrsm_dev(char side, char uplo, char trans, char diag, int m, int n, CPX alpha, CPX *a, int lda, CPX *b, int ldb, magma_queue_t queue)
1263{
1264 magma_side_t magma_side = magma_side_const(side);
1265 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1266 magma_trans_t magma_trans = magma_trans_const(trans);
1267 magma_diag_t magma_diag = magma_diag_const(diag);
1268
1269 magma_ztrsm(magma_side, magma_uplo, magma_trans, magma_diag, m, n, *reinterpret_cast<magmaDoubleComplex*>(&alpha), (magmaDoubleComplex_ptr)a, lda, (magmaDoubleComplex_ptr)b, ldb, queue);
1270}
1271
1272/************************************************************************************************/
1273
1274template <typename T>
1275inline void tlacpy_dev(char uplo, int m, int n, T *a, int lda, T *b, int ldb, magma_queue_t queue);
1276
1277template <>
1278inline void tlacpy_dev(char uplo, int m, int n, double *a, int lda, double *b, int ldb, magma_queue_t queue)
1279{
1280 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1281
1282 magmablas_dlacpy(magma_uplo, m, n, a, lda, b, ldb, queue);
1283
1284}
1285
1286// new SINGLE PRECISION
1287template <>
1288inline void tlacpy_dev(char uplo, int m, int n, float *a, int lda, float *b, int ldb, magma_queue_t queue)
1289{
1290 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1291 //printf("in tlacpy float.\n");
1292 //cudaDeviceSynchronize();
1293
1294 magmablas_slacpy(magma_uplo, m, n, a, lda, b, ldb, queue);
1295
1296}
1297
1298template <>
1299inline void tlacpy_dev(char uplo, int m, int n, CPX *a, int lda, CPX *b, int ldb, magma_queue_t queue)
1300{
1301 magma_uplo_t magma_uplo = magma_uplo_const(uplo);
1302
1303 magmablas_zlacpy(magma_uplo, m, n, (magmaDoubleComplex_ptr)a, lda, (magmaDoubleComplex_ptr)b, ldb, queue);
1304}
1305
1306/************************************************************************************************/
1307
1308/*Lapack*****************************************************************************************/
1309
1310inline void c_dgetrf(int m, int n, double *a, int lda, int *ipiv, int *info)
1311{
1312 fortran_name(dgetrf,DGETRF)(&m,&n,a,&lda,ipiv,info);
1313}
1314
1315/************************************************************************************************/
1316
1317inline void c_dgetrs(char transa, int n, int nrhs, double *a, int lda, int *ipiv, double *b, \
1318 int ldb, int *info)
1319{
1320 fortran_name(dgetrs,DGETRS)(&transa,&n,&nrhs,a,&lda,ipiv,b,&ldb,info);
1321}
1322
1323/************************************************************************************************/
1324
1325inline void c_zgetrf(int m, int n, CPX *a, int lda, int *ipiv, int *info)
1326{
1327 fortran_name(zgetrf,ZGETRF)(&m,&n,a,&lda,ipiv,info);
1328}
1329
1330/************************************************************************************************/
1331
1332inline void c_zgetrs(char transa, int n, int nrhs, CPX *a, int lda, int *ipiv, CPX *b, \
1333 int ldb, int *info)
1334{
1335 fortran_name(zgetrs,ZGETRS)(&transa,&n,&nrhs,a,&lda,ipiv,b,&ldb,info);
1336}
1337
1338/************************************************************************************************/
1339
1340inline void c_zgetri(int n,CPX *a,int lda,int *ipiv,CPX *work,int lwork,int *info)
1341{
1342 fortran_name(zgetri,ZGETRI)(&n,a,&lda,ipiv,work,&lwork,info);
1343}
1344
1345/************************************************************************************************/
1346
1347inline void c_dgeev(char jobvl, char jobvr, int n, double *a, int lda, double *wr, double *wi, \
1348 double *vl, int ldvl, double *vr, int ldvr, double *work, int lwork, \
1349 int *info)
1350{
1351 fortran_name(dgeev,DGEEV)(&jobvl,&jobvr,&n,a,&lda,wr,wi,vl,&ldvl,vr,&ldvr,work,&lwork,info);
1352}
1353
1354/************************************************************************************************/
1355
1356inline void c_dsyev(char JOBZ,char UPLO,int N,double *A,int LDA,double *W,double *WORK,int LWORK,\
1357 int *INFO)
1358{
1359 fortran_name(dsyev,DSYEV)(&JOBZ,&UPLO,&N,A,&LDA,W,WORK,&LWORK,INFO);
1360}
1361
1362/************************************************************************************************/
1363
1364inline void c_dggev(char jobvl, char jobvr, int n, double *a, int lda, double *b, int ldb, \
1365 double *alphar, double *alphai, double *beta, double *vl, int ldvl, \
1366 double *vr, int ldvr, double *work, int lwork, int *info)
1367{
1368 fortran_name(dggev,DGGEV)(&jobvl,&jobvr,&n,a,&lda,b,&ldb,alphar,alphai,beta,vl,&ldvl,vr,\
1369 &ldvr,work,&lwork,info);
1370}
1371
1372/************************************************************************************************/
1373
1374inline void c_zggev(char jobvl, char jobvr, int n, CPX *a, int lda, CPX *b, int ldb, CPX *alpha,\
1375 CPX *beta, CPX *vl, int ldvl, CPX *vr, int ldvr, CPX *work, int lwork, \
1376 double *rwork,int *info)
1377{
1378 fortran_name(zggev,ZGGEV)(&jobvl,&jobvr,&n,a,&lda,b,&ldb,alpha,beta,vl,&ldvl,vr,&ldvr,\
1379 work,&lwork,rwork,info);
1380}
1381
1382/************************************************************************************************/
1383
1384inline void c_zgeev(char jobvl, char jobvr, int n, CPX *a, int lda, CPX *w, CPX *vl, int ldvl, \
1385 CPX *vr, int ldvr, CPX *work, int lwork, double *rwork, int *info)
1386{
1387 fortran_name(zgeev,ZGEEV)(&jobvl,&jobvr,&n,a,&lda,w,vl,&ldvl,vr,&ldvr,work,&lwork,rwork,info);
1388}
1389
1390/************************************************************************************************/
1391
1392inline void c_zheev(char jobvl, char uplo, int n, CPX *a, int lda, double *w, CPX *work, \
1393 int lwork, double *rwork, int *info)
1394{
1395 fortran_name(zheev,ZHEEV)(&jobvl,&uplo,&n,a,&lda,w,work,&lwork,rwork,info);
1396}
1397
1398/************************************************************************************************/
1399
1400inline void c_dgetri(int n,double *a,int lda,int *ipiv,double *work,int lwork,int *info)
1401{
1402 fortran_name(dgetri,DGETRI)(&n,a,&lda,ipiv,work,&lwork,info);
1403}
1404
1405/************************************************************************************************/
1406
1407inline void c_dsytri(char uplo,int n,double *a,int lda,int *ipiv,double *work,int *info)
1408{
1409 fortran_name(dsytri,DSYTRI)(&uplo,&n,a,&lda,ipiv,work,info);
1410}
1411
1412/************************************************************************************************/
1413
1414inline void c_zhetrf(char uplo,int n,CPX *a,int lda,int *ipiv,CPX *work,int lwork,int *info)
1415{
1416 fortran_name(zhetrf,ZHETRF)(&uplo,&n,a,&lda,ipiv,work,&lwork,info);
1417}
1418
1419/************************************************************************************************/
1420
1421inline void c_zhetri(char uplo,int n,CPX *a,int lda,int *ipiv,CPX *work,int *info)
1422{
1423 fortran_name(zhetri,ZHETRI)(&uplo,&n,a,&lda,ipiv,work,info);
1424}
1425
1426/************************************************************************************************/
1427
1428inline void c_zhetrs(char uplo,int n,int nrhs,CPX *a,int lda,int *ipiv,CPX *b,int ldb,\
1429 int *info)
1430{
1431 fortran_name(zhetrs,ZHETRS)(&uplo,&n,&nrhs,a,&lda,ipiv,b,&ldb,info);
1432}
1433
1434/************************************************************************************************/
1435
1436inline void c_dsysv(char uplo,int n,int nrhs,double *a,int lda,int *ipiv,double *b,int ldb,\
1437 double *work,int lwork,int *info)
1438{
1439 fortran_name(dsysv,DSYSV)(&uplo,&n,&nrhs,a,&lda,ipiv,b,&ldb,work,&lwork,info);
1440}
1441
1442/************************************************************************************************/
1443
1444inline void c_dsytrf(char uplo,int n,double *a,int lda,int *ipiv,double *work,int lwork,int *info)
1445{
1446 fortran_name(dsytrf,DSYTRF)(&uplo,&n,a,&lda,ipiv,work,&lwork,info);
1447}
1448
1449/************************************************************************************************/
1450
1451inline void c_dsytrs(char uplo,int n,int nrhs,double *a,int lda,int *ipiv,double *b,int ldb,\
1452 int *info)
1453{
1454 fortran_name(dsytrs,DSYTRS)(&uplo,&n,&nrhs,a,&lda,ipiv,b,&ldb,info);
1455}
1456
1457/************************************************************************************************/
1458
1459inline void c_dstebz(char *range,char *order,int *iter,double *vl,double *vu,int *il,int *iu,\
1460 double *abstol,double *diag,double *offd,int *neval,int *nsplit,\
1461 double *eval,int *iblock,int *isplit,double *work,int *iwork,int *info)
1462{
1463 fortran_name(dstebz,DSTEBZ)(range,order,iter,vl,vu,il,iu,abstol,diag,offd,neval,nsplit,eval,\
1464 iblock,isplit,work,iwork,info);
1465}
1466
1467/************************************************************************************************/
1468
1469inline void c_zlarnv(int *idist,int *iseed,int n,CPX *x)
1470{
1471 fortran_name(zlarnv,ZLARNV)(idist,iseed,&n,x);
1472}
1473
1474/************************************************************************************************/
1475
1476inline void c_dgesdd(char jobz,int m,int n,double *a,int lda,double *s,double *u,int ldu,\
1477 double *vt,int ldvt,double *work,int lwork,int *iwork,int *info)
1478{
1479 fortran_name(dgesdd,DGESDD)(&jobz,&m,&n,a,&lda,s,u,&ldu,vt,&ldvt,work,&lwork,iwork,info);
1480}
1481
1482/************************************************************************************************/
1483
1484inline void c_zgesdd(char jobz,int m,int n,CPX *a,int lda,double *s,CPX *u,int ldu,\
1485 CPX *vt,int ldvt,CPX *work,int lwork,double *rwork,int *iwork,int *info)
1486{
1487 fortran_name(zgesdd,ZGESDD)(&jobz,&m,&n,a,&lda,s,u,&ldu,vt,&ldvt,work,&lwork,rwork,iwork,info);
1488}
1489
1490/************************************************************************************************/
1491
1492template <typename T>
1493inline void copy_csr_to_device(int size,int n_nonzeros,int *hedge_i,int *hindex_j,T *hnnz,\
1494 int *dedge_i,int *dindex_j,T *dnnz);
1495
1496template <>
1497inline void copy_csr_to_device(int size,int n_nonzeros,int *hedge_i,int *hindex_j,double *hnnz,\
1498 int *dedge_i,int *dindex_j,double *dnnz)
1499{
1500 d_copy_csr_to_device(size,n_nonzeros,hedge_i,hindex_j,hnnz,dedge_i,dindex_j,dnnz);
1501}
1502
1503template <>
1504inline void copy_csr_to_device(int size,int n_nonzeros,int *hedge_i,int *hindex_j,CPX *hnnz,\
1505 int *dedge_i,int *dindex_j,CPX *dnnz)
1506{
1507 z_copy_csr_to_device(size,n_nonzeros,hedge_i,hindex_j,hnnz,dedge_i,dindex_j,dnnz);
1508}
1509
1510/************************************************************************************************/
1511
1512template <typename T>
1513inline void init_var_on_dev(T *var,int N,cudaStream_t stream);
1514
1515template <>
1516inline void init_var_on_dev(double *var,int N,cudaStream_t stream){
1517 d_init_var_on_dev(var,N,stream);
1518}
1519
1520template <>
1521inline void init_var_on_dev(CPX *var,int N,cudaStream_t stream){
1522 z_init_var_on_dev(var,N,stream);
1523}
1524
1525/************************************************************************************************/
1526
1527template <typename T>
1528inline void init_eye_on_dev(T *var,int N,cudaStream_t stream);
1529
1530template <>
1531inline void init_eye_on_dev(double *var,int N,cudaStream_t stream){
1532 d_init_eye_on_dev(var,N,stream);
1533}
1534
1535// new SINGLE PRECISION
1536template <>
1537inline void init_eye_on_dev(float *var,int N,cudaStream_t stream){
1538 s_init_eye_on_dev(var,N,stream);
1539}
1540
1541template <>
1542inline void init_eye_on_dev(CPX *var,int N,cudaStream_t stream){
1543 z_init_eye_on_dev(var,N,stream);
1544}
1545
1546/************************************************************************************************/
1547
1548template <typename T>
1549inline void csr_mult_f(void *handle,int m,int n,int k,int n_nonzeros,int *Aedge_i,int *Aindex_j,\
1550 T *Annz,T alpha,T *B,T beta,T *C);
1551
1552template <>
1553inline void csr_mult_f(void *handle,int m,int n,int k,int n_nonzeros,int *Aedge_i,int *Aindex_j,\
1554 double *Annz,double alpha,double *B,double beta,double *C)
1555{
1556 d_csr_mult_f(handle,m,n,k,n_nonzeros,Aedge_i,Aindex_j,Annz,alpha,B,beta,C);
1557}
1558
1559template <>
1560inline void csr_mult_f(void *handle,int m,int n,int k,int n_nonzeros,int *Aedge_i,int *Aindex_j,\
1561 CPX *Annz,CPX alpha,CPX *B,CPX beta,CPX *C)
1562{
1563 z_csr_mult_f(handle,m,n,k,n_nonzeros,Aedge_i,Aindex_j,Annz,alpha,B,beta,C);
1564}
1565
1566/************************************************************************************************/
1567
1568template <typename T>
1569inline void transpose_matrix(T *odata,T *idata,int size_x,int size_y);
1570
1571template <>
1572inline void transpose_matrix(double *odata,double *idata,int size_x,int size_y)
1573{
1574 d_transpose_matrix(odata,idata,size_x,size_y);
1575}
1576
1577template <>
1578inline void transpose_matrix(CPX *odata,CPX *idata,int size_x,int size_y)
1579{
1580 z_transpose_matrix(odata,idata,size_x,size_y);
1581}
1582
1583/************************************************************************************************/
1584
1585template <typename T>
1586inline void extract_diag_on_dev(T *D,int *edge_i,int *index_j,T *nnz,int NR,\
1587 int imin,int imax,int shift,int findx,cudaStream_t stream);
1588
1589template <>
1590inline void extract_diag_on_dev(double *D,int *edge_i,int *index_j,double *nnz,int NR,\
1591 int imin,int imax,int shift,int findx,cudaStream_t stream){
1592 d_extract_diag_on_dev(D,edge_i,index_j,nnz,NR,imin,imax,shift,findx,stream);
1593}
1594
1595template <>
1596inline void extract_diag_on_dev(CPX *D,int *edge_i,int *index_j,CPX *nnz,int NR,\
1597 int imin,int imax,int shift,int findx,cudaStream_t stream){
1598 z_extract_diag_on_dev(D,edge_i,index_j,nnz,NR,imin,imax,shift,findx,stream);
1599}
1600
1601/************************************************************************************************/
1602
1603template <typename T>
1604inline void extract_not_diag_on_dev(T *D,int *edge_i,int *index_j,T *nnz,int NR,\
1605 int imin,int imax,int jmin,int side,int shift,int findx,cudaStream_t stream);
1606
1607template <>
1608inline void extract_not_diag_on_dev(double *D,int *edge_i,int *index_j,double *nnz,int NR,\
1609 int imin,int imax,int jmin,int side,int shift,int findx,cudaStream_t stream){
1610 d_extract_not_diag_on_dev(D,edge_i,index_j,nnz,NR,imin,imax,jmin,side,shift,findx,stream);
1611}
1612
1613template <>
1614inline void extract_not_diag_on_dev(CPX *D,int *edge_i,int *index_j,CPX *nnz,int NR,\
1615 int imin,int imax,int jmin,int side,int shift,int findx,cudaStream_t stream){
1616 z_extract_not_diag_on_dev(D,edge_i,index_j,nnz,NR,imin,imax,jmin,side,shift,findx,stream);
1617}
1618
1619/************************************************************************************************/
1620
1621template <typename T>
1622inline void tril_dev(T *A, int lda, int N);
1623
1624template <>
1625inline void tril_dev(double *A, int lda, int N)
1626{
1627 d_tril_on_dev(A, lda, N);
1628}
1629
1630// new SINGLE PRECISION
1631template <>
1632inline void tril_dev(float *A, int lda, int N)
1633{
1634 s_tril_on_dev(A, lda, N);
1635}
1636
1637template <>
1638inline void tril_dev(CPX *A, int lda, int N)
1639{
1640 z_tril_on_dev(A, lda, N);
1641}
1642
1643/************************************************************************************************/
1644
1645template <typename T>
1646inline void indexed_copy_dev(T *src, T *dst, size_t *index, size_t N);
1647
1648template <>
1649inline void indexed_copy_dev(double *src, double *dst, size_t *index, size_t N)
1650{
1651 d_indexed_copy_on_dev(src, dst, index, N);
1652}
1653
1654template <>
1655inline void indexed_copy_dev(CPX *src, CPX *dst, size_t *index, size_t N)
1656{
1657 z_indexed_copy_on_dev(src, dst, index, N);
1658}
1659
1660/************************************************************************************************/
1661
1662template <typename T>
1663inline void indexed_copy_offset_dev(T *src, T *dst, size_t *index, size_t N, size_t offset);
1664
1665template <>
1666inline void indexed_copy_offset_dev(double *src, double *dst, size_t *index, size_t N, size_t offset)
1667{
1668 d_indexed_copy_offset_on_dev(src, dst, index, N, offset);
1669}
1670
1671// new SINGLE PRECISION
1672template <>
1673inline void indexed_copy_offset_dev(float *src, float *dst, size_t *index, size_t N, size_t offset)
1674{
1675 s_indexed_copy_offset_on_dev(src, dst, index, N, offset);
1676}
1677
1678template <>
1679inline void indexed_copy_offset_dev(CPX *src, CPX *dst, size_t *index, size_t N, size_t offset)
1680{
1681 z_indexed_copy_offset_on_dev(src, dst, index, N, offset);
1682}
1683
1684/************************************************************************************************/
1685
1686template <typename T>
1687inline void indexed_copy(T *src, T *dst, size_t *index, size_t N)
1688{
1689 #pragma omp parallel for
1690 for (int i = 0; i < N; i++)
1691 {
1692 dst[i] = src[index[i]];
1693 }
1694}
1695
1696/************************************************************************************************/
1697
1698template <typename T>
1699inline void indexed_log_sum(T *x, size_t *index, size_t N, T *sum)
1700{
1701 // initialize to zero ...
1702 *sum = 0.0;
1703
1704 //printf("diag(L) : ");
1705 #pragma omp parallel for reduction(+:sum[:1])
1706 for (int i = 0; i < N; i++)
1707 {
1708 *sum += log(x[index[i]]);
1709 //printf(" %f ", x[index[i]]);
1710 }
1711 //printf("\n");
1712}
1713
1714
1715/************************************************************************************************/
1716
1717template <typename T>
1718inline void log_sum(T *x, size_t N, T *sum)
1719{
1720 // initialize to zero ...
1721 *sum = 0.0;
1722
1723 #pragma omp parallel for reduction(+:sum[:1])
1724 for (int i = 0; i < N; i++)
1725 {
1726 *sum += log(x[i]);
1727 }
1728}
1729
1730/************************************************************************************************/
1731
1732template <typename T>
1733inline void log_dev(T *x, size_t N);
1734
1735template <>
1736inline void log_dev(double *x, size_t N)
1737{
1738 d_log_on_dev(x, N);
1739}
1740
1741template <>
1742inline void log_dev(CPX *x, size_t N)
1743{
1744 z_log_on_dev(x, N);
1745}
1746
1747/************************************************************************************************/
1748
1749template <typename T>
1750inline void fill_dev(T *x, const T value, size_t N);
1751
1752template <>
1753inline void fill_dev(double *x, const double value, size_t N)
1754{
1755 d_fill_on_dev(x, value, N);
1756}
1757
1758// new SINGLE PRECISION
1759template <>
1760inline void fill_dev(float *x, const float value, size_t N)
1761{
1762 s_fill_on_dev(x, value, N);
1763}
1764
1765template <>
1766inline void fill_dev(CPX *x, const CPX value, size_t N)
1767{
1768 z_fill_on_dev(x, value, N);
1769}
1770
1771/************************************************************************************************/
1772
1773template <typename T>
1774inline void init_block_matrix_dev(T *M, size_t *ia, size_t *ja, T *a, size_t nnz, size_t ns, size_t nt, size_t nd);
1775
1776template <>
1777inline void init_block_matrix_dev(double *M, size_t *ia, size_t *ja, double *a, size_t nnz, size_t ns, size_t nt, size_t nd)
1778{
1779 d_init_block_matrix_on_dev(M, ia, ja, a, nnz, ns, nt, nd);
1780}
1781
1782template <>
1783inline void init_block_matrix_dev(CPX *M, size_t *ia, size_t *ja, CPX *a, size_t nnz, size_t ns, size_t nt, size_t nd)
1784{
1785 z_init_block_matrix_on_dev(M, ia, ja, a, nnz, ns, nt, nd);
1786}
1787
1788/************************************************************************************************/
1789
1790template <typename T>
1791inline void init_supernode_dev(T *M, size_t *ia, size_t *ja, T *a, size_t supernode, size_t supernode_nnz, size_t supernode_offset, size_t ns, size_t nt, size_t nd, cudaStream_t stream );
1792
1793template <>
1794inline void init_supernode_dev(double *M, size_t *ia, size_t *ja, double *a, size_t supernode, size_t supernode_nnz, size_t supernode_offset, size_t ns, size_t nt, size_t nd, cudaStream_t stream )
1795{
1796 d_init_supernode_on_dev(M, ia, ja, a, supernode, supernode_nnz, supernode_offset, ns, nt, nd, stream);
1797}
1798
1799// new SINGLE PRECISION
1800template <>
1801inline void init_supernode_dev(float *M, size_t *ia, size_t *ja, float *a, size_t supernode, size_t supernode_nnz, size_t supernode_offset, size_t ns, size_t nt, size_t nd, cudaStream_t stream )
1802{
1803 s_init_supernode_on_dev(M, ia, ja, a, supernode, supernode_nnz, supernode_offset, ns, nt, nd, stream);
1804}
1805
1806template <>
1807inline void init_supernode_dev(CPX *M, size_t *ia, size_t *ja, CPX *a, size_t supernode, size_t supernode_nnz, size_t supernode_offset, size_t ns, size_t nt, size_t nd, cudaStream_t stream )
1808{
1809 z_init_supernode_on_dev(M, ia, ja, a, supernode, supernode_nnz, supernode_offset, ns, nt, nd, stream );
1810}
1811
1812/************************************************************************************************/
1813// NEW
1814
1815template <typename T>
1816inline void extract_nnzA_dev(T *a, size_t *ia, size_t *ja, T *M, size_t supernode, size_t supernode_nnz, size_t supernode_offset, size_t ns, size_t nt, size_t nd);
1817
1818template <>
1819inline void extract_nnzA_dev(double *a, size_t *ia, size_t *ja, double *M, size_t supernode, size_t supernode_nnz, size_t supernode_offset, size_t ns, size_t nt, size_t nd)
1820{
1821 d_extract_nnzA_on_dev(a, ia, ja, M, supernode, supernode_nnz, supernode_offset, ns, nt, nd);
1822}
1823
1824// new SINGLE PRECISION
1825template <>
1826inline void extract_nnzA_dev(float *a, size_t *ia, size_t *ja, float *M, size_t supernode, size_t supernode_nnz, size_t supernode_offset, size_t ns, size_t nt, size_t nd)
1827{
1828 s_extract_nnzA_on_dev(a, ia, ja, M, supernode, supernode_nnz, supernode_offset, ns, nt, nd);
1829}
1830
1831template <>
1832inline void extract_nnzA_dev(CPX *a, size_t *ia, size_t *ja, CPX *M, size_t supernode, size_t supernode_nnz, size_t supernode_offset, size_t ns, size_t nt, size_t nd)
1833{
1834 printf("just dummy version of COMPLEX extract_nnzA_dev. not implemented.\n");
1835 exit(1);
1836
1837}
1838
1839/************************************************************************************************/
1840
1841#endif
1842
1843
1844
1845
1846