LAMA
|
00001 00033 #ifndef LAMA_COMMUNICATOR_HPP_ 00034 #define LAMA_COMMUNICATOR_HPP_ 00035 00036 // for dll_import 00037 #include <lama/config.hpp> 00038 00039 // base classes 00040 #include <lama/NonCopyable.hpp> 00041 #include <lama/Printable.hpp> 00042 00043 // others 00044 #include <lama/LAMATypes.hpp> 00045 #include <lama/HostReadAccess.hpp> 00046 #include <lama/HostWriteAccess.hpp> 00047 #include <lama/CommunicationPlan.hpp> 00048 00049 #include <lama/exception/LAMAAssert.hpp> 00050 00051 // logging 00052 #include <logging/logging.hpp> 00053 00054 // boost 00055 #include <boost/shared_ptr.hpp> 00056 00057 #include <memory> 00058 #include <vector> 00059 #include <cmath> 00060 00061 namespace lama 00062 { 00063 00064 // Forward declaration of all classes that are used in the interface 00065 00066 class CommunicationPlan; 00067 00068 template<typename T> class LAMAArray; 00069 00070 class SyncToken; 00071 00072 class Distribution; 00073 00074 class Halo; 00075 00094 class LAMA_DLL_IMPORTEXPORT Communicator : public Printable, private NonCopyable 00095 { 00096 00097 public: 00098 00101 enum ThreadSafetyLevel { Funneled = 1, 00102 Serialized = 2, 00103 Multiple = 3 00104 }; 00105 00106 virtual ~Communicator(); 00107 00108 void factorize2(const double sizeX, const double sizeY, PartitionId procgrid[2]) const; 00109 00110 void factorize3(const double sizeX, const double sizeY, const double sizeZ, PartitionId procgrid[3]) const; 00111 00112 void getGrid2Rank( PartitionId pos[2], const PartitionId procgrid[2] ) const; 00113 00114 void getGrid3Rank( PartitionId pos[3], const PartitionId procgrid[3] ) const; 00115 00126 bool operator==(const Communicator& other) const; 00127 00128 bool operator!=(const Communicator& other) const; 00129 00132 virtual bool isEqual(const Communicator& other) const = 0; 00133 00134 virtual ThreadSafetyLevel getThreadSafetyLevel() const =0; 00135 00138 virtual PartitionId getSize() const = 0; 00139 00144 virtual PartitionId getRank() const = 0; 00145 00154 inline PartitionId getNeighbor(int pos) const; 00155 00165 virtual void all2all( int* recvValues, const int* sendValues ) const = 0; 00166 00181 virtual void exchangeByPlan( 00182 int* const recvData, 00183 const CommunicationPlan& recvPlan, 00184 const int* const sendData, 00185 const CommunicationPlan& sendPlan ) const = 0; 00186 00189 virtual void exchangeByPlan( 00190 float* const recvData, 00191 const CommunicationPlan& recvPlan, 00192 const float* const sendData, 00193 const CommunicationPlan& sendPlan ) const = 0; 00194 00197 virtual void exchangeByPlan( 00198 double* const recvData, 00199 const CommunicationPlan& recvPlan, 00200 const double* const sendData, 00201 const CommunicationPlan& sendPlan ) const = 0; 00202 00203 virtual std::auto_ptr<SyncToken> exchangeByPlanAsync( 00204 int* const recvData, 00205 const CommunicationPlan& recvPlan, 00206 const int* const sendData, 00207 const CommunicationPlan& sendPlan ) const = 0; 00208 00211 virtual std::auto_ptr<SyncToken> exchangeByPlanAsync( 00212 float* const recvData, 00213 const CommunicationPlan& recvPlan, 00214 const float* const sendData, 00215 const CommunicationPlan& sendPlan ) const = 0; 00216 00219 virtual std::auto_ptr<SyncToken> exchangeByPlanAsync( 00220 double* const recvData, 00221 const CommunicationPlan& recvPlan, 00222 const double* const sendData, 00223 const CommunicationPlan& sendPlan ) const = 0; 00224 00225 template<typename T> 00226 void exchangeByPlan( 00227 LAMAArray<T>& recvArray, 00228 const CommunicationPlan& recvPlan, 00229 const LAMAArray<T>& sendArray, 00230 const CommunicationPlan& sendPlan ) const; 00231 00232 template<typename T> 00233 std::auto_ptr<SyncToken> exchangeByPlanAsync( 00234 LAMAArray<T>& recvArray, 00235 const CommunicationPlan& recvPlan, 00236 const LAMAArray<T>& sendArray, 00237 const CommunicationPlan& sendPlan ) const; 00238 00248 template<typename T> 00249 void updateHalo( LAMAArray<T>& haloValues, 00250 const LAMAArray<T>& localValues, 00251 const Halo& halo ) const; 00252 00255 template<typename T> 00256 std::auto_ptr<SyncToken> updateHaloAsync( LAMAArray<T>& haloValues, 00257 const LAMAArray<T>& localValues, 00258 const Halo& halo ) const; 00259 00270 template<typename T> 00271 void shift( LAMAArray<T>& recv, const LAMAArray<T>& send, const int direction ) const; 00272 00282 template<typename T> 00283 std::auto_ptr<SyncToken> shiftAsync( LAMAArray<T>& recvArray, const LAMAArray<T>& sendArray, 00284 const int direction ) const; 00285 00292 void computeOwners( 00293 const std::vector<IndexType>& requiredIndexes, 00294 const Distribution& distribution, 00295 std::vector<PartitionId>& owners) const; 00296 00304 virtual void bcast (double val[], const IndexType n, const PartitionId root) const = 0; 00305 virtual void bcast (float val[], const IndexType n, const PartitionId root) const = 0; 00306 virtual void bcast (int val[], const IndexType n, const PartitionId root) const = 0; 00307 00316 virtual void scatter (double myvals[], const IndexType n, const PartitionId root, const double allvals[]) const = 0; 00317 virtual void scatter (float myvals[], const IndexType n, const PartitionId root, const float allvals[]) const = 0; 00318 virtual void scatter (int myvals[], const IndexType n, const PartitionId root, const int allvals[]) const = 0; 00319 00329 virtual void scatter (double myvals[], const IndexType n, const PartitionId root, 00330 const double allvals[], const IndexType sizes[]) const = 0; 00331 virtual void scatter (float myvals[], const IndexType n, const PartitionId root, 00332 const float allvals[], const IndexType sizes[]) const = 0; 00333 virtual void scatter (int myvals[], const IndexType n, const PartitionId root, 00334 const int allvals[], const IndexType sizes[]) const = 0; 00335 00344 virtual void gather (double allvals[], const IndexType n, const PartitionId root, const double myvals[]) const = 0; 00345 virtual void gather (float allvals[], const IndexType n, const PartitionId root, const float myvals[]) const = 0; 00346 virtual void gather (int allvals[], const IndexType n, const PartitionId root, const int myvals[]) const = 0; 00347 00357 virtual void gather (double allvals[], const IndexType n, const PartitionId root, 00358 const double myvals[], const IndexType sizes[]) const = 0; 00359 virtual void gather (float allvals[], const IndexType n, const PartitionId root, 00360 const float myvals[], const IndexType sizes[]) const = 0; 00361 virtual void gather (int allvals[], const IndexType n, const PartitionId root, 00362 const int myvals[], const IndexType sizes[]) const = 0; 00363 00378 virtual IndexType shift( double newVals[], const IndexType newSize, 00379 const double oldVals[], const IndexType oldSize, 00380 const int direction ) const = 0; 00381 00382 virtual IndexType shift( float newVals[], const IndexType newSize, 00383 const float oldVals[], const IndexType oldSize, 00384 const int direction ) const = 0; 00385 00386 virtual IndexType shift( int newVals[], const IndexType newSize, 00387 const int oldVals[], const IndexType oldSize, 00388 const int direction ) const = 0; 00389 00397 virtual std::auto_ptr<SyncToken> shiftAsync( double newVals[], const double oldVals[], 00398 const IndexType size, const int direction ) const; 00399 00400 virtual std::auto_ptr<SyncToken> shiftAsync( float newVals[], const float oldVals[], 00401 const IndexType size, const int direction ) const; 00402 00403 virtual std::auto_ptr<SyncToken> shiftAsync( int newVals[], const int oldVals[], 00404 const IndexType size, const int direction ) const; 00405 00412 virtual float sum(const float value) const = 0; 00413 virtual double sum(const double value) const = 0; 00414 virtual int sum(const int value) const = 0; 00415 virtual size_t sum(const size_t value) const = 0; 00416 00417 virtual float min(const float value) const = 0; 00418 virtual float max(const float value) const = 0; 00419 00420 virtual double min(const double value) const = 0; 00421 virtual double max(const double value) const = 0; 00422 00423 virtual int min(const int value) const = 0; 00424 virtual int max(const int value) const = 0; 00425 00434 virtual void maxloc( double& val, int& location, const PartitionId root ) const = 0; 00435 virtual void maxloc( float& val, int& location, const PartitionId root ) const = 0; 00436 virtual void maxloc( int& val, int& location, const PartitionId root ) const = 0; 00437 00447 virtual void swap( double val[], const IndexType n, const PartitionId partner ) const = 0; 00448 virtual void swap( float val[], const IndexType n, const PartitionId partner ) const = 0; 00449 virtual void swap( int val[], const IndexType n, const PartitionId partner ) const = 0; 00450 00453 virtual void gather(std::vector<float>& values, float value) const = 0; 00454 00457 virtual void synchronize() const = 0; 00458 00461 virtual void writeAt(std::ostream& stream) const; 00462 00465 const std::string& getType() const { return mCommunicatorType; } 00466 00467 protected: 00468 00469 // Default constructor can only be called by base classes. 00470 00471 Communicator( const std::string& type ); 00472 00473 std::string mCommunicatorType; 00474 00475 LAMA_LOG_DECL_STATIC_LOGGER(logger); 00476 00482 static void getUserProcArray( PartitionId userProcArray[3] ); 00483 00486 template<typename T> 00487 IndexType shift0( T newVals[], const IndexType newSize, 00488 const T oldVals[], const IndexType oldSize ) const; 00489 00492 template<typename T> 00493 std::auto_ptr<SyncToken> defaultShiftAsync( T newVals[], const T oldVals[], 00494 const IndexType size, const int direction ) const; 00495 00500 virtual ContextPtr getCommunicationContext() const = 0; 00501 00502 }; 00503 00504 typedef boost::shared_ptr<const Communicator> CommunicatorPtr; 00505 00506 /* -------------------------------------------------------------------------- */ 00507 00508 PartitionId Communicator::getNeighbor(int pos) const 00509 { 00510 PartitionId size = getSize(); 00511 PartitionId rank = getRank(); 00512 00513 LAMA_ASSERT( std::abs(pos) <= size, 00514 "neighbor pos "<<pos<<" out of range ("<<size<<")" ); 00515 00516 return (size + rank+pos)%size; 00517 } 00518 00519 /* -------------------------------------------------------------------------- */ 00520 00521 template<typename T> 00522 void Communicator::exchangeByPlan( LAMAArray<T>& recvArray, 00523 const CommunicationPlan& recvPlan, 00524 const LAMAArray<T>& sendArray, 00525 const CommunicationPlan& sendPlan ) const 00526 { 00527 LAMA_ASSERT_ERROR( sendArray.size() == sendPlan.totalQuantity(), 00528 "Send array has size " << sendArray.size() 00529 << ", but send plan requires " << sendPlan.totalQuantity() << " entries" ); 00530 00531 IndexType recvSize = recvPlan.totalQuantity(); 00532 00533 ContextPtr comCtx = getCommunicationContext(); 00534 00535 WriteAccess<T> recvData( recvArray, comCtx ); 00536 ReadAccess<T> sendData( sendArray, comCtx ); 00537 00538 recvData.clear(); 00539 recvData.resize( recvSize ); 00540 00541 exchangeByPlan( recvData.get(), recvPlan, sendData.get(), sendPlan ); 00542 } 00543 00544 /* -------------------------------------------------------------------------- */ 00545 00546 template<typename T> 00547 std::auto_ptr<SyncToken> Communicator::exchangeByPlanAsync( LAMAArray<T>& recvArray, 00548 const CommunicationPlan& recvPlan, 00549 const LAMAArray<T>& sendArray, 00550 const CommunicationPlan& sendPlan ) const 00551 { 00552 LAMA_ASSERT_ERROR( sendArray.size() == sendPlan.totalQuantity(), 00553 "Send array has size " << sendArray.size() 00554 << ", but send plan requires " << sendPlan.totalQuantity() << " entries" ); 00555 00556 IndexType recvSize = recvPlan.totalQuantity(); 00557 00558 // allocate accesses, SyncToken will take ownership 00559 00560 std::auto_ptr<HostWriteAccess<T> > recvData( new HostWriteAccess<T>( recvArray ) ); 00561 std::auto_ptr<HostReadAccess<T> > sendData( new HostReadAccess<T>( sendArray ) ); 00562 00563 recvData->clear(); 00564 recvData->resize( recvSize ); 00565 00566 std::auto_ptr<SyncToken> token = exchangeByPlanAsync( recvData->get(), recvPlan, sendData->get(), sendPlan ); 00567 00568 // Add the read and write access to the sync token to get it freed after successful wait 00569 00570 token->pushAccess( std::auto_ptr<BaseAccess> ( recvData.release() ) ); 00571 token->pushAccess( std::auto_ptr<BaseAccess> ( sendData.release() ) ); 00572 00573 return token; 00574 } 00575 00576 } 00577 00578 #endif // LAMA_COMMUNICATOR_HPP_