P6329 【模板】点分树 | 震波()

\(\text{Solution}\)

点分树就是将点分治过程中的重心连成一棵虚树
对点分树子树信息的记录,就是点分治处理每个重心时需要的信息
这样就可以留下点分治的过程,支持多次修改和查询
点分树树高 \(O(log n)\) 且 \(\sum size_x = O(n \log n)\)
可以使用很多暴力的手段
但要注意:点分树和原树唯一的联系是点分树中两点的 \(LCA\) 在原树两点的路径上
\(LCA\) 的祖先和这两点的路径再无关系,容斥时要思考清楚
所以统计路径长一切行为以原树为准
因为要大量求 \(LCA\),所以用欧拉序转 \(RMQ\)
两点的 \(LCA\) 就是 \([\min(first_u,first_v),\max(first_u,first_v)]\) 中 \(dep\) 最小的点

注意:
欧拉序有两种:一个点入栈和出栈时记录,序列长 \(2n\)
一个点入栈记一次,每次回溯都记一次,考虑边数得序列长 \(2n-1\)
求 \(LCA\) 时用第二种

又:\(\text{vector}\) 的 \(\text{size()}\) 返回值为 \(\text{unsigned int}\),比较时将参与比较的元素强转 \(\text{unsigned int}\)
所以用负数比较会挂,这点让我懵逼了很久

\(\text{Code}\)

#include <cstdio>
#include <iostream>
#include <vector>
#define IN inline
using namespace std;

const int N = 1e5 + 5;
int n, m, h[N], tot, a[N];
struct edge{int to, nxt;}e[N * 2];
IN void add(int x, int y) {e[++tot] = edge{y, h[x]}, h[x] = tot;}

int dep[N], rt, size, used[N], son[N], sz[N], Rt, fa[N];
struct BIT {
	vector <int> c;
	IN void build(int n) {c.resize(n);}
	IN int lowbit(int x) {return x & (-x);}
	IN void add(int x, int v) {for(; x < c.size(); x += lowbit(x)) c[x] += v;}
	IN int query(int x) {
		if (x >= (int)c.size()) x = c.size() - 1;
		int s = 0; for(; x > 0; x -= lowbit(x)) s += c[x]; return s;
	}
}tr[N][2];

int rev[N * 2], st[N], dfc, lg[N * 2], mn[N * 2][21];
void dfs(int x, int dad) {
	st[x] = ++dfc, rev[dfc] = x;
	for(int i = h[x], v; i; i = e[i].nxt) {
		if ((v = e[i].to) == dad) continue;
		dep[v] = dep[x] + 1, dfs(v, x), rev[++dfc] = x;
	}
}
IN int LCA(int x, int y) {
	x = st[x], y = st[y]; if (x > y) swap(x, y);
	int k = lg[y - x + 1];
	if (dep[mn[x][k]] < dep[mn[y - (1 << k) + 1][k]]) return mn[x][k];
	return mn[y - (1 << k) + 1][k];
}
IN int Dis(int x, int y) {return dep[x] + dep[y] - dep[LCA(x, y)] * 2;}

void getrt(int x, int dad) {
	sz[x] = 1, son[x] = 0;
	for(int i = h[x], v; i; i = e[i].nxt) {
		if ((v = e[i].to) == dad || used[v]) continue;
		getrt(v, x), sz[x] += sz[v], son[x] = max(son[x], sz[v]);
	}
	son[x] = max(son[x], size - sz[x]);
	if (son[rt] > son[x]) rt = x;
}
void divide(int x) {
	used[x] = 1, tr[x][0].build(size + 1), tr[x][1].build(size + 2);
	for(int i = h[x], v; i; i = e[i].nxt) {
		if (used[v = e[i].to]) continue;
		rt = 0, size = sz[v], getrt(v, x), fa[rt] = x, divide(rt);
	}
}
void obtain() {
	lg[0] = -1;
	for(int i = 1; i <= dfc; i++) mn[i][0] = rev[i], lg[i] = lg[i >> 1] + 1;
	for(int i = 1; i <= lg[dfc]; i++) {
		for(int j = 1; j + (1 << i) - 1 <= dfc; j++)
			if (dep[mn[j][i - 1]] < dep[mn[j + (1 << i - 1)][i - 1]])
				mn[j][i] = mn[j][i - 1]; else mn[j][i] = mn[j + (1 << i - 1)][i - 1];
	}
	for(int i = 1; i <= n; i++)
		for(int j = i; j; j = fa[j]) {
			tr[j][0].add(Dis(j, i) + 1, a[i]);
			if (fa[j]) tr[j][1].add(Dis(fa[j], i) + 1, a[i]);
		}
}

