# cf1060 E. Sergey and Subway（树形dp）(Cf1060 E. Sergey and subway (tree DP))-其他

## cf1060 E. Sergey and Subway（树形dp）(Cf1060 E. Sergey and subway (tree DP))

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 2e5 + 5, M = 4e5 + 5;
int h[N], e[M], ne[M], idx;
void add(int a, int b) {
e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

ll res[N], siz1[N], siz0[N], sum1[N], sum0[N];
void dfs(int u, int fa)
{
siz0[u] = 1;

for(int i = h[u]; i; i = ne[i])
{
int v = e[i]; if(v == fa) continue;
dfs(v, u);

res[u] += sum0[u]/2 * siz1[v] + (sum1[v]+siz1[v])/2 * (siz0[u]-1)
+ (sum1[u]+siz1[u])/2 * siz0[v] + sum0[v]/2 * siz1[u]
+ (sum1[u]+siz1[u])/2 * siz1[v] + (sum1[v]+siz1[v])/2 * siz1[u]
+ sum0[u]/2 * siz0[v] + sum0[v]/2 * (siz0[u]-1) + siz0[v] * (siz0[u]-1);

siz1[u] += siz0[v], siz0[u] += siz1[v];
sum1[u] += siz0[v] + sum0[v], sum0[u] += siz1[v] + sum1[v];
}
res[u] += sum0[u]/2 + (sum1[u]+siz1[u])/2;
}

signed main()
{
int n; scanf("%d", &n);
for(int i = 1, a, b; i < n; i++) scanf("%d%d", &a, &b), add(a, b), add(b, a);

dfs(1, 0);

ll ans = 0;
for(int i = 1; i <= n; i++) ans += res[i];
printf("%lld", ans);

return 0;
}



ll siz[N], cnt;
void dfs(int u, int fa, int dep)
{
siz[u] = 1;
if(dep % 2) cnt++;
for(int i = h[u]; i; i = ne[i]) {
int v = e[i]; if(v == fa) continue;
dfs(v, u, dep + 1);
siz[u] += siz[v];
}
}

ll ans = 0;
for(int i = 1; i <= n; i++) ans += siz[i] * (n - siz[i]);
ans += cnt * (n - cnt); ans /= 2;
printf("%lld", ans);

————————

< strong > meaning of the question: < / strong >

Given a tree, then connect edges between all point pairs with the same adjacent points. Newly connected edges cannot be used to judge adjacent edges. Find the sum of the distances of all point pairs.

< strong > ideas: < / strong >

Method 1: annoying tree DP. Maintain the number of points with odd distance from the root and even distance in the subtree.

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 2e5 + 5, M = 4e5 + 5;
int h[N], e[M], ne[M], idx;
void add(int a, int b) {
e[++idx] = b, ne[idx] = h[a], h[a] = idx;
}

ll res[N], siz1[N], siz0[N], sum1[N], sum0[N];
void dfs(int u, int fa)
{
siz0[u] = 1;

for(int i = h[u]; i; i = ne[i])
{
int v = e[i]; if(v == fa) continue;
dfs(v, u);

res[u] += sum0[u]/2 * siz1[v] + (sum1[v]+siz1[v])/2 * (siz0[u]-1)
+ (sum1[u]+siz1[u])/2 * siz0[v] + sum0[v]/2 * siz1[u]
+ (sum1[u]+siz1[u])/2 * siz1[v] + (sum1[v]+siz1[v])/2 * siz1[u]
+ sum0[u]/2 * siz0[v] + sum0[v]/2 * (siz0[u]-1) + siz0[v] * (siz0[u]-1);

siz1[u] += siz0[v], siz0[u] += siz1[v];
sum1[u] += siz0[v] + sum0[v], sum0[u] += siz1[v] + sum1[v];
}
res[u] += sum0[u]/2 + (sum1[u]+siz1[u])/2;
}

signed main()
{
int n; scanf("%d", &n);
for(int i = 1, a, b; i < n; i++) scanf("%d%d", &a, &b), add(a, b), add(b, a);

dfs(1, 0);

ll ans = 0;
for(int i = 1; i <= n; i++) ans += res[i];
printf("%lld", ans);

return 0;
}



Method 2: the answer before leaving aside is \ (\ sum size_ (n-size_) \).

After adding edges, the distance of even point pairs is halved; Odd + 1 after halving.

Before adding edges, ans = sum of point pair distances from even layer to even layer (x) + sum of point pair distances from odd layer to odd layer (y) + sum of point pair distances from odd layer to even layer (z). X and y must be even, reduced by half; The number of Z is the number of points with odd depth × Points with even depth

ll siz[N], cnt;
void dfs(int u, int fa, int dep)
{
siz[u] = 1;
if(dep % 2) cnt++;
for(int i = h[u]; i; i = ne[i]) {
int v = e[i]; if(v == fa) continue;
dfs(v, u, dep + 1);
siz[u] += siz[v];
}
}

ll ans = 0;
for(int i = 1; i <= n; i++) ans += siz[i] * (n - siz[i]);
ans += cnt * (n - cnt); ans /= 2;
printf("%lld", ans);