cf1060 E. Sergey and Subway(树形dp)(Cf1060 E. Sergey and subway (tree DP))

题意:

给定一棵树,然后在所有有相同邻点的点对之间连边。新连的边不能用于判断相邻。求所有点对的距离和。

思路:

法一:烦人的树形dp。维护子树中与根的距离为奇数的点数和距离为偶数的点数。

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 2e5 + 5, M = 4e5 + 5;
int h[N], e[M], ne[M], idx;
void add(int a, int b) {
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

ll res[N], siz1[N], siz0[N], sum1[N], sum0[N];
void dfs(int u, int fa)
{
    siz0[u] = 1;

    for(int i = h[u]; i; i = ne[i])
    {
        int v = e[i]; if(v == fa) continue;
        dfs(v, u);

        res[u] += sum0[u]/2 * siz1[v] + (sum1[v]+siz1[v])/2 * (siz0[u]-1)
        + (sum1[u]+siz1[u])/2 * siz0[v] + sum0[v]/2 * siz1[u]
        + (sum1[u]+siz1[u])/2 * siz1[v] + (sum1[v]+siz1[v])/2 * siz1[u]
        + sum0[u]/2 * siz0[v] + sum0[v]/2 * (siz0[u]-1) + siz0[v] * (siz0[u]-1);

        siz1[u] += siz0[v], siz0[u] += siz1[v];
        sum1[u] += siz0[v] + sum0[v], sum0[u] += siz1[v] + sum1[v];
    }
    res[u] += sum0[u]/2 + (sum1[u]+siz1[u])/2;
}

signed main()
{
    int n; scanf("%d", &n);
    for(int i = 1, a, b; i < n; i++) scanf("%d%d", &a, &b), add(a, b), add(b, a);

    dfs(1, 0);

    ll ans = 0;
    for(int i = 1; i <= n; i++) ans += res[i];
    printf("%lld", ans);

    return 0;
}

法二:不加边之前的答案为 \(\sum size_u(n-size_u)\) 。

加边后,距离为偶数的点对距离减半;奇数的减半后+1。

不加边前,ans=偶数层到偶数层的点对距离和(X)+奇数层到奇数层的点对距离和(Y)+奇数到偶数的点对距离和(Z) 。X和Y必然为偶数,缩小一半;Z的数量为深度为奇的点数×深度为偶的点数

ll siz[N], cnt;
void dfs(int u, int fa, int dep)
{
    siz[u] = 1;
    if(dep % 2) cnt++;
    for(int i = h[u]; i; i = ne[i]) {
        int v = e[i]; if(v == fa) continue;
        dfs(v, u, dep + 1);
        siz[u] += siz[v];
    }
}

ll ans = 0;
for(int i = 1; i <= n; i++) ans += siz[i] * (n - siz[i]);
ans += cnt * (n - cnt); ans /= 2;
printf("%lld", ans);
————————

< strong > meaning of the question: < / strong >

Given a tree, then connect edges between all point pairs with the same adjacent points. Newly connected edges cannot be used to judge adjacent edges. Find the sum of the distances of all point pairs.

< strong > ideas: < / strong >

Method 1: annoying tree DP. Maintain the number of points with odd distance from the root and even distance in the subtree.

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 2e5 + 5, M = 4e5 + 5;
int h[N], e[M], ne[M], idx;
void add(int a, int b) {
    e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

ll res[N], siz1[N], siz0[N], sum1[N], sum0[N];
void dfs(int u, int fa)
{
    siz0[u] = 1;

    for(int i = h[u]; i; i = ne[i])
    {
        int v = e[i]; if(v == fa) continue;
        dfs(v, u);

        res[u] += sum0[u]/2 * siz1[v] + (sum1[v]+siz1[v])/2 * (siz0[u]-1)
        + (sum1[u]+siz1[u])/2 * siz0[v] + sum0[v]/2 * siz1[u]
        + (sum1[u]+siz1[u])/2 * siz1[v] + (sum1[v]+siz1[v])/2 * siz1[u]
        + sum0[u]/2 * siz0[v] + sum0[v]/2 * (siz0[u]-1) + siz0[v] * (siz0[u]-1);

        siz1[u] += siz0[v], siz0[u] += siz1[v];
        sum1[u] += siz0[v] + sum0[v], sum0[u] += siz1[v] + sum1[v];
    }
    res[u] += sum0[u]/2 + (sum1[u]+siz1[u])/2;
}

signed main()
{
    int n; scanf("%d", &n);
    for(int i = 1, a, b; i < n; i++) scanf("%d%d", &a, &b), add(a, b), add(b, a);

    dfs(1, 0);

    ll ans = 0;
    for(int i = 1; i <= n; i++) ans += res[i];
    printf("%lld", ans);

    return 0;
}

Method 2: the answer before leaving aside is \ (\ sum size_ (n-size_) \).

After adding edges, the distance of even point pairs is halved; Odd + 1 after halving.

Before adding edges, ans = sum of point pair distances from even layer to even layer (x) + sum of point pair distances from odd layer to odd layer (y) + sum of point pair distances from odd layer to even layer (z). X and y must be even, reduced by half; The number of Z is the number of points with odd depth × Points with even depth

ll siz[N], cnt;
void dfs(int u, int fa, int dep)
{
    siz[u] = 1;
    if(dep % 2) cnt++;
    for(int i = h[u]; i; i = ne[i]) {
        int v = e[i]; if(v == fa) continue;
        dfs(v, u, dep + 1);
        siz[u] += siz[v];
    }
}

ll ans = 0;
for(int i = 1; i <= n; i++) ans += siz[i] * (n - siz[i]);
ans += cnt * (n - cnt); ans /= 2;
printf("%lld", ans);