IN void read(int &x) {
	x = 0; char ch = getchar(); int f = 1;
	for(; !isdigit(ch); f = (ch == '-' ? -1 : f), ch = getchar());
	for(; isdigit(ch); x = (x<<3)+(x<<1)+(ch^48), ch = getchar());
	x *= f;
}
IN int Query(int x, int k) {
	int ans = 0;
	for(int i = x; i; i = fa[i]) {
		ans += tr[i][0].query(k - Dis(i, x) + 1);
		if (fa[i]) ans -= tr[i][1].query(k - Dis(fa[i], x) + 1);
	}
	return ans;
}

int main() {
	read(n), read(m);
	for(int i = 1; i <= n; i++) read(a[i]);
	for(int i = 1, u, v; i < n; i++) read(u), read(v), add(u, v), add(v, u);
	rt = 0, size = n, son[0] = 2e9, getrt(1, 0), Rt = rt, divide(rt), dfs(Rt, 0), obtain();
	for(int op, x, y, lst = 0; m; --m) {
		read(op), read(x), read(y), x ^= lst, y ^= lst;
		if (op) {
			for(int i = x; i; i = fa[i]) {
				tr[i][0].add(Dis(x, i) + 1, y - a[x]);
				if (fa[i]) tr[i][1].add(Dis(x, fa[i]) + 1, y - a[x]);
			}
			a[x] = y;
		}
		else printf("%d\n", lst = Query(x, y));
	}
}
————————

\(\text{Solution}\)

点分树就是将点分治过程中的重心连成一棵虚树
对点分树子树信息的记录,就是点分治处理每个重心时需要的信息
这样就可以留下点分治的过程,支持多次修改和查询
点分树树高 \(O(log n)\) 且 \(\sum size_x = O(n \log n)\)
可以使用很多暴力的手段
但要注意:点分树和原树唯一的联系是点分树中两点的 \(LCA\) 在原树两点的路径上
\(LCA\) 的祖先和这两点的路径再无关系,容斥时要思考清楚
所以统计路径长一切行为以原树为准
因为要大量求 \(LCA\),所以用欧拉序转 \(RMQ\)
两点的 \(LCA\) 就是 \([\min(first_u,first_v),\max(first_u,first_v)]\) 中 \(dep\) 最小的点

注意:
欧拉序有两种:一个点入栈和出栈时记录,序列长 \(2n\)
一个点入栈记一次,每次回溯都记一次,考虑边数得序列长 \(2n-1\)
求 \(LCA\) 时用第二种

又:\(\text{vector}\) 的 \(\text{size()}\) 返回值为 \(\text{unsigned int}\),比较时将参与比较的元素强转 \(\text{unsigned int}\)
所以用负数比较会挂,这点让我懵逼了很久

\(\text{Code}\)

#include <cstdio>
#include <iostream>
#include <vector>
#define IN inline
using namespace std;

const int N = 1e5 + 5;
int n, m, h[N], tot, a[N];
struct edge{int to, nxt;}e[N * 2];
IN void add(int x, int y) {e[++tot] = edge{y, h[x]}, h[x] = tot;}

int dep[N], rt, size, used[N], son[N], sz[N], Rt, fa[N];
struct BIT {
	vector <int> c;
	IN void build(int n) {c.resize(n);}
	IN int lowbit(int x) {return x & (-x);}
	IN void add(int x, int v) {for(; x < c.size(); x += lowbit(x)) c[x] += v;}
	IN int query(int x) {
		if (x >= (int)c.size()) x = c.size() - 1;
		int s = 0; for(; x > 0; x -= lowbit(x)) s += c[x]; return s;
	}
}tr[N][2];

