C++实现的线性代数矩阵计算


/**
 *  线性代数矩阵计算
 *  实现功能:行列向量获取,子矩阵获取,转置矩阵获取,
 *          行列式计算,转置伴随矩阵获取,逆矩阵计算
 *
 *  Copyright 2011 Shi Y.M. All Rights Reserved
 */

#ifndef MATICAL_TMATRIX_H_INCLUDED
#define MATICAL_TMATRIX_H_INCLUDED

#include <memory.h>

namespace matical
{

template<typename T>
class TMatrix
{
private:
    T * m_Elems;
    int m_Rows;
    int m_Cols;

public:
    typedef T ElemType;

public:

    TMatrix() : m_Elems(NULL), m_Rows(0), m_Cols(0) {}
    TMatrix(int rows, int cols) : m_Rows(rows), m_Cols(cols) {
        m_Elems = new ElemType[m_Rows * m_Cols];
        memset(m_Elems, 0, sizeof(ElemType) * m_Rows * m_Cols);
    }
    // 生成n阶方阵
    TMatrix(int n) : m_Rows(n), m_Cols(n){
        m_Elems = new ElemType[m_Rows * m_Cols];
        memset(m_Elems, 0, sizeof(ElemType) * m_Rows * m_Cols);
    }

    TMatrix(const TMatrix& m) : m_Rows(m.Rows()), m_Cols(m.Cols()) {
        m_Elems = new ElemType[m_Rows * m_Cols];
        memcpy(m_Elems, m.m_Elems, sizeof(ElemType) * m_Rows * m_Cols);
    }
    virtual ~TMatrix() {
        if ( m_Elems ) delete[] m_Elems;
        m_Elems = NULL;
        m_Rows  = 0;
        m_Cols  = 0;
    }

public:

    int     Rows() const { return m_Rows;}
    int     Cols() const { return m_Cols;}

    // 是否为方阵
    bool    IsSquare()   const { return ( m_Rows > 0 && (m_Rows == m_Cols)); }

    // 取行向量
    TMatrix<T> Row(int row) const {
        TMatrix m(1, m_Cols);
        for(int i = 0; i < m_Cols; i++) {
            m(0, i) = (*this)(row, i);
        }
        return m;
    }

    // 取列向量
    TMatrix<T> Col(int col) const {
        TMatrix m(m_Rows, 1);
        for(int i = 0; i < m_Rows; i++) {
            m(i, 0) = (*this)(i, col);
        }
        return m;
    }

public:

    // 将当前矩阵设置为单位矩阵
    void Identity() {
        for(int r = 0; r < this->m_Rows; r++) {
            for(int c = 0; c < this->m_Cols; c++) {
                if ( r == c ) (*this)(r, c) = 1;
                else  (*this)(r, c) = 0;
            }
        }
        return ;
    }

    // 当前矩阵设置为零矩阵
    void Zero() {
        int sz = m_Rows * m_Cols;
        for (int i = 0; i < sz; i++ ) m_Elems[i] = 0;
        return ;
    }

    // 获取指定范围的子矩阵,rrow>=lrow, rcol>=lcol
    TMatrix<T> SubMatrix(int lrow, int lcol, int rrow, int rcol) const {
        TMatrix m(rrow - lrow + 1, rcol - lcol + 1);
        for ( int r = 0; r < m.m_Rows; r++) {
            for (int c = 0; c < m.m_Cols; c++) {
                m(r, c ) = (*this)(r + lrow, c + lcol);
            }
        }
        return m;
    }

    // 获取当前矩阵元素(row, col)的子矩阵(去掉元素(row,col)所在的行和列)
    TMatrix<T> SubMatrixEx(int row, int col) const {
        int rr = 0;
        int cc = 0;
        TMatrix m(this->m_Rows - 1, this->m_Cols - 1);
        for ( int r = 0; r < m.m_Rows; r++, rr++) {
            if ( r == row) rr++;
            cc = 0;
            for (int c = 0; c < m.m_Cols; c++, cc++) {
                if ( c == col ) cc++;
                m(r, c) = (*this)(rr, cc);
            }
        }
        return m;
    }

