树上问题总结

8.1k 词

0 前言

总结一些树上问题的常用模板 (感觉开了个大坑)

1 树的直径

1.1 树形 DP 求树的直径

DuD_u 表示从节点 uu 出发走向以 uu 为根的子树,能够到达的最远节点的距离,对于 uu 的子节点 vi(1iti)v_i(1 \leq i \leq t_i) ,有

Du=max1itu{Dvi+edgeu,vi}D_u = \max_{1 \leq i \leq t_u}\{D_{v_i} + edge_{u, v_i}\}

FuF_u 表示经过节点 uu 的最长链的长度, 考虑 uu 的两个节点 vi,vjv_i, v_j,将其通过节点 uu 连接即可,转移方程

Fu=max1j<itu{Dvi+edgeu,vi+Dvj+edgeu,vj}F_u = \max_{1 \leq j < i \leq t_u}\{D_{v_i} + edge_{u, v_i} + D_{v_j} + edge_{u, v_j}\}

在树形DP的过程中转移,时间复杂度 O(n)O(n)

1
// TODO

1.1 两次 DFS/BFS 求树的直径

跑两边DFS/BFS,将每次找到最远距离的点分别设为起点和终点

DFS实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <bits/stdc++.h>

using namespace std;

vector<pair<int, int>> E[N];

bool vis[N];
int far, S, T, fa[N], dis[N];
void dfs(int u, int f)
{
fa[u] = f;
if (dis[u] > dis[far])
far = u;
for (auto nxt : E[u])
{
int v = nxt.first, w = nxt.second;
if (v == f || vis[v])
continue;
dis[v] = dis[u] + w;
dfs(v, u);
}
}

int n;

int main()
{
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
E[u].push_back(make_pair(v, w));
E[v].push_back(make_pair(u, w));
}
dis[1] = 0, dfs(1, 0), S = far;
dis[S] = 0, dfs(S, 0), T = far;
printf("%d %d\n", S, T);
return 0;
}

BFS实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <bits/stdc++.h>

using namespace std;

int n, m;

vector<pair<int, int>> E[N];

int dis[N];
bool vis[N];
inline int bfs(int s)
{
queue<int> Q;
memset(dis, 0, sizeof(dis));
memset(vis, false, sizeof(vis));
int far = 0;
Q.push(s);
vis[s] = true;
while (!Q.empty())
{
int u = Q.front();
Q.pop();
for (auto nxt : E[u])
{
int v = nxt.first, w = nxt.second;
if (vis[v])
continue;
vis[v] = true;
dis[v] = dis[u] + w;
Q.push(v);
if (dis[v] > dis[far])
far = v;
}
}
return far;
}

int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= m; i++)
{
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
E[u].push_back(make_pair(v, w));
E[v].push_back(make_pair(u, w));
}
int S = bfs(1);
int T = bfs(S);
printf("%d %d\n", S, T);
return 0;
}

2 树的重心

考虑树形DP,设 fuf_u 表示以 uu 根节点的最大子树大小,树的中心为使得 fuf_u 最小的点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>

using namespace std;

const int N = 5e4 + 5;

int n;
int w[N];
vector<int> E[N];

int ctr;
int f[N], siz[N];
void dfs(int u, int fa)
{
f[u] = 0;
siz[u] = w[u];
for (int v : E[u])
{
if (v == fa)
continue;
dfs(v, u);
siz[u] += siz[v];
f[u] = max(f[u], siz[v]);
}
f[u] = max(f[u], n - siz[u]);
if (f[u] < f[ctr] || (f[u] == f[ctr] && u < ctr))
ctr = u;
}

int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &w[i]);
for (int i = 1; i < n; i++)
{
int a, b;
scanf("%d%d", &a, &b);
E[a].push_back(b);
E[b].push_back(a);
}
ctr = 0;
f[0] = 0x3f3f3f3f;
dfs(1, 0);
printf("%d\n", ctr);
return 0;
}