int rev[N * 2], st[N], dfc, lg[N * 2], mn[N * 2][21];
void dfs(int x, int dad) {
	st[x] = ++dfc, rev[dfc] = x;
	for(int i = h[x], v; i; i = e[i].nxt) {
		if ((v = e[i].to) == dad) continue;
		dep[v] = dep[x] + 1, dfs(v, x), rev[++dfc] = x;
	}
}
IN int LCA(int x, int y) {
	x = st[x], y = st[y]; if (x > y) swap(x, y);
	int k = lg[y - x + 1];
	if (dep[mn[x][k]] < dep[mn[y - (1 << k) + 1][k]]) return mn[x][k];
	return mn[y - (1 << k) + 1][k];
}
IN int Dis(int x, int y) {return dep[x] + dep[y] - dep[LCA(x, y)] * 2;}

void getrt(int x, int dad) {
	sz[x] = 1, son[x] = 0;
	for(int i = h[x], v; i; i = e[i].nxt) {
		if ((v = e[i].to) == dad || used[v]) continue;
		getrt(v, x), sz[x] += sz[v], son[x] = max(son[x], sz[v]);
	}
	son[x] = max(son[x], size - sz[x]);
	if (son[rt] > son[x]) rt = x;
}
void divide(int x) {
	used[x] = 1, tr[x][0].build(size + 1), tr[x][1].build(size + 2);
	for(int i = h[x], v; i; i = e[i].nxt) {
		if (used[v = e[i].to]) continue;
		rt = 0, size = sz[v], getrt(v, x), fa[rt] = x, divide(rt);
	}
}
void obtain() {
	lg[0] = -1;
	for(int i = 1; i <= dfc; i++) mn[i][0] = rev[i], lg[i] = lg[i >> 1] + 1;
	for(int i = 1; i <= lg[dfc]; i++) {
		for(int j = 1; j + (1 << i) - 1 <= dfc; j++)
			if (dep[mn[j][i - 1]] < dep[mn[j + (1 << i - 1)][i - 1]])
				mn[j][i] = mn[j][i - 1]; else mn[j][i] = mn[j + (1 << i - 1)][i - 1];
	}
	for(int i = 1; i <= n; i++)
		for(int j = i; j; j = fa[j]) {
			tr[j][0].add(Dis(j, i) + 1, a[i]);
			if (fa[j]) tr[j][1].add(Dis(fa[j], i) + 1, a[i]);
		}
}

IN void read(int &x) {
	x = 0; char ch = getchar(); int f = 1;
	for(; !isdigit(ch); f = (ch == '-' ? -1 : f), ch = getchar());
	for(; isdigit(ch); x = (x<<3)+(x<<1)+(ch^48), ch = getchar());
	x *= f;
}
IN int Query(int x, int k) {
	int ans = 0;
	for(int i = x; i; i = fa[i]) {
		ans += tr[i][0].query(k - Dis(i, x) + 1);
		if (fa[i]) ans -= tr[i][1].query(k - Dis(fa[i], x) + 1);
	}
	return ans;
}

int main() {
	read(n), read(m);
	for(int i = 1; i <= n; i++) read(a[i]);
	for(int i = 1, u, v; i < n; i++) read(u), read(v), add(u, v), add(v, u);
	rt = 0, size = n, son[0] = 2e9, getrt(1, 0), Rt = rt, divide(rt), dfs(Rt, 0), obtain();
	for(int op, x, y, lst = 0; m; --m) {
		read(op), read(x), read(y), x ^= lst, y ^= lst;
		if (op) {
			for(int i = x; i; i = fa[i]) {
				tr[i][0].add(Dis(x, i) + 1, y - a[x]);
				if (fa[i]) tr[i][1].add(Dis(x, fa[i]) + 1, y - a[x]);
			}
			a[x] = y;
		}
		else printf("%d\n", lst = Query(x, y));
	}
}