    // 计算当前矩阵的转置矩阵
    TMatrix<T> Transpose() const {
        printf("Transpose\n");
        TMatrix m(this->m_Cols, this->m_Rows);
        for (int r = 0; r < m.m_Rows; r++ ) {
            for ( int c = 0; c < m.m_Cols; c++ ) {
                m(r, c) = (*this)(c, r);
            }
        }
        return m;
    }

    // 计算当前矩阵的行列式(递归方式), 计算行列式的矩阵,须是n阶方阵
    ElemType Det() const {
        ElemType det = 0;
        if ( m_Rows == 1) {
            det = (*this)(0, 0);
        } else if ( m_Rows == 2) {
            det = (*this)(0, 0) * (*this)(1, 1) - (*this)(0 ,1) * (*this)(1, 0);
        } else {
            for (int r = 0; r < m_Rows; r++) {
                TMatrix<ElemType> m = this->SubMatrixEx(r, 0);
                if ( r % 2 ) det += (-1) * (*this)(r, 0) * m.Det();
                else det += (*this)(r, 0) * m.Det();
            }
        }
        return det;
    }

    // 计算当前矩阵的转置伴随矩阵,当前矩阵须为n阶方阵
    TMatrix<T> Adj() const {
        TMatrix<T> m(m_Rows, m_Cols);
        for (int r = 0; r < m_Rows; r++) {
            for (int c = 0; c < m_Cols; c++) {
                if ( (r + c) % 2 ) m(c, r) = -1 * this->SubMatrixEx(r, c).Det();
                else m(c, r) = this->SubMatrixEx(r, c).Det();
            }
        }
        return m;
    }

    // 计算当前矩阵的逆矩阵,当前矩阵须为n阶方阵
    TMatrix<T> Inv() const {
        ElemType   det = this->Det();
        TMatrix<T> adj = this->Adj();
        return adj * (1 / det);
    }

public:

    ElemType operator()(int row, int col) const { return m_Elems[row * m_Cols + col];};
    ElemType& operator()(int row, int col) { return m_Elems[row * m_Cols + col]; }

public:

    // 两矩阵相加,必须具有相同的行数以及列数
    TMatrix<T>  operator+(const TMatrix& m) const {
        TMatrix mr(this->m_Rows, this->m_Cols);
        for(int r = 0; r < this->m_Rows; r++) {
            for(int c = 0; c < this->m_Cols; c++) {
                mr(r, c) = (*this)(r, c) + m(r, c);
            }
        }
        return mr;
    } // operator+(const TMatrix& m) const

    // 两矩阵相减,必须具有相同的行数以及列数
    TMatrix<T> operator-(const TMatrix& m) const {
        TMatrix mr(this->m_Rows, this->m_Cols);
        for(int r = 0; r < this->m_Rows; r++) {
            for(int c = 0; c < this->m_Cols; c++) {
                mr(r, c) = (*this)(r, c) - m(r, c);
            }
        }
        return mr;
    } // operator-(const TMatrix& m) const

    // 矩阵与常数相乘(数乘)
    TMatrix<T>  operator*(ElemType v) const {
        TMatrix mr(this->m_Rows, this->m_Cols);
        for(int r = 0; r < this->m_Rows; r++) {
            for(int c = 0; c < this->m_Cols; c++) {
                mr(r, c) = (*this)(r, c) * v;
            }
        }
        return mr;
    } // operator*(ElemType v) const

    // 矩阵相乘(当前矩阵列数须等于参数矩阵的列数),
    // 结果矩阵行数等于当前矩阵,列数等于参数矩阵的列数
    TMatrix<T>  operator*(const TMatrix& m) const {
        TMatrix mr(this->m_Rows, m.m_Cols);
        for (int r = 0; r < this->m_Rows; r++) {
            for (int c = 0; c < m.m_Cols; c++) {
                for (int i = 0; i < this->m_Cols; i++) {
                    mr(r, c) += ( (*this)(r, i) * m(i, c) );
                    printf("(%d, %d)\n", r, c);
                }
            }
        }
        return mr;
    } // operator*(const TMatrix& m) const
}; // class TMatrix

} // namespace matical

#endif // MATICAL_TMATRIX_H_INCLUDED

相关内容