# P6329 【模板】点分树 | 震波()-其他

## P6329 【模板】点分树 | 震波()

### $$\text{Solution}$$

$$LCA$$ 的祖先和这两点的路径再无关系，容斥时要思考清楚

### $$\text{Code}$$

#include <cstdio>
#include <iostream>
#include <vector>
#define IN inline
using namespace std;

const int N = 1e5 + 5;
int n, m, h[N], tot, a[N];
struct edge{int to, nxt;}e[N * 2];
IN void add(int x, int y) {e[++tot] = edge{y, h[x]}, h[x] = tot;}

int dep[N], rt, size, used[N], son[N], sz[N], Rt, fa[N];
struct BIT {
vector <int> c;
IN void build(int n) {c.resize(n);}
IN int lowbit(int x) {return x & (-x);}
IN void add(int x, int v) {for(; x < c.size(); x += lowbit(x)) c[x] += v;}
IN int query(int x) {
if (x >= (int)c.size()) x = c.size() - 1;
int s = 0; for(; x > 0; x -= lowbit(x)) s += c[x]; return s;
}
}tr[N][2];

int rev[N * 2], st[N], dfc, lg[N * 2], mn[N * 2][21];
void dfs(int x, int dad) {
st[x] = ++dfc, rev[dfc] = x;
for(int i = h[x], v; i; i = e[i].nxt) {
if ((v = e[i].to) == dad) continue;
dep[v] = dep[x] + 1, dfs(v, x), rev[++dfc] = x;
}
}
IN int LCA(int x, int y) {
x = st[x], y = st[y]; if (x > y) swap(x, y);
int k = lg[y - x + 1];
if (dep[mn[x][k]] < dep[mn[y - (1 << k) + 1][k]]) return mn[x][k];
return mn[y - (1 << k) + 1][k];
}
IN int Dis(int x, int y) {return dep[x] + dep[y] - dep[LCA(x, y)] * 2;}

void getrt(int x, int dad) {
sz[x] = 1, son[x] = 0;
for(int i = h[x], v; i; i = e[i].nxt) {
if ((v = e[i].to) == dad || used[v]) continue;
getrt(v, x), sz[x] += sz[v], son[x] = max(son[x], sz[v]);
}
son[x] = max(son[x], size - sz[x]);
if (son[rt] > son[x]) rt = x;
}
void divide(int x) {
used[x] = 1, tr[x][0].build(size + 1), tr[x][1].build(size + 2);
for(int i = h[x], v; i; i = e[i].nxt) {
if (used[v = e[i].to]) continue;
rt = 0, size = sz[v], getrt(v, x), fa[rt] = x, divide(rt);
}
}
void obtain() {
lg[0] = -1;
for(int i = 1; i <= dfc; i++) mn[i][0] = rev[i], lg[i] = lg[i >> 1] + 1;
for(int i = 1; i <= lg[dfc]; i++) {
for(int j = 1; j + (1 << i) - 1 <= dfc; j++)
if (dep[mn[j][i - 1]] < dep[mn[j + (1 << i - 1)][i - 1]])
mn[j][i] = mn[j][i - 1]; else mn[j][i] = mn[j + (1 << i - 1)][i - 1];
}
for(int i = 1; i <= n; i++)
for(int j = i; j; j = fa[j]) {
if (fa[j]) tr[j][1].add(Dis(fa[j], i) + 1, a[i]);
}
}

x = 0; char ch = getchar(); int f = 1;
for(; !isdigit(ch); f = (ch == '-' ? -1 : f), ch = getchar());
for(; isdigit(ch); x = (x<<3)+(x<<1)+(ch^48), ch = getchar());
x *= f;
}
IN int Query(int x, int k) {
int ans = 0;
for(int i = x; i; i = fa[i]) {
ans += tr[i][0].query(k - Dis(i, x) + 1);
if (fa[i]) ans -= tr[i][1].query(k - Dis(fa[i], x) + 1);
}
return ans;
}

int main() {
for(int i = 1; i <= n; i++) read(a[i]);
rt = 0, size = n, son[0] = 2e9, getrt(1, 0), Rt = rt, divide(rt), dfs(Rt, 0), obtain();
for(int op, x, y, lst = 0; m; --m) {
if (op) {
for(int i = x; i; i = fa[i]) {
tr[i][0].add(Dis(x, i) + 1, y - a[x]);
if (fa[i]) tr[i][1].add(Dis(x, fa[i]) + 1, y - a[x]);
}
a[x] = y;
}
else printf("%d\n", lst = Query(x, y));
}
}