2 最近公共祖先

2.1 原理

基本思路: 先将两个节点跳到同一层,再同时向上跳只到找到共同祖先

2.2 实现

2.2.1 倍增法

最常见的一种求LCA的算法,设计数组 Fu,kF_{u, k} 表示节点 uu 向上(根节点)跳 2k2^k 次到达的节点,每次跳跃按照 kk 从大到小的顺序尝试

时间复杂度 O(qlogn)O(qlogn)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <bits/stdc++.h>

using namespace std;

const int N = 5e5 + 5;
const int F = 20;

int n, m, s;
vector<int> E[N];

int f[N][F], dep[N];
void dfs(int u, int fa)
{
dep[u] = dep[fa] + 1;
f[u][0] = fa;
for (int i = 1; i < F; i++)
{
f[u][i] = f[f[u][i - 1]][i - 1];
}
for (int v : E[u])
{
if (v == fa)
continue;
dfs(v, u);
}
}

inline int LCA(int u, int v)
{
if (dep[u] < dep[v])
swap(u, v);
for (int i = F - 1; i >= 0; i--)
{
if (dep[f[u][i]] >= dep[v])
{
u = f[u][i];
}
}
if (u == v)
return u;
for (int i = F - 1; i >= 0; i--)
{
if (f[u][i] ^ f[v][i])
{
u = f[u][i];
v = f[v][i];
}
}
return f[u][0];
}

int main()
{
scanf("%d%d%d", &n, &m, &s);
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
E[x].push_back(y);
E[y].push_back(x);
}
dfs(s, 0);
while (m--)
{
int a, b;
scanf("%d%d", &a, &b);
printf("%d\n", LCA(a, b));
}
return 0;
}

2.2.2 树链剖分/重链剖分

代码长度较短的一种做法,先对树进行重链剖分,再沿着重链不断往上跳

时间复杂度 O(qlogn)O(qlogn)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <bits/stdc++.h>

using namespace std;

const int N = 5e5 + 5;
const int F = 20;

int n, m, s;
vector<int> E[N];

int fat[N], siz[N], dep[N], son[N];
void dfs1(int u, int fa)
{
fat[u] = fa;
siz[u] = 1;
dep[u] = dep[fa] + 1;
for (int v : E[u])
{
if (v == fa)
continue;
dfs1(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
}
}

int dfn[N], idx[N], top[N], Index;
void dfs2(int u, int Top)
{
dfn[u] = ++Index;
idx[Index] = u;
top[u] = Top;
if (son[u])
{
dfs2(son[u], Top);
for (int v : E[u])
{
if (v == fat[u] || v == son[u])
continue;
dfs2(v, v);
}
}
}

inline int LCA(int u, int v)
{
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]])
swap(u, v);
u = fat[top[u]];
}
return dep[u] < dep[v] ? u : v;
}

int main()
{
scanf("%d%d%d", &n, &m, &s);
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
E[x].push_back(y);
E[y].push_back(x);
}
dfs1(s, 0);
dfs2(s, 0);
while (m--)
{
int a, b;
scanf("%d%d", &a, &b);
printf("%d\n", LCA(a, b));
}
return 0;
}

2.2.3 Tarjan算法

  • 这是一个 离线算法 使用时要注意
  1. 考虑离线操作,先将询问排序

  2. 自根节点开始DFS遍历整个树

  3. 在进入一个节点时,先将他在并查集中的父亲设置为自己,在遍历完所有子节点后,处理与该节点相关的所有询问,如果询问中另一个节点为已经访问过的节点,那么该节点所在的并查集的代表元即为他们的最近公共祖先

  4. 回溯离开节点时,在并查集上将它所处的集合与其父节点所处的集合合并

时间复杂度 O(n+m)O(n + m)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <bits/stdc++.h>

using namespace std;

const int N = 5e5 + 5;
const int F = 20;

