LAMA
/home/brandes/workspace/LAMA/src/lama/cuda/CUDABLAS3.hpp
Go to the documentation of this file.
00001 
00033 #ifndef LAMA_CUDABLAS3_HPP_
00034 #define LAMA_CUDABLAS3_HPP_
00035 
00036 // for dll_import
00037 #include <lama/config.hpp>
00038 
00039 // others
00040 #include <lama/openmp/BLASHelper.hpp>
00041 #include <lama/LAMATypes.hpp>
00042 
00043 // logging
00044 #include <logging/logging.hpp>
00045 
00046 #include <cublas.h>
00047 #include <cuda_runtime_api.h>
00048 
00049 namespace lama
00050 {
00051 
00052 class LAMA_DLL_IMPORTEXPORT CUDABLAS3
00053 {
00054 public:
00107     template<typename T>
00108     static void gemm(
00109         const enum CBLAS_ORDER order,
00110         const enum CBLAS_TRANSPOSE TransA,
00111         const enum CBLAS_TRANSPOSE TransB,
00112         const IndexType M,
00113         const IndexType N,
00114         const IndexType K,
00115         const T alpha,
00116         const T* A,
00117         const IndexType lda,
00118         const T* B,
00119         const IndexType ldb,
00120         const T beta,
00121         T* C,
00122         const IndexType ldc,
00123         class SyncToken* syncToken );
00124 
00193 //    template<typename T>
00194 //    static void symm(
00195 //        const enum CBLAS_ORDER order,
00196 //        const enum CBLAS_SIDE side,
00197 //        const enum CBLAS_UPLO uplo,
00198 //        const IndexType m,
00199 //        const IndexType n,
00200 //        const T alpha,
00201 //        const T *A,
00202 //        const IndexType lda,
00203 //        const T *B,
00204 //        const IndexType ldb,
00205 //        const T beta,
00206 //        T *C,
00207 //        const IndexType ldc);
00208 
00272 //    template<typename T>
00273 //    static void trmm(
00274 //        const enum CBLAS_ORDER order,
00275 //        const enum CBLAS_SIDE Side,
00276 //        const enum CBLAS_UPLO Uplo,
00277 //        const enum CBLAS_TRANSPOSE TransA,
00278 //        const enum CBLAS_DIAG Diag,
00279 //        const IndexType M,
00280 //        const IndexType N,
00281 //        const T alpha,
00282 //        const T *A,
00283 //        const IndexType lda,
00284 //        T *B,
00285 //        const IndexType ldb);
00286 
00354     template<typename T>
00355     static void trsm(
00356         const enum CBLAS_ORDER order,
00357         const enum CBLAS_SIDE Side,
00358         const enum CBLAS_UPLO Uplo,
00359         const enum CBLAS_TRANSPOSE TransA,
00360         const enum CBLAS_DIAG Diag,
00361         const IndexType M,
00362         const IndexType N,
00363         const T alpha,
00364         const T* A,
00365         const IndexType lda,
00366         T* B,
00367         const IndexType ldb,
00368         class SyncToken* syncToken );
00369 
00438 //    template<typename T>
00439 //    static void syrk(
00440 //        const enum CBLAS_ORDER order,
00441 //        const enum CBLAS_UPLO uplo,
00442 //        const enum CBLAS_TRANSPOSE trans,
00443 //        const IndexType n,
00444 //        const IndexType k,
00445 //        const T alpha,
00446 //        const T* A,
00447 //        const IndexType lda,
00448 //        const T beta,
00449 //        T* C,
00450 //        const IndexType ldc );
00451 
00521 //    template<typename T>
00522 //    static void syrk2(
00523 //        const enum CBLAS_ORDER order,
00524 //        const enum CBLAS_UPLO uplo,
00525 //        const enum CBLAS_TRANSPOSE trans,
00526 //        const IndexType n,
00527 //        const IndexType k,
00528 //        const T alpha,
00529 //        const T* A,
00530 //        const IndexType lda,
00531 //        const T* B,
00532 //        const IndexType ldb,
00533 //        const T beta,
00534 //        T* C,
00535 //        const IndexType ldc );
00536 private:
00537 
00538     LAMA_LOG_DECL_STATIC_LOGGER( logger );
00539 
00540     template <typename T>
00541     void gemm_launcher(
00542         const char transA_char,
00543         const char transB_char,
00544         const int m,
00545         const int n,
00546         const int k,
00547         const T alpha,
00548         const T* const A,
00549         const int lda,
00550         const T* const B,
00551         const int ldb,
00552         const T beta,
00553         T* const C,
00554         const int ldc,
00555         cudaStream_t cs );
00556 };
00557 
00558 } /* namespace lama */
00559 
00560 #endif // LAMA_CUDABLAS3_HPP_