LAMA
|
00001 00033 #ifndef LAMA_CUDABLAS1_HPP_ 00034 #define LAMA_CUDABLAS1_HPP_ 00035 00036 // for dll_import 00037 #include <lama/config.hpp> 00038 00039 // others 00040 #include <lama/LAMATypes.hpp> 00041 #include <lama/SyncToken.hpp> 00042 00043 #include <cublas.h> 00044 #include <cuda_runtime_api.h> 00045 00046 namespace lama 00047 { 00048 00049 class LAMA_DLL_IMPORTEXPORT CUDABLAS1 00050 { 00051 public: 00052 00056 template<typename T> 00057 static void scal( 00058 const IndexType n, 00059 const T alpha, 00060 T* x, 00061 const IndexType incX, 00062 SyncToken* syncToken ); 00063 00067 template<typename T> 00068 static T nrm2( const IndexType n, const T* x, const IndexType incX, 00069 SyncToken* syncToken ); 00070 00074 template<typename T> 00075 static T asum( const IndexType n, const T* x, const IndexType incX, 00076 SyncToken* syncToken ); 00077 00081 template<typename T> 00082 static IndexType iamax( const IndexType n, const T* x, const IndexType incX, 00083 SyncToken* syncToken ); 00084 00088 template<typename T> 00089 static void swap( 00090 const IndexType n, 00091 T* y, 00092 const IndexType incY, 00093 T* x, 00094 const IndexType incX, 00095 SyncToken* syncToken ); 00096 00100 template<typename T> 00101 static void copy( 00102 const IndexType n, 00103 const T* x, 00104 const IndexType incX, 00105 T* y, 00106 const IndexType incY, 00107 SyncToken* syncToken ); 00108 00112 template<typename T> 00113 static void axpy( 00114 const IndexType n, 00115 const T alpha, 00116 const T* x, 00117 const IndexType incX, 00118 T* y, 00119 const IndexType incY, 00120 SyncToken* syncToken ); 00121 00125 template<typename T> 00126 static T dot( 00127 const IndexType n, 00128 const T* x, 00129 const IndexType incX, 00130 const T* y, 00131 const IndexType incY, 00132 SyncToken* syncToken ); 00133 00137 template<typename T> 00138 static void sum( 00139 const IndexType n, 00140 T alpha, 00141 const T* x, 00142 T beta, 00143 const T* y, 00144 T* z, 00145 SyncToken* syncToken ); 00146 00150 template<typename T> 00151 static void rot( 00152 const IndexType n, 00153 T* x, 00154 const IndexType incX, 00155 T* y, 00156 const IndexType incY, 00157 const T c, 00158 const T s, 00159 SyncToken* syncToken ); 00160 00164 template<typename T> 00165 static void rotm( 00166 const IndexType n, 00167 T* x, 00168 const IndexType incX, 00169 T* y, 00170 const IndexType incY, 00171 const T* P, 00172 SyncToken* syncToken ); 00173 00177 template<typename T> 00178 static void ass( const IndexType n, const T value, T* x, SyncToken* syncToken ); 00179 00183 template<typename T> 00184 static T viamax( const IndexType n, const T* x_d, const IndexType incx, SyncToken* syncToken ); 00185 00186 private: 00187 00188 template <typename T> 00189 static void ass_launcher( const int n, const T value, T* x, cudaStream_t stream ); 00190 00191 template<typename T> 00192 static void sum_launcher( const int n, T alpha, const T* x, T beta, const T* y, T* z, cudaStream_t stream ); 00193 }; 00194 00195 } /* namespace lama */ 00196 00197 #endif // LAMA_CUDABLAS1_HPP_