LAMA
|
00001 00034 #ifndef LAMA_LUSOLVER_HPP_ 00035 #define LAMA_LUSOLVER_HPP_ 00036 00037 // for dll_import 00038 #include <lama/config.hpp> 00039 00040 // base classes 00041 #include <lama/solver/Solver.hpp> 00042 00043 // others 00044 #include <lama/Communicator.hpp> 00045 #include <lama/DenseMatrix.hpp> 00046 00047 #include <lama/OpenMP/OpenMPBLAS3.hpp> 00048 00049 // tracing 00050 #include <lama/tracing.hpp> 00051 00052 #ifndef LAMA_BUILD_CUDA 00053 typedef int cudaStream_t; 00054 typedef int CUevent; 00055 typedef int CUdevice; 00056 typedef int CUstream; 00057 typedef int CUresult; 00058 typedef int CUDAStreamSyncToken; 00059 typedef std::auto_ptr<CUDAStreamSyncToken> CUDAStreamSyncTokenPtr; 00060 #else 00061 #include <lama/CUDA/CUDAStreamSyncToken.hpp> 00062 #include <cuda.h> 00063 #include <cuda_runtime.h> 00064 #endif 00065 00066 namespace lama 00067 { 00068 00069 class LAMA_DLL_IMPORTEXPORT LUSolver : public Solver 00070 { 00071 public: 00077 LUSolver(const std::string& id); 00078 00085 LUSolver(const std::string& id, LoggerPtr logger); 00086 00090 LUSolver( const LUSolver& other ); 00091 00095 virtual ~LUSolver(); 00096 00106 virtual void initialize( const Matrix& coefficients ); 00107 00108 void factorMatrixToLU( Matrix& matrix,std::vector<IndexType>& permutation ); 00109 00122 virtual void solve( Vector& solution, const Vector& rhs ); 00123 00124 void setTileSize( const IndexType tilesize ); 00125 IndexType getTileSize( ); 00126 00127 virtual void setDeviceNumber( const IndexType dev ); 00128 virtual IndexType getDeviceNumber( ); 00129 00130 struct LUSolverRuntime : SolverRuntime 00131 { 00132 LUSolverRuntime(); 00133 virtual ~LUSolverRuntime(); 00134 00135 Matrix* mLUfactorization; 00136 std::vector<IndexType> mPermutation; 00137 }; 00138 00142 virtual LUSolverRuntime& getRuntime(); 00143 00147 virtual const LUSolverRuntime& getConstRuntime() const; 00148 00155 virtual SolverPtr copy(); 00156 00157 protected: 00158 00159 LUSolverRuntime mLUSolverRuntime; 00160 00161 IndexType mTilesize; 00162 IndexType mDev; 00163 00164 const static double epsilon; 00165 00166 private: 00167 template<typename T> 00168 struct lama_swap 00169 { 00170 typedef void (*swap_func)( const int,T*,const int,T*,const int ); 00171 ContextPtr ctxt; 00172 swap_func func; 00173 }; 00174 00175 template<typename T> 00176 struct lama_gemm 00177 { 00178 typedef void (*gemm_func)( 00179 const enum CBLAS_ORDER,const enum CBLAS_TRANSPOSE,const enum CBLAS_TRANSPOSE,const int,const int, 00180 const int,const T,const T*,const int,const T*,const int,const T,T*,const int,cudaStream_t ); 00181 typedef void( *__rs )( const CUevent ); 00182 typedef bool( *__qu )( const CUevent ); 00183 00184 cudaStream_t stream; 00185 gemm_func func; 00186 __rs record; 00187 __rs synchronize; 00188 __qu query; 00189 00190 static CUDAStreamSyncTokenPtr __syncTok; 00191 00192 static void __gemm( const enum CBLAS_ORDER order,const enum CBLAS_TRANSPOSE transa, 00193 const enum CBLAS_TRANSPOSE transb,const int m,const int n,const int k, 00194 const T alpha,const T* A,const int lda,const T* B,const int ldb,const T beta,T* C, 00195 const int ldc,cudaStream_t stream ); 00196 static void __recordDef( const CUevent event ); 00197 static bool __queryDef ( const CUevent event ); 00198 static void __synchronizeDef( const CUevent event ); 00199 static void __recordCuda( const CUevent event ); 00200 static bool __queryCuda ( const CUevent event ); 00201 static void __synchronizeCuda( const CUevent event ); 00202 }; 00203 00204 template<typename T> 00205 void computeLUFactorization( DenseMatrix<T>& matrix,std::vector<IndexType>& permutation ); 00206 00207 template<typename T> 00208 void pgetf2( const IndexType numBlockRows,DenseStorage<T>** const A,IndexType* const ipiv,const PartitionId ROOT ); 00209 00210 template<typename T> 00211 void plaswp( DenseStorage<T>** const A,const PartitionId ROOT,const IndexType* const ipiv,const IndexType n, 00212 const lama_swap<T> swap ); 00213 00214 template<typename T> 00215 IndexType piamax_own( const IndexType numBlockCol,DenseStorage<T>** const local,const IndexType col, 00216 const IndexType locRow=0 ); 00217 00218 template<typename T> 00219 void ptrsm( const enum CBLAS_UPLO uplo,const DenseMatrix<T>& matrix,DenseVector<T>& solution ); 00220 00221 IndexType computeTilesize( IndexType m,IndexType n ); 00222 00223 inline void initializeCommunicator( ); 00224 00225 CommunicatorPtr mComm; 00226 00227 LAMA_LOG_DECL_STATIC_LOGGER(logger); 00228 }; 00229 00230 // implementation of inner classes. 00231 00232 template<typename T> 00233 CUDAStreamSyncTokenPtr LUSolver::lama_gemm<T>::__syncTok; 00234 00235 template<typename T> 00236 void LUSolver::lama_gemm<T>::__gemm( const enum CBLAS_ORDER order,const enum CBLAS_TRANSPOSE transa, 00237 const enum CBLAS_TRANSPOSE transb,const int m,const int n,const int k, 00238 const T alpha,const T* A,const int lda,const T* B,const int ldb,const T beta,T* C, 00239 const int ldc,cudaStream_t ) 00240 { 00241 LAMA_REGION("GEMM"); 00242 OpenMPBLAS3::gemm( order, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, NULL ); 00243 } 00244 00245 template<typename T> 00246 void LUSolver::lama_gemm<T>::__recordDef( const CUevent ) 00247 { } 00248 00249 template<typename T> 00250 void LUSolver::lama_gemm<T>::__synchronizeDef( const CUevent ) 00251 { } 00252 00253 template<typename T> 00254 bool LUSolver::lama_gemm<T>::__queryDef( const CUevent ) 00255 { 00256 return true; 00257 } 00258 00259 template<typename T> 00260 void LUSolver::lama_gemm<T>::__recordCuda( const CUevent event ) 00261 { 00262 #ifdef LAMA_BUILD_CUDA 00263 __syncTok->recordEvent( event ); 00264 #endif 00265 } 00266 00267 template<typename T> 00268 void LUSolver::lama_gemm<T>::__synchronizeCuda( const CUevent event ) 00269 { 00270 #ifdef LAMA_BUILD_CUDA 00271 __syncTok->synchronizeEvent( event ); 00272 #endif 00273 } 00274 00275 template<typename T> 00276 bool LUSolver::lama_gemm<T>::__queryCuda( const CUevent event ) 00277 { 00278 #ifdef LAMA_BUILD_CUDA 00279 return __syncTok->queryEvent( event ); 00280 #endif 00281 return true; 00282 } 00283 00284 } // namespace LAMA 00285 00286 #endif /* LAMA_LUSOLVER_HPP_ */