LAMA
|
00001 00034 #ifndef LAMA_CRTP_MATRIX_STORAGE_HPP_ 00035 #define LAMA_CRTP_MATRIX_STORAGE_HPP_ 00036 00037 // for dll_import 00038 #include <lama/config.hpp> 00039 00040 // base classes 00041 #include <lama/storage/MatrixStorage.hpp> 00042 00043 namespace lama 00044 { 00045 00059 template<class Derived, typename ValueType> 00060 class LAMA_DLL_IMPORTEXPORT CRTPMatrixStorage : public MatrixStorage<ValueType> 00061 { 00062 public: 00063 00066 CRTPMatrixStorage( const IndexType numRows, const IndexType numColumns) 00067 00068 : MatrixStorage<ValueType>( numRows, numColumns ) 00069 00070 { 00071 } 00072 00075 CRTPMatrixStorage( ) 00076 00077 : MatrixStorage<ValueType>( 0, 0 ) 00078 00079 { 00080 } 00081 00088 void setCSRData( const IndexType numRows, const IndexType numColumns, 00089 const IndexType numValues, 00090 const LAMAArray<IndexType>& ia, const LAMAArray<IndexType>& ja, 00091 const _LAMAArray& values ) 00092 { 00093 Scalar::ScalarType arrayType = values.getValueType(); 00094 00095 if ( arrayType == Scalar::DOUBLE ) 00096 { 00097 const LAMAArray<double>& typedValues = dynamic_cast<const LAMAArray<double>&>( values ); 00098 static_cast<Derived*>(this)->setCSRDataImpl( numRows, numColumns, numValues, 00099 ia, ja, typedValues, this->getContextPtr() ); 00100 } 00101 else if ( arrayType == Scalar::FLOAT ) 00102 { 00103 const LAMAArray<float>& typedValues = dynamic_cast<const LAMAArray<float>&>( values ); 00104 static_cast<Derived*>(this)->setCSRDataImpl( numRows, numColumns, numValues, 00105 ia, ja, typedValues, this->getContextPtr() ); 00106 } 00107 else 00108 { 00109 LAMA_THROWEXCEPTION( *this << ": setCSRData with value type " << arrayType << " not supported" ); 00110 } 00111 } 00112 00115 void buildCSRSizes( LAMAArray<IndexType>& ia ) const 00116 { 00117 // The sizes will be available via buildCSR with NULL for ja, values 00118 00119 LAMAArray<IndexType>* ja = NULL; 00120 LAMAArray<ValueType>* values = NULL; 00121 00122 static_cast<const Derived*>(this)->buildCSR( ia, ja, values, this->getContextPtr() ); 00123 } 00124 00125 void buildCSRData( LAMAArray<IndexType>& ia, LAMAArray<IndexType>& ja, _LAMAArray& values ) const 00126 { 00127 Scalar::ScalarType arrayType = values.getValueType(); 00128 00129 if ( arrayType == Scalar::DOUBLE ) 00130 { 00131 LAMAArray<double>& typedValues = dynamic_cast<LAMAArray<double>&>( values ); 00132 static_cast<const Derived*>(this)->buildCSR( ia, &ja, &typedValues, this->getContextPtr() ); 00133 } 00134 else if ( arrayType == Scalar::FLOAT ) 00135 { 00136 LAMAArray<float>& typedValues = dynamic_cast<LAMAArray<float>&>( values ); 00137 static_cast<const Derived*>(this)->buildCSR( ia, &ja, &typedValues, this->getContextPtr() ); 00138 } 00139 else 00140 { 00141 LAMA_THROWEXCEPTION( *this << ": build CSR with value type " << arrayType << " not supported" ); 00142 } 00143 } 00144 00147 void getRow( _LAMAArray& row, const IndexType i ) const 00148 { 00149 Scalar::ScalarType arrayType = row.getValueType(); 00150 00151 if ( arrayType == Scalar::DOUBLE ) 00152 { 00153 LAMAArray<double>& typedRow = dynamic_cast<LAMAArray<double>&>( row ); 00154 static_cast<const Derived*>( this )->getRowImpl( typedRow, i ); 00155 } 00156 else if ( arrayType == Scalar::FLOAT ) 00157 { 00158 LAMAArray<float>& typedRow = dynamic_cast<LAMAArray<float>&>( row ); 00159 static_cast<const Derived*>( this )->getRowImpl( typedRow, i ); 00160 } 00161 else 00162 { 00163 LAMA_THROWEXCEPTION( "getRow for array of type " << arrayType << " not supported" ); 00164 } 00165 } 00166 00167 void getDiagonal( _LAMAArray& diagonal ) const 00168 { 00169 if ( ! this->hasDiagonalProperty() ) 00170 { 00171 LAMA_THROWEXCEPTION( *this << ": has not diagonal property, cannot set diagonal"); 00172 } 00173 00174 Scalar::ScalarType arrayType = diagonal.getValueType(); 00175 00176 if ( arrayType == Scalar::DOUBLE ) 00177 { 00178 LAMAArray<double>& typedDiagonal = dynamic_cast<LAMAArray<double>&>( diagonal ); 00179 static_cast<const Derived*>(this)->getDiagonalImpl( typedDiagonal ); 00180 } 00181 else if ( arrayType == Scalar::FLOAT ) 00182 { 00183 LAMAArray<float>& typedDiagonal = dynamic_cast<LAMAArray<float>&>( diagonal ); 00184 static_cast<const Derived*>(this)->getDiagonalImpl( typedDiagonal ); 00185 } 00186 else 00187 { 00188 LAMA_THROWEXCEPTION( "getDiagonal for array of type " << arrayType << " not supported" ); 00189 } 00190 } 00191 00192 void setDiagonal( const Scalar value ) 00193 { 00194 static_cast<Derived*>(this)->setDiagonalImpl( value ); 00195 } 00196 00197 void setDiagonal( const _LAMAArray& diagonal ) 00198 { 00199 IndexType numDiagonalElements = diagonal.size(); 00200 00201 if ( numDiagonalElements > this->getNumRows() 00202 || numDiagonalElements > this->getNumColumns() ) 00203 { 00204 LAMA_THROWEXCEPTION( "Diagonal of size " << numDiagonalElements 00205 << " too large for matrix: " << *this ); 00206 } 00207 00208 if ( ! this->hasDiagonalProperty() ) 00209 { 00210 LAMA_THROWEXCEPTION( *this << ": has not diagonal property, cannot set diagonal"); 00211 } 00212 00213 Scalar::ScalarType arrayType = diagonal.getValueType(); 00214 00215 if ( arrayType == Scalar::DOUBLE ) 00216 { 00217 const LAMAArray<double>& typedDiagonal = dynamic_cast<const LAMAArray<double>&>( diagonal ); 00218 static_cast<Derived*>(this)->setDiagonalImpl( typedDiagonal ); 00219 } 00220 else if ( arrayType == Scalar::FLOAT ) 00221 { 00222 const LAMAArray<float>& typedDiagonal = dynamic_cast<const LAMAArray<float>&>( diagonal ); 00223 static_cast<Derived*>(this)->setDiagonalImpl( typedDiagonal ); 00224 } 00225 else 00226 { 00227 LAMA_THROWEXCEPTION( "setDiagonal to array of type " << arrayType << " not supported" ); 00228 } 00229 } 00230 00231 void scale( const Scalar value ) 00232 { 00233 static_cast<Derived*>(this)->scaleImpl( value ); 00234 } 00235 00238 void scale( const _LAMAArray& diagonal ) 00239 { 00240 LAMA_ASSERT_EQUAL_ERROR( this->getNumRows(), diagonal.size() ); 00241 00242 Scalar::ScalarType arrayType = diagonal.getValueType(); 00243 00244 if ( arrayType == Scalar::DOUBLE ) 00245 { 00246 const LAMAArray<double>& typedDiagonal = dynamic_cast<const LAMAArray<double>&>( diagonal ); 00247 static_cast<Derived*>(this)->scaleImpl( typedDiagonal ); 00248 } 00249 else if ( arrayType == Scalar::FLOAT ) 00250 { 00251 const LAMAArray<float>& typedDiagonal = dynamic_cast<const LAMAArray<float>&>( diagonal ); 00252 static_cast<Derived*>(this)->scaleImpl( typedDiagonal ); 00253 } 00254 else 00255 { 00256 LAMA_THROWEXCEPTION( "scale of type " << arrayType << " not supported" ); 00257 } 00258 } 00259 00262 virtual const char* getTypeName() const 00263 { 00264 // each derived class provides static method to get the type name. 00265 00266 return Derived::typeName(); 00267 } 00268 }; 00269 00270 } // namespace lama 00271 00272 #endif // LAMA_CRTP_MATRIX_STORAGE_HPP_