CSparseMatrix.h文件
/*
*Copyright? 中国地质大学(武汉) 信息工程学院
*All right reserved.
*
*文件名称:CSparseMatrix.h
*摘 要:用三元组实现稀疏矩阵的表示及运算
*
*当前版本:1.0
*作 者:———
*完成日期:2018-04-29
*/
#pragma once
#ifndef CSPARSEMATRIX
#define CSPARSEMATRIX
#include<iostream>
using namespace std;
template<class _Ty>
struct CTrituple //三元组结构体定义
{
int m_nRowIndex; //行下标
int m_nColIndex; //列下标
_Ty m_tData; //数据
};
#define MAXSIZE 55
template<class _Ty>
class CSparseMatrix {
public:
CSparseMatrix(int size = MAXSIZE); //构造函数
~CSparseMatrix(); //析构函数
CSparseMatrix(CSparseMatrix<_Ty>& S); //复制构造函数
//赋值运算符
CSparseMatrix<_Ty>& operator = (CSparseMatrix<_Ty>& S);
void Transpose(CSparseMatrix<_Ty>& S); //转置函数-快速转置
// -重载加法运算符
CSparseMatrix<_Ty> operator + (CSparseMatrix<_Ty>& S);
// -重载乘法运算符
CSparseMatrix<_Ty> operator * (CSparseMatrix<_Ty>& S1);
//重载输入运算符
friend istream& operator >>(istream& in,
CSparseMatrix<_Ty>& S) {
int term, row, col;
int rindex, cindex, data;
cout << "依次输入矩阵的行与列大小:";
while (in >> row >> col) {
if (row < 1 || col < 1) {
cout << "行号或者列号个数小于1,请重新输入!" << endl;
continue;
}
else {
S.m_nRowSize = row;
S.m_nColSize = col;
break;
}
}
cout << "输入三元组的个数:";
while (in >> term) {
if (term > S.m_nMaxTerm) {
cout << "Term > m_nMaxTerm,请重新输入三元组的个数:" << endl;
continue;
}
else if (term < 1) {
cout << "Term < 1,请重新输入三元组的个数:" << endl;
continue;
}
else {
S.m_nTerm = term;
break;
}
}
cout << "依次输入三元组,格式为:行号 列号 数值" << endl;
int count = 0;
while (count < S.m_nTerm //应先判断是否count < S.m_nTerm
&& in >> rindex >> cindex >> data) { //不然会多输入一次无用的三元组
if (rindex < 0
|| rindex >= S.m_nRowSize
|| cindex < 0
|| cindex >= S.m_nColSize) {
cout << "元素下标越界,请重新输入该三元组:" << endl;
continue;
}
else {
S.m_tmElement[count].m_nRowIndex = rindex;
S.m_tmElement[count].m_nColIndex = cindex;
S.m_tmElement[count].m_tData = data;
count++;
}
}
return in;
}
//重载输出函数
friend ostream& operator <<(ostream& out,
CSparseMatrix<_Ty>& S) {
if (S.m_nTerm > 0) {
out << "矩阵的行数与列数:" << S.m_nRowSize << " " << S.m_nColSize << endl;
out << "三元组个数:" << S.m_nTerm << endl;
out << "三元组列表:" << endl;
for (int i = 0; i < S.m_nTerm; i++)
out << S.m_tmElement[i].m_nRowIndex << " "
<< S.m_tmElement[i].m_nColIndex << " "
<< S.m_tmElement[i].m_tData << " "
<< endl;
}
return out;
}
private:
int m_nMaxTerm; //最大三元组个数
int m_nTerm; //当前三元组个数
int m_nRowSize; //行规格
int m_nColSize; //列规格
CTrituple<_Ty>* m_tmElement; //存放三元组的数组
void Init(CSparseMatrix<_Ty>& S);
};
template<class _Ty>
CSparseMatrix<_Ty>::CSparseMatrix(int size /*= MAXSIZE*/) { //构造函数
if (size < 1) { //检查参数的合理性
cout << "构造函数错误,最大元素个数小于1!";
exit(-1);
}
//将成员变量赋初值
m_nMaxTerm = size;
m_nTerm = 0;
m_nRowSize = 0;
m_nColSize = 0;
m_tmElement = new CTrituple<_Ty>[m_nMaxTerm];
//判断空间分配成功与否
if (m_tmElement == nullptr) {
cerr << "内存分配错误!" << endl;
exit(-1);
}
}
template<class _Ty>
CSparseMatrix<_Ty>::~CSparseMatrix() { //析构函数
if (m_tmElement != nullptr)
delete[]m_tmElement;
}
template<class _Ty>
void CSparseMatrix<_Ty>::Init(CSparseMatrix<_Ty>& S) {
//利用参数S给本对象成员变量赋初值
this->m_nMaxTerm = S.m_nMaxTerm;
this->m_nTerm = S.m_nTerm;
this->m_nRowSize = S.m_nRowSize;
this->m_nColSize = S.m_nColSize;
//给本对象的m_tmElement数组分配内存
this->m_tmElement = new CTrituple<_Ty>[this->m_nMaxTerm];
//判断内存分配成功与否
if (this->m_tmElement == nullptr) {
cerr << "内存分配错误!" << endl;
exit(-1);
}
//给本对象的三原数组赋初值
for (int i = 0; i < this->m_nTerm; i++) {
this->m_tmElement[i].m_nRowIndex = S.m_tmElement[i].m_nRowIndex;
this->m_tmElement[i].m_nColIndex = S.m_tmElement[i].m_nColIndex;
this->m_tmElement[i].m_tData = S.m_tmElement[i].m_tData;
}
}
template<class _Ty>
CSparseMatrix<_Ty>::CSparseMatrix(CSparseMatrix<_Ty>& S) { //复制构造函数
Init(S); //直接调用
}
template<class _Ty>
CSparseMatrix<_Ty>& CSparseMatrix<_Ty>::operator =
(CSparseMatrix<_Ty>& S) { //赋值运算符
Init(S);
return *this;
}
template<class _Ty>
void CSparseMatrix<_Ty>::Transpose(CSparseMatrix<_Ty>& S) { //转置函数-快速转置
S.m_nMaxTerm = this->m_nMaxTerm;
S.m_nTerm = this->m_nTerm;
S.m_nRowSize = this->m_nColSize; //转置后的行数等于原矩阵的列数
S.m_nColSize = this->m_nRowSize; //转置后的列数等于原矩阵的行数
//为转置后的三元组数组分配内存
S.m_tmElement = new CTrituple<_Ty>[S.m_nMaxTerm];
//检查内存是否分配成功
if (S.m_tmElement == nullptr) {
cerr << "内存分配错误!" << endl;
exit(-1);
}
if (S.m_nTerm > 0) { //当三元组个数大于零时便进行转置
int* rowSize = new int[S.m_nRowSize]; //转置后的每行三元组的个数
int* rowStart = new int[S.m_nRowSize]; //转置后每行在数组中的起始位置
int i = 0;
//为rowSize数组初始化,初始化为每行的元素个数为零
for (i = 0; i < S.m_nRowSize; i++)
rowSize[i] = 0;
//在原矩阵中寻找列号-转置之后的行号,找到一个加一
for (i = 0; i < S.m_nTerm; i++)
rowSize[this->m_tmElement[i].m_nColIndex]++;
//rowStart[0]初始化为零
rowStart[0] = 0;
//累加
for (i = 1; i < S.m_nRowSize; i++)
rowStart[i] = rowStart[i - 1] + rowSize[i - 1];
//开始转置
for (i = 0; i < S.m_nTerm; i++) {
int j = rowStart[this->m_tmElement[i].m_nColIndex];
S.m_tmElement[j].m_nRowIndex = this->m_tmElement[i].m_nColIndex;
S.m_tmElement[j].m_nColIndex = this->m_tmElement[i].m_nRowIndex;
S.m_tmElement[j].m_tData = this->m_tmElement[i].m_tData;
rowStart[this->m_tmElement[i].m_nColIndex]++;
}
delete[]rowSize;
delete[]rowStart;
}
}
template<class _Ty>
CSparseMatrix<_Ty> CSparseMatrix<_Ty>::operator +
(CSparseMatrix<_Ty>& S) { // -重载加法运算符
CSparseMatrix<_Ty> result; //构造函数中result.m_nTerm = 0;
if (this->m_nRowSize != S.m_nRowSize
|| this->m_nColSize != S.m_nColSize) {
cout << "两矩阵的大小不匹配,不能进行相加!" << endl;
return result;
}
//为返回结果的成员变量赋值
result.m_nMaxTerm = this->m_nMaxTerm;
result.m_nRowSize = this->m_nRowSize;
result.m_nColSize = this->m_nColSize;
int i = 0, j = 0; //用于循环
int index_this = 0, index_S = 0; //用于记录下标大小
while (i < this->m_nTerm && j < S.m_nTerm) {
index_this = this->m_tmElement[i].m_nRowIndex
* this->m_nColSize + this->m_tmElement[i].m_nColIndex;
index_S = S.m_tmElement[j].m_nRowIndex
* S.m_nColSize + S.m_tmElement[j].m_nColIndex;
if (index_this < index_S) { //说明应先插入this中的这个三元组
result.m_tmElement[result.m_nTerm++] = this->m_tmElement[i]; //系统会提供默认的赋值运算符,不涉及new,可以不用自己重载
i++;
}
else if (index_this = index_S) { //说明此时this与S中此三元组在矩阵中的位置相同
_Ty temp = this->m_tmElement[i].m_tData + S.m_tmElement[j].m_tData;
if (temp != 0) {
result.m_tmElement[result.m_nTerm] = this->m_tmElement[i];
result.m_tmElement[result.m_nTerm].m_tData = temp;
}
result.m_nTerm++;
i++;
j++;
}
else {
result.m_tmElement[S.m_nTerm++] = S.m_tmElement[j];
j++;
}
}
//复制剩余元素
while (i < this->m_nTerm) { //如果在this中有剩余
result.m_tmElement[result.m_nTerm++] = this->m_tmElement[i];
i++;
}
while (j < S.m_nTerm) {
result.m_tmElement[result.m_nTerm++] = S.m_tmElement[j]; //如果S中有剩余
j++;
}
return result;
}
template<class _Ty>
CSparseMatrix<_Ty> CSparseMatrix<_Ty>::operator *
(CSparseMatrix<_Ty>& S) { // -重载乘法运算符
//稀疏矩阵(*this)与参数中的稀疏矩阵S相乘
//结果放置在result中
CSparseMatrix<_Ty> result;
if (this->m_nColSize != S.m_nRowSize) {
cout << "两矩阵类型不匹配,不能进行相乘!" << endl;
return result;
}
int* rowSize = new int[S.m_nRowSize]; //参数矩阵S中各行非零元素个数
if (rowSize == nullptr) {
cerr << "内存分配失败!" << endl;
exit(-1);
}
int* rowStart = new int[S.m_nRowSize + 1]; //参数矩阵S中各行在三元组开始位置
if (rowStart == nullptr) {
cerr << "内存分配失败!" << endl;
exit(-1);
}
_Ty* temp = new _Ty[S.m_nColSize]; //暂存每一行的计算结果
if (temp == nullptr) {
cerr << "内存分配失败!" << endl;
exit(-1);
}
int i = 0; //用于循环的临时变量
for (i = 0; i < S.m_nRowSize; i++)
rowSize[i] = 0; //rowSize数组初始化
for (i = 0; i < S.m_nTerm; i++)
rowSize[S.m_tmElement[i].m_nRowIndex]++; //计算得到参数矩阵中每行非零元素个数
rowStart[0] = 0;
for (i = 1; i < S.m_nRowSize + 1; i++)
rowStart[i] = rowStart[i - 1] + rowSize[i - 1]; //计算得到b中每行非零元素开始位置
int current = 0, lastInResult = -1; //current:扫描指针,lastInResult:记录结果Term值
int row_this, col_this, col_S;
while (current < this->m_nTerm) {
row_this = this->m_tmElement[current].m_nRowIndex; //this当前的行号
for (i = 0; i < S.m_nColSize; i++)
temp[i] = 0;
while (current < this->m_nTerm
&& this->m_tmElement[current].m_nRowIndex == row_this) {
col_this = this->m_tmElement[current].m_nColIndex; //this矩阵当前扫描到的料号
for (i = rowStart[col_this]; i < rowStart[col_this + 1]; i++) {
col_S = S.m_tmElement[i].m_nColIndex;
temp[col_S] += this->m_tmElement[current].m_tData * S.m_tmElement[i].m_tData;
}
current++;
}
//将temp中的非零元素压缩到result中去
for (i = 0; i < S.m_nColSize; i++) {
if (temp[i] != 0) {
lastInResult++;
result.m_tmElement[lastInResult].m_nRowIndex = row_this;
result.m_tmElement[lastInResult].m_nColIndex = i;
result.m_tmElement[lastInResult].m_tData = temp[i];
}
}
}
result.m_nRowSize = this->m_nRowSize;
result.m_nColSize = S.m_nColSize;
result.m_nTerm = lastInResult + 1;
delete[] rowSize;
delete[] rowStart;
delete[] temp;
return result;
}
#endif
main.cpp文件(测试文件)
#include"CSparseMatrix.h"
int main() {
CSparseMatrix<int> sparse1, sparse2, sparse3, result;
//测试输入输出运算符
cout << "输入sparse1:" << endl;
cin >> sparse1;
cout << "测试加法,输入sparse2" << endl;
cin >> sparse2;
cout << "sparse1 + sparse2 = " << endl;
result = sparse1 + sparse2;
cout << result;
/*cout << "测试乘法,输入sparse3" << endl;
cin >> sparse3;
result = sparse2 * sparse3;
cout << "sparse2 * sparse3 = " << endl;
cout << result;*/
return 0;
}
测试结果
加法:
乘法: