前言:这两天刚接手一个并行项目,对矩阵分块进行计算,而采用的矩阵存储格式就是CSR,之前刚好对这个格式一直存在迷惑,借着机会好好的学习一下。
1、CSR介绍
CSR(Compressed Sparse Row,压缩稀疏行),也就是常用在稀疏矩阵存储的一种格式。CSR 的核心思想是:只存非零的数据,并且把“空着”的地方挤掉。通过使用三个一维数组来达到存储的效果,分别是values、col_indices、row_pointer。
- values(数值):存储矩阵中的所有非零元素。
- col_indices(列索引):用于存储每个非零元素对应的列索引。col_indices[i]表示第i个非零元素所在的列索引。
- row_pointer(行偏移/行指针):用于存储每一行在col_indices和values数组中的起始索引位置,而row_pointer[i+1] - row_pointer[i]表示第i行的非零元素个数。
2、示例说明
如下图所示是一个6×4的矩阵,以及对应的CSR格式
1、values:按行(从上到下),行内从左到右读取所有绿色的数字。
2、col_indices:根据上面的数值分别位于哪一列填写,也就是列坐标
3、row_pointer:是最难理解的部分,(如果是行坐标那就好理解了,可惜不是),它不直接存行号,它存的是每一行在values中的起始下标。也就是第0行的起始索引是0,第一行的起始数据是3,3前面有两个元素,所以第一行就是就是2……
计算过程如下:数组长度 = 行数 + 1 = 7
- Row 0: 从下标0开始 (对应数值 4)
- Row 1: Row 0 有 2 个元素,所以 Row 1 从下标 0+2 =2开始 (对应数值 3)
- Row 2: Row 1 有 1 个元素,所以 Row 2 从下标 2+1 =3开始 (对应数值 5)
- Row 3: Row 2 有 1 个元素,所以 Row 3 从下标 3+1 =4开始 (对应数值 7)
- Row 4: Row 3 有 2 个元素,所以 Row 4 从下标 4+2 =6开始 (对应数值 2)
- Row 5: Row 4 有 1 个元素,所以 Row 5 从下标 6+1 =7开始 (对应数值 9)
- End: Row 5 有 1 个元素,最后一位是 7+1 =8(非零元素总数)
3、二维稀疏矩阵转CSR格式
#include<iostream>#include<vector>#include<iomanip>// 用于格式化输出// 定义 CSR 矩阵结构structCSRMatrix{std::vector<double>values;// 非零数值std::vector<int>col_indices;// 列索引std::vector<int>row_ptr;// 行偏移量introws;// 行数intcols;// 列数};/** * 将二维向量 (std::vector<std::vector<double>>) 转换为 CSR 格式 */CSRMatrixdense_to_csr(conststd::vector<std::vector<double>>&dense_matrix){CSRMatrix csr;// 获取行数和列数if(dense_matrix.empty())returncsr;csr.rows=dense_matrix.size();csr.cols=dense_matrix[0].size();// 1. 初始化 row_ptr// 规则:row_ptr 的第一个元素永远是 0csr.row_ptr.push_back(0);// 2. 遍历二维矩阵for(inti=0;i<csr.rows;++i){// 遍历当前行的每一列for(intj=0;j<csr.cols;++j){doubleval=dense_matrix[i][j];// 3. 判断非零元素if(val!=0.0){csr.values.push_back(val);// 存数值csr.col_indices.push_back(j);// 存列号}}// 4. 当前行遍历结束 row_ptr 的下一个值 = 当前 values 数组的总长度csr.row_ptr.push_back(csr.values.size());}returncsr;}4、查找特定位置元素
说完了他的构造,那么如果访问呢,对于上述的例子,比如我们想访问第 4 行,第 3 列,从图片看,就是数字7,在 CSR 格式中,不能像普通二维数组那样直接用A[3][2]访问。需要通过row_pointer先找到“第 3 行”的数据范围,然后在里面“搜”第 2 列。
下面是具体的访问步骤和逻辑:
- 查行范围:去
row_pointer查第 4 行的起始位置和结束位置。- 起始:
row_pointer[3]=4 - 结束:
row_pointer[3+1](即row_pointer[4]) =6 - 意味着:第 3 行的数据存储在
values数组下标[4, 6)的区间内。
- 起始:
- 搜列索引:遍历
col_indices的下标 4 到 5,寻找列号为2的元素。- 检查下标4:
col_indices[4]是2。 - 匹配成功!
- 检查下标4:
- 取值:既然下标 4 对应的列号是对的,那么取
values[4]。values[4]=7。
#include<iostream>#include<vector>// 查找矩阵中特定的元素 A[row, col]doubleget_element(inttarget_row,inttarget_col,conststd::vector<int>&row_ptr,conststd::vector<int>&col_indices,conststd::vector<double>&values){// 1. 查找行范围intstart_idx=row_ptr[target_row];// 起点intend_idx=row_ptr[target_row+1];// 终点 (不包含)// 2. 在该范围内遍历,查找是否有列索引等于target_colfor(intidx=start_idx;idx<end_idx;++idx){if(col_indices[idx]==target_col){returnvalues[idx];}}// 3. 如果循环结束还没找到,说明该位置是 0return0.0;}// 定义一个结构体用来存 (行, 列, 值) 三元组structTriplet{introw;intcol;doublevalue;};/** * 遍历 CSR 矩阵的所有非零元素 * 对应 Python: get_all_non_zero_elements */std::vector<Triplet>get_all_non_zero_elements(conststd::vector<int>&row_ptr,conststd::vector<int>&col_indices,conststd::vector<double>&values){std::vector<Triplet>results;// 非零元素的总数就是 values 的大小results.reserve(values.size());// 1. 获取行数 (row_ptr 的长度减 1)intnum_rows=row_ptr.size()-1;// 2. 外层循环:遍历每一行for(inti=0;i<num_rows;++i){// 获取当前行的起止位置intstart_idx=row_ptr[i];intend_idx=row_ptr[i+1];// 3. 内层循环:遍历这一行的所有非零元素for(intidx=start_idx;idx<end_idx;++idx){// 构造三元组并存入结果// row = i (当前行号)// col = col_indices[idx] (当前列号)// val = values[idx] (当前数值)results.push_back({i,col_indices[idx],values[idx]});}}returnresults;}5、CSR总结
| 特性 | 描述 / 评价 | 备注 (Why?) |
|---|---|---|
| 存储方式 | 按行压缩存储非零元素 | 也就是 Row-Major (行优先) |
| 空间复杂度 | 2×NNZ+(N+1)2 \times NNZ + (N + 1)2×NNZ+(N+1) 即 values、column_indices、row_pointer三个数组 | NNZNNZNNZ个数值 +NNZNNZNNZ个列号 +(行数+1)(行数+1)(行数+1)个行偏移。非常省内存。 |
| 优势 | 行切片 (Row Slicing) 极快 | 想取第iii行?直接row_ptr[i]到row_ptr[i+1]就拿到了。 |
| 优势 | 矩阵-向量乘法 (SpMV) 极快 | CPU 缓存命中率高,因为values数组是连续访问的。 |
| 缺点 (致命) | 结构修改 (Insertion/Deletion) 极慢 | 千万别往 CSR 里插入新元素!这意味着需要把后面几十万个数据全部往后挪一位。 |
| 缺点 | 列切片 (Column Slicing) 极慢 | 想取第jjj列,必须遍历所有行,去“搜”有没有第jjj列的元素。 |
| 缺点 | 随机访问 (Random Access) 不是O(1)O(1)O(1) | 查A[i][j]的时间复杂度是O(该行非零元个数)O(\text{该行非零元个数})O(该行非零元个数),因为要二分查找或线性扫描。 |
| 并行性 | 适合 OpenMP/多线程 | 每一行是独立的,很容易并行 (#pragma omp parallel for)。 |
| 负载均衡 | 可能较差 (Load Imbalance) | 如果第1行有100个元素,第2行只有1个,多线程分配时会导致有的核累死,有的核闲死。 |
| 适用场景 | 矩阵乘法、有限元分析、迭代求解器 | 写好了就不动的数据结构。 |
| 不适用场景 | 矩阵组装 (Assembly) | 正在生成矩阵时别用 CSR,要用 COO 或 List of Lists,生成完了一次性转 CSR。 |
CSR 的“三不”原则
- 不要动态修改:如果需要频繁
A[i][j] = x且(i,j)原本是 0,绝对不要用 CSR。请先用 COO 格式存,最后make_compressed()转成 CSR。 - 不要按列遍历:如果算法里有
for col in columns的逻辑,CSR 会慢到让你怀疑人生。请转置矩阵或者改用CSC (压缩稀疏列)格式。 - 不要直接用于复杂的矩阵-矩阵乘 (SpGEMM):虽然可以做,但如果结果矩阵的稀疏结构不可预测,通常需要先计算符号模式(Symbolic Phase),比稠密矩阵乘法麻烦得多。