int n, m, s;
vector<int> E[N];

vector<pair<int, int>> Q[N];

int fa[N];
int find(int x)
{
return fa[x] == x ? fa[x] : fa[x] = find(fa[x]);
}

bool vis[N];
int ans[N];
void dfs(int u)
{
fa[u] = u;
vis[u] = true;
for (int v : E[u])
{
if (vis[v]) continue;
dfs(v);
fa[v] = u;
}
for (auto cur : Q[u])
if (vis[cur.first])
ans[cur.second] = find(cur.first);
}

int main()
{
scanf("%d%d%d", &n, &m, &s);
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
E[x].push_back(y);
E[y].push_back(x);
}
for (int i = 1; i <= m; i++)
{
int a, b;
scanf("%d%d", &a, &b);
Q[a].push_back(make_pair(b, i));
Q[b].push_back(make_pair(a, i));
}
dfs(s);
for (int i = 1; i <= m; i++)
printf("%d\n", ans[i]);
return 0;
}

2.3 应用

2.3.1 树上差分

问题描述

对于 nn 个节点的树,进行 kk 次修改操作,每次将 sstt 的路径上的点权增加 xx,查询每个点的最终值.

思路

对于每一次修改的路径 $ s \to t $ 我们可以把将其拆分为 $ s \to LCA(s, t) \to t$ 两段,考虑到这是一个区间修改单点查询的问题,我们可以用差分思想将修改操作进行转换:

建立一个差分数组 dud_u 对应根节点到 uu 的节点路径的差分值,对于每一个修改,我们将 $ s \to LCA(s, t) $ 和 $ LCA(s, t) \to t$ 两条路径都加上 xx,由于节点 LCA(s,t)LCA(s, t) 被重复计算了一次,所以该点的点权还要减去 xx。因此,只需将 dsd_sdtd_t 增加 xxdLCA(s,t)d_{LCA(s, t)}dfaLCA(s,t)d_{fa_{LCA(s, t)}} 减少 xx 即可,使用倍增求 LCA。

最后遍历整棵树,还原出每个点的点权并统计答案。

代码实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <bits/stdc++.h>

using namespace std;

const int N = 5e5 + 5;
const int F = 20;

int n, k;
vector<int> E[N];

int f[N][F], dep[N];
void dfs(int u, int fa)
{
dep[u] = dep[fa] + 1;
f[u][0] = fa;
for (int i = 1; i < F; i++)
{
f[u][i] = f[f[u][i - 1]][i - 1];
}
for (int v : E[u])
{
if (v == fa)
continue;
dfs(v, u);
}
}

inline int LCA(int u, int v)
{
if (dep[u] < dep[v])
swap(u, v);
for (int i = F - 1; i >= 0; i--)
{
if (dep[f[u][i]] >= dep[v])
{
u = f[u][i];
}
}
if (u == v)
return u;
for (int i = F - 1; i >= 0; i--)
{
if (f[u][i] ^ f[v][i])
{
u = f[u][i];
v = f[v][i];
}
}
return f[u][0];
}

int d[N], val[N];
void dfs_(int u, int fa)
{
val[u] = d[u];
for (int v : E[u])
{
if (v == fa)
continue;
dfs_(v, u);
val[u] += val[v];
}
}

int main()
{
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
E[x].push_back(y);
E[y].push_back(x);
}
dfs(1, 0);
while (k--)
{
int s, t, x;
scanf("%d%d", &s, &t, &x);
int u = LCA(s, t);
d[s] += x, d[t] += x;
d[f[u][0]] -= x, d[u] -= x;
}
dfs_(1, 0);
for (int i = 1; i <= n; i++)
printf("%d%c", ans[i], i == n ? '\n' : ' ');
return 0;
}

树链剖分

咕咕咕

TODO: 树的重链剖分(路径点修改/查询,单点修改/查询,子树点修改/查询,边权转点权)