这道题要求统计“好路径”的数量,定义是路径起点和终点的值相等,且路径上所有节点的值都 ≤ 这个值。
我来帮你梳理思路并给出可直接使用的代码。
---
思路分析
1. 关键条件理解
· 路径的起点值 vals[u] = 终点值 vals[v]。
· 路径上所有中间节点的值 ≤ 起点值(所以其实是 ≤ vals[u])。
· 因此,对于节点值相等的两个节点,如果它们之间的路径上所有节点值都 ≤ 这个值,则这条路径就是“好路径”。
2. 并查集 + 值从小到大处理
· 将所有节点按值升序排序。
· 每次处理一种值 v 的所有节点时,只考虑值 ≤ v 的边来连接(按值升序逐步加边)。
· 在加入值 v 的节点及其与值 ≤ v 节点的边后,使用并查集将连通块合并。
· 对于当前值 v,每个连通块内具有值 v 的节点数 cnt,它们两两之间的路径都是好路径(同值且值最大,路径上无更大值节点,因为更大的值还没加进来)。
· 对每个值 v,贡献为 C(cnt, 2) + cnt(每个节点自身也算一条长度为 0 的路径)。
3. 算法步骤
· 按值排序节点索引。
· 按值升序排序所有边(或边只按两端节点的较大值排序)。
· 顺序枚举每个值,对于当前所有值 ≤ 当前值的节点,利用并查集合并它们之间的边。
· 统计每个连通块中当前值节点个数,累加贡献。
4. 时间复杂度
· 并查集 + 排序节点 + 排序边:O(n log n + m log m),其中 m = edges.length。
---
代码实现
```java
import java.util.*;
class Solution {
public int numberOfGoodPaths(int[] vals, int[][] edges) {
int n = vals.length;
// 按节点值升序存储节点索引
List<Integer> nodes = new ArrayList<>();
for (int i = 0; i < n; i++) nodes.add(i);
nodes.sort(Comparator.comparingInt(a -> vals[a]));
// 建图:邻接表
List<Integer>[] graph = new List[n];
for (int i = 0; i < n; i++) graph[i] = new ArrayList<>();
for (int[] e : edges) {
int u = e[0], v = e[1];
graph[u].add(v);
graph[v].add(u);
}
// 并查集
int[] parent = new int[n];
int[] size = new int[n];
for (int i = 0; i < n; i++) {
parent[i] = i;
size[i] = 1;
}
// 按值升序处理节点
int ans = 0;
boolean[] active = new boolean[n];
for (int i = 0; i < n; ) {
int curVal = vals[nodes.get(i)];
// 收集所有值等于 curVal 的节点
List<Integer> curNodes = new ArrayList<>();
while (i < n && vals[nodes.get(i)] == curVal) {
curNodes.add(nodes.get(i));
i++;
}
// 将所有相邻的、值 <= curVal 的节点合并(即当前 active 的节点)
for (int u : curNodes) {
for (int v : graph[u]) {
if (active[v]) {
union(parent, size, u, v);
}
}
active[u] = true;
}
// 统计当前值 curVal 在同一个连通块中的数量
Map<Integer, Integer> cnt = new HashMap<>();
for (int u : curNodes) {
int root = find(parent, u);
cnt.put(root, cnt.getOrDefault(root, 0) + 1);
}
// 组合数 C(cnt,2) + cnt
for (int c : cnt.values()) {
ans += c * (c + 1) / 2;
}
}
return ans;
}
private int find(int[] parent, int x) {
if (parent[x] != x) {
parent[x] = find(parent, parent[x]);
}
return parent[x];
}
private void union(int[] parent, int[] size, int a, int b) {
int ra = find(parent, a);
int rb = find(parent, b);
if (ra == rb) return;
if (size[ra] < size[rb]) {
parent[ra] = rb;
size[rb] += size[ra];
} else {
parent[rb] = ra;
size[ra] += size[rb];
}
}
}
```
---
简要解释
· active 数组标记当前已经允许使用的节点(值 ≤ 当前处理的值)。
· 处理值 v 时,先把这些节点标记为 active,然后通过已激活的边将它们合并到同一个连通块。
· 统计每一个连通块中当前值 v 的节点数量,它们两两之间都是好路径。
· 加上自己到自己的路径(单节点路径)。㇏