LAMA
/home/brandes/workspace/LAMA/src/lama/solver/SpecialLUSolver.hpp
Go to the documentation of this file.
00001 
00034 #ifndef LAMA_LUSOLVER_HPP_
00035 #define LAMA_LUSOLVER_HPP_
00036 
00037 // for dll_import
00038 #include <lama/config.hpp>
00039 
00040 // base classes
00041 #include <lama/solver/Solver.hpp>
00042 
00043 // others
00044 #include <lama/Communicator.hpp>
00045 #include <lama/DenseMatrix.hpp>
00046 
00047 #include <lama/OpenMP/OpenMPBLAS3.hpp>
00048 
00049 // tracing
00050 #include <lama/tracing.hpp>
00051 
00052 #ifndef LAMA_BUILD_CUDA
00053     typedef int cudaStream_t;
00054     typedef int CUevent;
00055     typedef int CUdevice;
00056     typedef int CUstream;
00057     typedef int CUresult;
00058     typedef int CUDAStreamSyncToken;
00059     typedef std::auto_ptr<CUDAStreamSyncToken> CUDAStreamSyncTokenPtr;
00060 #else
00061     #include <lama/CUDA/CUDAStreamSyncToken.hpp>
00062     #include <cuda.h>
00063     #include <cuda_runtime.h>
00064 #endif
00065 
00066 namespace lama
00067 {
00068 
00069 class LAMA_DLL_IMPORTEXPORT LUSolver : public Solver
00070 {
00071 public:
00077     LUSolver(const std::string& id);
00078 
00085     LUSolver(const std::string& id, LoggerPtr logger);
00086 
00090     LUSolver( const LUSolver& other );
00091 
00095     virtual ~LUSolver();
00096 
00106     virtual void initialize( const Matrix& coefficients );
00107 
00108     void factorMatrixToLU( Matrix& matrix,std::vector<IndexType>& permutation );
00109 
00122     virtual void solve( Vector& solution, const Vector& rhs );
00123 
00124     void setTileSize( const IndexType tilesize );
00125     IndexType getTileSize( );
00126 
00127     virtual void setDeviceNumber( const IndexType dev );
00128     virtual IndexType getDeviceNumber( );
00129 
00130     struct LUSolverRuntime : SolverRuntime
00131     {
00132         LUSolverRuntime();
00133         virtual ~LUSolverRuntime();
00134         
00135         Matrix*                 mLUfactorization;
00136         std::vector<IndexType>  mPermutation;
00137     };
00138 
00142     virtual LUSolverRuntime& getRuntime();
00143     
00147     virtual const LUSolverRuntime& getConstRuntime() const;
00148 
00155     virtual SolverPtr copy();
00156 
00157 protected:
00158 
00159     LUSolverRuntime mLUSolverRuntime;
00160 
00161     IndexType mTilesize;
00162     IndexType mDev;
00163 
00164     const static double epsilon;
00165 
00166 private:
00167     template<typename T>
00168     struct lama_swap
00169     {
00170         typedef void (*swap_func)( const int,T*,const int,T*,const int );
00171         ContextPtr ctxt;
00172         swap_func  func;
00173     };
00174 
00175     template<typename T>
00176     struct lama_gemm
00177     {
00178         typedef void (*gemm_func)(
00179                         const enum CBLAS_ORDER,const enum CBLAS_TRANSPOSE,const enum CBLAS_TRANSPOSE,const int,const int,
00180                         const int,const T,const T*,const int,const T*,const int,const T,T*,const int,cudaStream_t );
00181         typedef void( *__rs )( const CUevent );
00182         typedef bool( *__qu )( const CUevent );
00183 
00184         cudaStream_t stream;
00185         gemm_func    func;
00186         __rs         record;
00187         __rs         synchronize;
00188         __qu         query;
00189 
00190         static CUDAStreamSyncTokenPtr __syncTok;
00191 
00192         static void __gemm( const enum CBLAS_ORDER order,const enum CBLAS_TRANSPOSE transa,
00193                             const enum CBLAS_TRANSPOSE transb,const int m,const int n,const int k,
00194                             const T alpha,const T* A,const int lda,const T* B,const int ldb,const T beta,T* C,
00195                             const int ldc,cudaStream_t stream );
00196         static void __recordDef( const CUevent event );
00197         static bool __queryDef ( const CUevent event );
00198         static void __synchronizeDef( const CUevent event );
00199         static void __recordCuda( const CUevent event );
00200         static bool __queryCuda ( const CUevent event );
00201         static void __synchronizeCuda( const CUevent event );
00202     };
00203 
00204     template<typename T>
00205     void computeLUFactorization( DenseMatrix<T>& matrix,std::vector<IndexType>& permutation );
00206 
00207     template<typename T>
00208     void pgetf2( const IndexType numBlockRows,DenseStorage<T>** const A,IndexType* const ipiv,const PartitionId ROOT );
00209 
00210     template<typename T>
00211     void plaswp( DenseStorage<T>** const A,const PartitionId ROOT,const IndexType* const ipiv,const IndexType n,
00212         const lama_swap<T> swap );
00213 
00214     template<typename T>
00215     IndexType piamax_own( const IndexType numBlockCol,DenseStorage<T>** const local,const IndexType col,
00216         const IndexType locRow=0 );
00217 
00218     template<typename T>
00219     void ptrsm( const enum CBLAS_UPLO uplo,const DenseMatrix<T>& matrix,DenseVector<T>& solution );
00220 
00221     IndexType computeTilesize( IndexType m,IndexType n );
00222 
00223     inline void initializeCommunicator( );
00224 
00225     CommunicatorPtr mComm;
00226 
00227     LAMA_LOG_DECL_STATIC_LOGGER(logger);
00228 };
00229 
00230 // implementation of inner classes.
00231 
00232 template<typename T>
00233 CUDAStreamSyncTokenPtr LUSolver::lama_gemm<T>::__syncTok;
00234 
00235 template<typename T>
00236 void LUSolver::lama_gemm<T>::__gemm( const enum CBLAS_ORDER order,const enum CBLAS_TRANSPOSE transa,
00237                                      const enum CBLAS_TRANSPOSE transb,const int m,const int n,const int k,
00238                                      const T alpha,const T* A,const int lda,const T* B,const int ldb,const T beta,T* C,
00239                                      const int ldc,cudaStream_t )
00240 {
00241     LAMA_REGION("GEMM");
00242     OpenMPBLAS3::gemm( order, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, NULL );
00243 }
00244 
00245 template<typename T>
00246 void LUSolver::lama_gemm<T>::__recordDef( const CUevent )
00247 { }
00248 
00249 template<typename T>
00250 void LUSolver::lama_gemm<T>::__synchronizeDef( const CUevent )
00251 { }
00252 
00253 template<typename T>
00254 bool LUSolver::lama_gemm<T>::__queryDef( const CUevent )
00255 {
00256     return true;
00257 }
00258 
00259 template<typename T>
00260 void LUSolver::lama_gemm<T>::__recordCuda( const CUevent event )
00261 {
00262 #ifdef LAMA_BUILD_CUDA
00263     __syncTok->recordEvent( event );
00264 #endif
00265 }
00266 
00267 template<typename T>
00268 void LUSolver::lama_gemm<T>::__synchronizeCuda( const CUevent event )
00269 {
00270 #ifdef LAMA_BUILD_CUDA
00271     __syncTok->synchronizeEvent( event );
00272 #endif
00273 }
00274 
00275 template<typename T>
00276 bool LUSolver::lama_gemm<T>::__queryCuda( const CUevent event )
00277 {
00278 #ifdef LAMA_BUILD_CUDA
00279     return __syncTok->queryEvent( event );
00280 #endif
00281     return true;
00282 }
00283 
00284 } // namespace LAMA
00285 
00286 #endif /* LAMA_LUSOLVER_HPP_ */