LAMA
/home/brandes/workspace/LAMA/src/lama/Communicator.hpp
Go to the documentation of this file.
00001 
00033 #ifndef LAMA_COMMUNICATOR_HPP_
00034 #define LAMA_COMMUNICATOR_HPP_
00035 
00036 // for dll_import
00037 #include <lama/config.hpp>
00038 
00039 // base classes
00040 #include <lama/NonCopyable.hpp>
00041 #include <lama/Printable.hpp>
00042 
00043 // others
00044 #include <lama/LAMATypes.hpp>
00045 #include <lama/HostReadAccess.hpp>
00046 #include <lama/HostWriteAccess.hpp>
00047 #include <lama/CommunicationPlan.hpp>
00048 
00049 #include <lama/exception/LAMAAssert.hpp>
00050 
00051 // logging
00052 #include <logging/logging.hpp>
00053 
00054 // boost
00055 #include <boost/shared_ptr.hpp>
00056 
00057 #include <memory>
00058 #include <vector>
00059 #include <cmath>
00060 
00061 namespace lama
00062 {
00063 
00064 // Forward declaration of all classes that are used in the interface
00065 
00066 class CommunicationPlan;
00067 
00068 template<typename T> class LAMAArray;
00069 
00070 class SyncToken;
00071 
00072 class Distribution;
00073 
00074 class Halo;
00075 
00094 class LAMA_DLL_IMPORTEXPORT Communicator : public Printable, private NonCopyable
00095 {
00096 
00097 public:
00098 
00101     enum ThreadSafetyLevel { Funneled = 1,    
00102                              Serialized = 2,  
00103                              Multiple = 3     
00104                            };
00105 
00106     virtual ~Communicator();
00107 
00108     void factorize2(const double sizeX, const double sizeY, PartitionId procgrid[2]) const;
00109 
00110     void factorize3(const double sizeX, const double sizeY, const double sizeZ, PartitionId procgrid[3]) const;
00111 
00112     void getGrid2Rank( PartitionId pos[2], const PartitionId procgrid[2] ) const;
00113 
00114     void getGrid3Rank( PartitionId pos[3], const PartitionId procgrid[3] ) const;
00115 
00126     bool operator==(const Communicator& other) const;
00127 
00128     bool operator!=(const Communicator& other) const;
00129 
00132     virtual bool isEqual(const Communicator& other) const = 0;
00133 
00134     virtual ThreadSafetyLevel getThreadSafetyLevel() const =0;
00135 
00138     virtual PartitionId getSize() const = 0;
00139 
00144     virtual PartitionId getRank() const = 0;
00145 
00154     inline PartitionId getNeighbor(int pos) const;
00155 
00165     virtual void all2all( int* recvValues, const int* sendValues ) const = 0;
00166 
00181     virtual void exchangeByPlan(
00182         int* const recvData,
00183         const CommunicationPlan& recvPlan,
00184         const int* const sendData,
00185         const CommunicationPlan& sendPlan ) const = 0;
00186 
00189     virtual void exchangeByPlan(
00190         float* const recvData,
00191         const CommunicationPlan& recvPlan,
00192         const float* const sendData,
00193         const CommunicationPlan& sendPlan ) const = 0;
00194 
00197     virtual void exchangeByPlan(
00198         double* const recvData,
00199         const CommunicationPlan& recvPlan,
00200         const double* const sendData,
00201         const CommunicationPlan& sendPlan ) const = 0;
00202 
00203     virtual std::auto_ptr<SyncToken> exchangeByPlanAsync(
00204         int* const recvData,
00205         const CommunicationPlan& recvPlan,
00206         const int* const sendData,
00207         const CommunicationPlan& sendPlan ) const = 0;
00208 
00211     virtual std::auto_ptr<SyncToken> exchangeByPlanAsync(
00212         float* const recvData,
00213         const CommunicationPlan& recvPlan,
00214         const float* const sendData,
00215         const CommunicationPlan& sendPlan ) const = 0;
00216 
00219     virtual std::auto_ptr<SyncToken> exchangeByPlanAsync(
00220         double* const recvData,
00221         const CommunicationPlan& recvPlan,
00222         const double* const sendData,
00223         const CommunicationPlan& sendPlan ) const = 0;
00224 
00225     template<typename T> 
00226     void exchangeByPlan( 
00227         LAMAArray<T>& recvArray,
00228         const CommunicationPlan& recvPlan,
00229         const LAMAArray<T>& sendArray,
00230         const CommunicationPlan& sendPlan ) const;
00231 
00232     template<typename T> 
00233     std::auto_ptr<SyncToken> exchangeByPlanAsync( 
00234         LAMAArray<T>& recvArray,
00235         const CommunicationPlan& recvPlan,
00236         const LAMAArray<T>& sendArray,
00237         const CommunicationPlan& sendPlan ) const;
00238                       
00248     template<typename T>
00249     void updateHalo( LAMAArray<T>& haloValues, 
00250                      const LAMAArray<T>& localValues, 
00251                      const Halo& halo ) const;
00252 
00255     template<typename T>
00256     std::auto_ptr<SyncToken> updateHaloAsync( LAMAArray<T>& haloValues,
00257                                               const LAMAArray<T>& localValues,
00258                                               const Halo& halo ) const;
00259 
00270     template<typename T>
00271     void shift( LAMAArray<T>& recv, const LAMAArray<T>& send, const int direction ) const;
00272 
00282     template<typename T>
00283     std::auto_ptr<SyncToken> shiftAsync( LAMAArray<T>& recvArray, const LAMAArray<T>& sendArray,
00284                                          const int direction ) const;
00285 
00292     void computeOwners(
00293         const std::vector<IndexType>& requiredIndexes,
00294         const Distribution& distribution,
00295         std::vector<PartitionId>& owners) const;
00296 
00304     virtual void bcast (double val[], const IndexType n, const PartitionId root) const = 0;
00305     virtual void bcast (float val[], const IndexType n, const PartitionId root) const = 0;
00306     virtual void bcast (int val[], const IndexType n, const PartitionId root) const = 0;
00307 
00316     virtual void scatter (double myvals[], const IndexType n, const PartitionId root, const double allvals[]) const = 0;
00317     virtual void scatter (float myvals[], const IndexType n, const PartitionId root, const float allvals[]) const = 0;
00318     virtual void scatter (int myvals[], const IndexType n, const PartitionId root, const int allvals[]) const = 0;
00319 
00329     virtual void scatter (double myvals[], const IndexType n, const PartitionId root, 
00330                           const double allvals[], const IndexType sizes[])  const = 0;
00331     virtual void scatter (float myvals[], const IndexType n, const PartitionId root, 
00332                           const float allvals[], const IndexType sizes[])  const = 0;
00333     virtual void scatter (int myvals[], const IndexType n, const PartitionId root, 
00334                           const int allvals[], const IndexType sizes[])  const = 0;
00335 
00344     virtual void gather (double allvals[], const IndexType n, const PartitionId root, const double myvals[]) const = 0;
00345     virtual void gather (float allvals[], const IndexType n, const PartitionId root, const float myvals[]) const = 0;
00346     virtual void gather (int allvals[], const IndexType n, const PartitionId root, const int myvals[]) const = 0;
00347 
00357     virtual void gather (double allvals[], const IndexType n, const PartitionId root,
00358                           const double myvals[], const IndexType sizes[])  const = 0;
00359     virtual void gather (float allvals[], const IndexType n, const PartitionId root,
00360                           const float myvals[], const IndexType sizes[])  const = 0;
00361     virtual void gather (int allvals[], const IndexType n, const PartitionId root,
00362                           const int myvals[], const IndexType sizes[])  const = 0;
00363 
00378     virtual IndexType shift( double newVals[], const IndexType newSize, 
00379                              const double oldVals[], const IndexType oldSize, 
00380                              const int direction ) const = 0;
00381 
00382     virtual IndexType shift( float newVals[], const IndexType newSize, 
00383                              const float oldVals[], const IndexType oldSize,
00384                              const int direction ) const = 0;
00385 
00386     virtual IndexType shift( int newVals[], const IndexType newSize, 
00387                              const int oldVals[], const IndexType oldSize,
00388                              const int direction ) const = 0;
00389 
00397     virtual std::auto_ptr<SyncToken> shiftAsync( double newVals[], const double oldVals[], 
00398                                                  const IndexType size, const int direction ) const;
00399 
00400     virtual std::auto_ptr<SyncToken> shiftAsync( float newVals[], const float oldVals[], 
00401                                                  const IndexType size, const int direction ) const;
00402 
00403     virtual std::auto_ptr<SyncToken> shiftAsync( int newVals[], const int oldVals[], 
00404                                                  const IndexType size, const int direction ) const;
00405 
00412     virtual float  sum(const float value) const = 0;
00413     virtual double sum(const double value) const = 0;
00414     virtual int    sum(const int value) const = 0;
00415     virtual size_t sum(const size_t value) const = 0;
00416 
00417     virtual float min(const float value) const = 0;
00418     virtual float max(const float value) const = 0;
00419 
00420     virtual double min(const double value) const = 0;
00421     virtual double max(const double value) const = 0;
00422 
00423     virtual int min(const int value) const = 0;
00424     virtual int max(const int value) const = 0;
00425 
00434     virtual void maxloc( double& val, int& location, const PartitionId root ) const = 0;
00435     virtual void maxloc( float&  val, int& location, const PartitionId root ) const = 0;
00436     virtual void maxloc( int&    val, int& location, const PartitionId root ) const = 0;
00437 
00447     virtual void swap( double val[], const IndexType n, const PartitionId partner ) const = 0;
00448     virtual void swap( float  val[], const IndexType n, const PartitionId partner ) const = 0;
00449     virtual void swap( int    val[], const IndexType n, const PartitionId partner ) const = 0;
00450 
00453     virtual void gather(std::vector<float>& values, float value) const = 0;
00454 
00457     virtual void synchronize() const = 0;
00458 
00461     virtual void writeAt(std::ostream& stream) const;
00462 
00465     const std::string& getType() const { return mCommunicatorType; }
00466 
00467 protected:
00468 
00469     // Default constructor can only be called by base classes.
00470 
00471     Communicator( const std::string& type );
00472 
00473     std::string mCommunicatorType;
00474 
00475     LAMA_LOG_DECL_STATIC_LOGGER(logger);
00476 
00482     static void getUserProcArray( PartitionId userProcArray[3] );
00483 
00486     template<typename T>
00487     IndexType shift0( T newVals[], const IndexType newSize,
00488                       const T oldVals[], const IndexType oldSize ) const;
00489 
00492     template<typename T>
00493     std::auto_ptr<SyncToken> defaultShiftAsync( T newVals[], const T oldVals[], 
00494                                                 const IndexType size, const int direction ) const;
00495 
00500     virtual ContextPtr getCommunicationContext() const = 0;
00501 
00502 };
00503 
00504 typedef boost::shared_ptr<const Communicator> CommunicatorPtr;
00505 
00506 /* -------------------------------------------------------------------------- */
00507 
00508 PartitionId Communicator::getNeighbor(int pos) const
00509 {
00510     PartitionId size = getSize();
00511     PartitionId rank = getRank();
00512 
00513     LAMA_ASSERT( std::abs(pos) <= size,
00514             "neighbor pos "<<pos<<" out of range ("<<size<<")" );
00515 
00516     return (size + rank+pos)%size;
00517 }
00518 
00519 /* -------------------------------------------------------------------------- */
00520 
00521 template<typename T>
00522 void Communicator::exchangeByPlan( LAMAArray<T>& recvArray,
00523                                    const CommunicationPlan& recvPlan,
00524                                    const LAMAArray<T>& sendArray,
00525                                    const CommunicationPlan& sendPlan ) const
00526 {
00527     LAMA_ASSERT_ERROR( sendArray.size() == sendPlan.totalQuantity(), 
00528                   "Send array has size " << sendArray.size() 
00529                   << ", but send plan requires " << sendPlan.totalQuantity() << " entries" );
00530                   
00531     IndexType recvSize = recvPlan.totalQuantity();
00532 
00533     ContextPtr comCtx = getCommunicationContext();
00534 
00535     WriteAccess<T> recvData( recvArray, comCtx );
00536     ReadAccess<T> sendData( sendArray, comCtx );
00537 
00538     recvData.clear();
00539     recvData.resize( recvSize );
00540 
00541     exchangeByPlan( recvData.get(), recvPlan, sendData.get(), sendPlan );
00542 }
00543 
00544 /* -------------------------------------------------------------------------- */
00545 
00546 template<typename T>
00547 std::auto_ptr<SyncToken> Communicator::exchangeByPlanAsync( LAMAArray<T>& recvArray,
00548                                                             const CommunicationPlan& recvPlan,
00549                                                             const LAMAArray<T>& sendArray,
00550                                                             const CommunicationPlan& sendPlan ) const
00551 {
00552     LAMA_ASSERT_ERROR( sendArray.size() == sendPlan.totalQuantity(), 
00553                   "Send array has size " << sendArray.size() 
00554                   << ", but send plan requires " << sendPlan.totalQuantity() << " entries" );
00555                   
00556     IndexType recvSize = recvPlan.totalQuantity();
00557 
00558     // allocate accesses, SyncToken will take ownership
00559 
00560     std::auto_ptr<HostWriteAccess<T> > recvData( new HostWriteAccess<T>( recvArray ) );
00561     std::auto_ptr<HostReadAccess<T> > sendData( new HostReadAccess<T>( sendArray ) );
00562 
00563     recvData->clear();
00564     recvData->resize( recvSize );
00565 
00566     std::auto_ptr<SyncToken> token = exchangeByPlanAsync( recvData->get(), recvPlan, sendData->get(), sendPlan );
00567 
00568     // Add the read and write access to the sync token to get it freed after successful wait
00569 
00570     token->pushAccess( std::auto_ptr<BaseAccess> ( recvData.release() ) );
00571     token->pushAccess( std::auto_ptr<BaseAccess> ( sendData.release() ) );
00572 
00573     return token;
00574 }
00575 
00576 }
00577 
00578 #endif // LAMA_COMMUNICATOR_HPP_