LAMA
/home/brandes/workspace/LAMA/src/lama/storage/CRTPMatrixStorage.hpp
Go to the documentation of this file.
00001 
00034 #ifndef LAMA_CRTP_MATRIX_STORAGE_HPP_
00035 #define LAMA_CRTP_MATRIX_STORAGE_HPP_
00036 
00037 // for dll_import
00038 #include <lama/config.hpp>
00039 
00040 // base classes
00041 #include <lama/storage/MatrixStorage.hpp>
00042 
00043 namespace lama
00044 {
00045 
00059 template<class Derived, typename ValueType>
00060 class LAMA_DLL_IMPORTEXPORT CRTPMatrixStorage : public MatrixStorage<ValueType>
00061 {
00062 public:
00063 
00066     CRTPMatrixStorage( const IndexType numRows, const IndexType numColumns) 
00067 
00068        : MatrixStorage<ValueType>( numRows, numColumns )
00069 
00070     {
00071     }
00072 
00075     CRTPMatrixStorage( ) 
00076 
00077        : MatrixStorage<ValueType>( 0, 0 )
00078 
00079     {
00080     }
00081 
00088     void setCSRData( const IndexType numRows, const IndexType numColumns, 
00089                      const IndexType numValues,
00090                      const LAMAArray<IndexType>& ia, const LAMAArray<IndexType>& ja,
00091                      const _LAMAArray& values )
00092     {
00093         Scalar::ScalarType arrayType = values.getValueType();
00094 
00095         if ( arrayType == Scalar::DOUBLE )
00096         {
00097             const LAMAArray<double>& typedValues = dynamic_cast<const LAMAArray<double>&>( values );
00098             static_cast<Derived*>(this)->setCSRDataImpl( numRows, numColumns, numValues, 
00099                                                ia, ja, typedValues, this->getContextPtr() );
00100         }
00101         else if ( arrayType == Scalar::FLOAT )
00102         {
00103             const LAMAArray<float>& typedValues = dynamic_cast<const LAMAArray<float>&>( values );
00104             static_cast<Derived*>(this)->setCSRDataImpl( numRows, numColumns, numValues, 
00105                                                ia, ja, typedValues, this->getContextPtr() );
00106         }
00107         else
00108         {
00109             LAMA_THROWEXCEPTION( *this << ": setCSRData with value type " << arrayType << " not supported" );
00110         }
00111     }
00112 
00115     void buildCSRSizes( LAMAArray<IndexType>& ia ) const
00116     {
00117         // The sizes will be available via buildCSR with NULL for ja, values
00118 
00119         LAMAArray<IndexType>* ja     = NULL;
00120         LAMAArray<ValueType>* values = NULL;
00121 
00122         static_cast<const Derived*>(this)->buildCSR( ia, ja, values, this->getContextPtr() );
00123     }
00124     
00125     void buildCSRData( LAMAArray<IndexType>& ia, LAMAArray<IndexType>& ja, _LAMAArray& values ) const
00126     {
00127         Scalar::ScalarType arrayType = values.getValueType();
00128 
00129         if ( arrayType == Scalar::DOUBLE )
00130         {
00131             LAMAArray<double>& typedValues = dynamic_cast<LAMAArray<double>&>( values );
00132             static_cast<const Derived*>(this)->buildCSR( ia, &ja, &typedValues, this->getContextPtr() );
00133         }
00134         else if ( arrayType == Scalar::FLOAT )
00135         {
00136             LAMAArray<float>& typedValues = dynamic_cast<LAMAArray<float>&>( values );
00137             static_cast<const Derived*>(this)->buildCSR( ia, &ja, &typedValues, this->getContextPtr() );
00138         }
00139         else
00140         {
00141             LAMA_THROWEXCEPTION( *this << ": build CSR with value type " << arrayType << " not supported" );
00142         }
00143     }
00144     
00147     void getRow( _LAMAArray& row, const IndexType i ) const
00148     {
00149         Scalar::ScalarType arrayType = row.getValueType();
00150 
00151         if ( arrayType == Scalar::DOUBLE )
00152         {
00153             LAMAArray<double>& typedRow = dynamic_cast<LAMAArray<double>&>( row );
00154             static_cast<const Derived*>( this )->getRowImpl( typedRow, i );
00155         }
00156         else if ( arrayType == Scalar::FLOAT )
00157         {
00158             LAMAArray<float>& typedRow = dynamic_cast<LAMAArray<float>&>( row );
00159             static_cast<const Derived*>( this )->getRowImpl( typedRow, i );
00160         }
00161         else
00162         {
00163             LAMA_THROWEXCEPTION( "getRow for array of type " << arrayType << " not supported" );
00164         }
00165     }
00166 
00167     void getDiagonal( _LAMAArray& diagonal ) const
00168     {
00169         if ( ! this->hasDiagonalProperty() )
00170         {
00171            LAMA_THROWEXCEPTION( *this << ": has not diagonal property, cannot set diagonal");
00172         }
00173 
00174         Scalar::ScalarType arrayType = diagonal.getValueType();
00175     
00176         if ( arrayType == Scalar::DOUBLE )
00177         {
00178             LAMAArray<double>& typedDiagonal = dynamic_cast<LAMAArray<double>&>( diagonal );
00179             static_cast<const Derived*>(this)->getDiagonalImpl( typedDiagonal );
00180         }
00181         else if ( arrayType == Scalar::FLOAT )
00182         {
00183             LAMAArray<float>& typedDiagonal = dynamic_cast<LAMAArray<float>&>( diagonal );
00184             static_cast<const Derived*>(this)->getDiagonalImpl( typedDiagonal );
00185         }
00186         else
00187         {
00188             LAMA_THROWEXCEPTION( "getDiagonal for array of type " << arrayType << " not supported" );
00189         }
00190     }
00191 
00192     void setDiagonal( const Scalar value ) 
00193     {
00194         static_cast<Derived*>(this)->setDiagonalImpl( value );
00195     }
00196 
00197     void setDiagonal( const _LAMAArray& diagonal )
00198     {
00199         IndexType numDiagonalElements = diagonal.size();
00200     
00201         if (    numDiagonalElements > this->getNumRows()
00202              || numDiagonalElements > this->getNumColumns() )
00203         {
00204             LAMA_THROWEXCEPTION( "Diagonal of size " << numDiagonalElements
00205                                 << " too large for matrix: " << *this );
00206         }
00207     
00208         if ( ! this->hasDiagonalProperty() )
00209         {
00210             LAMA_THROWEXCEPTION( *this << ": has not diagonal property, cannot set diagonal");
00211         }
00212     
00213         Scalar::ScalarType arrayType = diagonal.getValueType();
00214     
00215         if ( arrayType == Scalar::DOUBLE )
00216         {
00217             const LAMAArray<double>& typedDiagonal = dynamic_cast<const LAMAArray<double>&>( diagonal );
00218             static_cast<Derived*>(this)->setDiagonalImpl( typedDiagonal );
00219         }
00220         else if ( arrayType == Scalar::FLOAT )
00221         {
00222             const LAMAArray<float>& typedDiagonal = dynamic_cast<const LAMAArray<float>&>( diagonal );
00223             static_cast<Derived*>(this)->setDiagonalImpl( typedDiagonal );
00224         }
00225         else
00226         {
00227             LAMA_THROWEXCEPTION( "setDiagonal to array of type " << arrayType << " not supported" );
00228         }
00229     }
00230 
00231     void scale( const Scalar value ) 
00232     {
00233         static_cast<Derived*>(this)->scaleImpl( value );
00234     }
00235 
00238     void scale( const _LAMAArray& diagonal )
00239     {
00240         LAMA_ASSERT_EQUAL_ERROR( this->getNumRows(), diagonal.size() );
00241 
00242         Scalar::ScalarType arrayType = diagonal.getValueType();
00243 
00244         if ( arrayType == Scalar::DOUBLE )
00245         {
00246             const LAMAArray<double>& typedDiagonal = dynamic_cast<const LAMAArray<double>&>( diagonal );
00247             static_cast<Derived*>(this)->scaleImpl( typedDiagonal );
00248         }
00249         else if ( arrayType == Scalar::FLOAT )
00250         {
00251             const LAMAArray<float>& typedDiagonal = dynamic_cast<const LAMAArray<float>&>( diagonal );
00252             static_cast<Derived*>(this)->scaleImpl( typedDiagonal );
00253         }
00254         else
00255         {
00256             LAMA_THROWEXCEPTION( "scale of type " << arrayType << " not supported" );
00257         }
00258     }
00259 
00262     virtual const char* getTypeName() const
00263     {
00264         // each derived class provides static method to get the type name.
00265 
00266         return Derived::typeName();
00267     }
00268 };
00269 
00270 } // namespace lama
00271 
00272 #endif // LAMA_CRTP_MATRIX_STORAGE_HPP_