//========= Copyright Valve Corporation, All rights reserved. ============// // // Purpose: // // A class allowing storage of a sparse NxN matirx as an array of sparse rows //===========================================================================// #ifndef SPARSEMATRIX_H #define SPARSEMATRIX_H #include "tier1/utlvector.h" /// CSparseMatrix is a matrix which compresses each row individually, not /// storing the zeros. NOte, that while you can randomly set any element in a /// CSparseMatrix, you really want to do it top to bottom or you will have bad /// perf as data is moved around to insert new elements. class CSparseMatrix { public: struct NonZeroValueDescriptor_t { int m_nColumnNumber; float m_flValue; }; struct RowDescriptor_t { int m_nNonZeroCount; // number of non-zero elements in the row int m_nDataIndex; // index of NonZeroValueDescriptor_t for the first // non-zero value }; int m_nNumRows; int m_nNumCols; CUtlVector m_rowDescriptors; CUtlVector m_entries; int m_nHighestRowAppendedTo; void AdjustAllRowIndicesAfter(int nStartRow, int nDelta); public: FORCEINLINE float Element(int nRow, int nCol) const; void SetElement(int nRow, int nCol, float flValue); void SetDimensions(int nNumRows, int nNumCols); void AppendElement(int nRow, int nCol, float flValue); void FinishedAppending(void); FORCEINLINE int Height(void) const { return m_nNumRows; } FORCEINLINE int Width(void) const { return m_nNumCols; } }; FORCEINLINE float CSparseMatrix::Element(int nRow, int nCol) const { Assert(nCol < m_nNumCols); int nCount = m_rowDescriptors[nRow].m_nNonZeroCount; if (nCount) { NonZeroValueDescriptor_t const *pValue = &(m_entries[m_rowDescriptors[nRow].m_nDataIndex]); do { int nIdx = pValue->m_nColumnNumber; if (nIdx == nCol) { return pValue->m_flValue; } if (nIdx > nCol) { break; } pValue++; } while (--nCount); } return 0; } // type-specific overrides of matrixmath template for special case sparse // routines namespace MatrixMath { /// sparse * dense matrix x matrix multiplication template void MatrixMultiply(CSparseMatrix const &matA, BTYPE const &matB, OUTTYPE *pMatrixOut) { Assert(matA.Width() == matB.Height()); pMatrixOut->SetDimensions(matA.Height(), matB.Width()); for (int i = 0; i < matA.Height(); i++) { for (int j = 0; j < matB.Width(); j++) { // compute inner product efficiently because of sparsity int nCnt = matA.m_rowDescriptors[i].m_nNonZeroCount; int nDataIdx = matA.m_rowDescriptors[i].m_nDataIndex; float flDot = 0.0; for (int nIdx = 0; nIdx < nCnt; nIdx++) { float flAValue = matA.m_entries[nIdx + nDataIdx].m_flValue; int nCol = matA.m_entries[nIdx + nDataIdx].m_nColumnNumber; flDot += flAValue * matB.Element(nCol, j); } pMatrixOut->SetElement(i, j, flDot); } } } FORCEINLINE void AppendElement(CSparseMatrix &matrix, int nRow, int nCol, float flValue) { matrix.AppendElement(nRow, nCol, flValue); // default implementation } FORCEINLINE void FinishedAppending(CSparseMatrix &matrix) { matrix.FinishedAppending(); } }; // namespace MatrixMath #endif // SPARSEMATRIX_H