LAMA
|
00001 00033 #ifndef LAMA_CUDABLAS3_HPP_ 00034 #define LAMA_CUDABLAS3_HPP_ 00035 00036 // for dll_import 00037 #include <lama/config.hpp> 00038 00039 // others 00040 #include <lama/openmp/BLASHelper.hpp> 00041 #include <lama/LAMATypes.hpp> 00042 00043 // logging 00044 #include <logging/logging.hpp> 00045 00046 #include <cublas.h> 00047 #include <cuda_runtime_api.h> 00048 00049 namespace lama 00050 { 00051 00052 class LAMA_DLL_IMPORTEXPORT CUDABLAS3 00053 { 00054 public: 00107 template<typename T> 00108 static void gemm( 00109 const enum CBLAS_ORDER order, 00110 const enum CBLAS_TRANSPOSE TransA, 00111 const enum CBLAS_TRANSPOSE TransB, 00112 const IndexType M, 00113 const IndexType N, 00114 const IndexType K, 00115 const T alpha, 00116 const T* A, 00117 const IndexType lda, 00118 const T* B, 00119 const IndexType ldb, 00120 const T beta, 00121 T* C, 00122 const IndexType ldc, 00123 class SyncToken* syncToken ); 00124 00193 // template<typename T> 00194 // static void symm( 00195 // const enum CBLAS_ORDER order, 00196 // const enum CBLAS_SIDE side, 00197 // const enum CBLAS_UPLO uplo, 00198 // const IndexType m, 00199 // const IndexType n, 00200 // const T alpha, 00201 // const T *A, 00202 // const IndexType lda, 00203 // const T *B, 00204 // const IndexType ldb, 00205 // const T beta, 00206 // T *C, 00207 // const IndexType ldc); 00208 00272 // template<typename T> 00273 // static void trmm( 00274 // const enum CBLAS_ORDER order, 00275 // const enum CBLAS_SIDE Side, 00276 // const enum CBLAS_UPLO Uplo, 00277 // const enum CBLAS_TRANSPOSE TransA, 00278 // const enum CBLAS_DIAG Diag, 00279 // const IndexType M, 00280 // const IndexType N, 00281 // const T alpha, 00282 // const T *A, 00283 // const IndexType lda, 00284 // T *B, 00285 // const IndexType ldb); 00286 00354 template<typename T> 00355 static void trsm( 00356 const enum CBLAS_ORDER order, 00357 const enum CBLAS_SIDE Side, 00358 const enum CBLAS_UPLO Uplo, 00359 const enum CBLAS_TRANSPOSE TransA, 00360 const enum CBLAS_DIAG Diag, 00361 const IndexType M, 00362 const IndexType N, 00363 const T alpha, 00364 const T* A, 00365 const IndexType lda, 00366 T* B, 00367 const IndexType ldb, 00368 class SyncToken* syncToken ); 00369 00438 // template<typename T> 00439 // static void syrk( 00440 // const enum CBLAS_ORDER order, 00441 // const enum CBLAS_UPLO uplo, 00442 // const enum CBLAS_TRANSPOSE trans, 00443 // const IndexType n, 00444 // const IndexType k, 00445 // const T alpha, 00446 // const T* A, 00447 // const IndexType lda, 00448 // const T beta, 00449 // T* C, 00450 // const IndexType ldc ); 00451 00521 // template<typename T> 00522 // static void syrk2( 00523 // const enum CBLAS_ORDER order, 00524 // const enum CBLAS_UPLO uplo, 00525 // const enum CBLAS_TRANSPOSE trans, 00526 // const IndexType n, 00527 // const IndexType k, 00528 // const T alpha, 00529 // const T* A, 00530 // const IndexType lda, 00531 // const T* B, 00532 // const IndexType ldb, 00533 // const T beta, 00534 // T* C, 00535 // const IndexType ldc ); 00536 private: 00537 00538 LAMA_LOG_DECL_STATIC_LOGGER( logger ); 00539 00540 template <typename T> 00541 void gemm_launcher( 00542 const char transA_char, 00543 const char transB_char, 00544 const int m, 00545 const int n, 00546 const int k, 00547 const T alpha, 00548 const T* const A, 00549 const int lda, 00550 const T* const B, 00551 const int ldb, 00552 const T beta, 00553 T* const C, 00554 const int ldc, 00555 cudaStream_t cs ); 00556 }; 00557 00558 } /* namespace lama */ 00559 00560 #endif // LAMA_CUDABLAS3_HPP_