#include #include "sparsemat.h" using namespace std; // Text of error messages, used by Matrix::ReportError // Do not change this! char *SparseMatrix::ErrorMessages[] = { "", // ERROR_INVALID_SIZE "dimension mismatch", // ERROR_SIZE_MISMATCH "index out of range", // ERROR_INVALID_INDEX } ; void SparseMatrix::ReportError( ErrorCode code ) { string prefix( "Error: " ); throw std::runtime_error( prefix + ErrorMessages[code] ); } SparseMatrix::SparseMatrix( int rows, int cols ) { if ( cols == 0 ) cols = rows; m_rows = rows; m_cols = cols; } SparseMatrix::SparseMatrix( const SparseMatrix& A ) { m_rows = A.m_rows; m_cols = A.m_cols; m_data = A.m_data; m_row_indices = A.m_row_indices; m_col_indices = A.m_col_indices; } SparseMatrix::~SparseMatrix() { } void SparseMatrix::SetSize( int rows, int cols ) { m_rows = rows; m_cols = cols; m_row_indices.clear(); m_col_indices.clear(); m_data.clear(); } ostream& operator { for ( int i = 0; i < A.get_rows(); i++ ) { for ( int j = 0; j < A.get_cols(); j++ ) { double x = A( i, j ); os.width(10); os } cout } return os; } double& SparseMatrix::operator()( int i, int j ) { if ( i >= m_rows || j >= m_cols ) ReportError( ERROR_INVALID_INDEX ); m_row_indices[i].insert( j ); m_col_indices[j].insert( i ); return m_data[std::make_pair(i, j)]; } double SparseMatrix::operator()( int i, int j ) const { return Get(i, j); } void SparseMatrix::Squeeze( double tol ) { for ( iset_iter i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { if ( fabs( m_data[std::make_pair(i->first, *j)] ) { i->second.erase( *j ); if ( i->second.size() == 0 ) m_row_indices.erase( i ); m_col_indices[*j].erase( i->first ); if ( m_col_indices[*j].size() == 0 ) m_col_indices.erase( *j ); } } } } int SparseMatrix::nnz() const { int total = 0; for ( iset_citer i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { total += i->second.size(); } return total; } elt_citer SparseMatrix::get_row_begin( int i ) const { return m_row_indices.find( i )->second.begin(); } elt_citer SparseMatrix::get_row_end( int i ) const { return m_row_indices.find( i )->second.end(); } elt_citer SparseMatrix::get_col_begin( int i ) const { return m_col_indices.find( i )->second.begin(); } elt_citer SparseMatrix::get_col_end( int i ) const { return m_col_indices.find( i )->second.end(); } SparseMatrix SparseMatrix::operator*( const SparseMatrix& B ) const { if ( m_cols != B.m_rows ) ReportError( ERROR_SIZE_MISMATCH ); SparseMatrix C( m_rows, B.m_cols ); for ( int i = 0; i < m_rows; i++ ) { for ( int j = 0; j < m_cols; j++ ) { double elt = 0.0; elt_citer irow = get_row_begin( i ); elt_citer jcol = B.get_col_begin( j ); for ( ; irow != get_row_end( i ); irow++ ) { elt += Get( i, *irow ) * B( *irow, j ); } if ( elt != 0.0 ) C( i, j ) = elt; } } return C; } SparseMatrix& SparseMatrix::operator=( const SparseMatrix& A ) { m_rows = A.m_rows; m_cols = A.m_cols; m_data = A.m_data; m_row_indices = A.m_row_indices; m_col_indices = A.m_col_indices; return *this; } SparseMatrix& SparseMatrix::operator+=( const SparseMatrix& A ) { if ( m_rows != A.m_rows || m_cols != A.m_cols ) ReportError( ERROR_SIZE_MISMATCH ); for ( iset_citer i = A.m_row_indices.begin(); i != A.m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { double elt = A( i->first, *j ); Set( i->first, *j, Get( i->first, *j ) + elt ); } } return *this; } SparseMatrix& SparseMatrix::operator-=( const SparseMatrix& A ) { *this += -A; return *this; } SparseMatrix& SparseMatrix::operator*=( double s ) { for ( mat_iter i = m_data.begin(); i != m_data.end(); i++ ) i->second *= s; return *this; } SparseMatrix& SparseMatrix::operator/=( double s ) { *this *= (1.0 / s); return *this; } SparseMatrix operator*( double s, const SparseMatrix& A ) { return A * s; } SparseMatrix SparseMatrix::operator+( const SparseMatrix& B ) const { if ( m_rows != B.m_rows || m_cols != B.m_cols ) ReportError( ERROR_SIZE_MISMATCH ); SparseMatrix C( *this ); for ( iset_citer i = B.m_row_indices.begin(); i != B.m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { double elt = B( i->first, *j ); C( i->first, *j ) += elt; } } return C; } double SparseMatrix::Get( int i, int j ) const { if ( i >= m_rows || j >= m_cols ) ReportError( ERROR_INVALID_INDEX ); mat_citer iter = m_data.find( make_pair( i, j ) ); if ( iter == m_data.end() ) return 0.0; return iter->second; } void SparseMatrix::Set( int i, int j, double x ) { m_data[make_pair(i,j)] = x; m_row_indices[i].insert( j ); m_col_indices[j].insert( i ); } void SparseMatrix::Identity( int rows, int cols ) { if ( cols == 0 ) cols = rows; SetSize( rows, cols ); for ( int i = 0; i < rows && i < cols; i++ ) Set( i, i, 1.0 ); } SparseMatrix SparseMatrix::Transpose() const { SparseMatrix T( m_cols, m_rows ); for ( mat_citer i = m_data.begin(); i != m_data.end(); i++ ) { T( i->first.second, i->first.first ) = Get( i->first.first, i->first.second ); } return T; } SparseMatrix SparseMatrix::TriU( int diag ) const { SparseMatrix S( m_rows, m_cols ); for ( iset_citer i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { if ( i->first S( i->first, *j ) = Get( i->first, *j ); } } return S; } SparseMatrix SparseMatrix::TriL( int diag ) const { SparseMatrix S( m_rows, m_cols ); for ( iset_citer i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { if ( i->first >= *j - diag ) S( i->first, *j ) = Get( i->first, *j ); } } return S; } SparseMatrix SparseMatrix::operator-() const { SparseMatrix C( m_rows, m_cols ); for ( iset_citer i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { double elt = Get( i->first, *j ); C( i->first, *j ) = -elt; } } return C; } SparseMatrix SparseMatrix::operator*( double s ) const { SparseMatrix C( m_rows, m_cols ); for ( iset_citer i = m_row_indices.begin(); i != m_row_indices.end(); i++ ) { for ( elt_citer j = i->second.begin(); j != i->second.end(); j++ ) { double elt = Get( i->first, *j ); C( i->first, *j ) = elt * s; } } return C; } SparseMatrix SparseMatrix::operator-( const SparseMatrix& B ) const { return *this + (-B); }