/* size 处理子树 d[], 连通块大小 cnt dp 最大子树 f[], 树的重心 rot get 计算出点到重心的距离 t[], top calc 点分治 bu[] 长度桶 hd to nx wg 链式前向星存图 ak[] as[] 离线处理询问 ok[] 点分治中已成为重心的点 */#include <iostream>const int N = 1e4 + 4, M = 105, Q = 1e7 + 7;int n, m, hd[N], to[N * 2], nx[N * 2], wg[N * 2];int ak[M], d[N], f[N], t[N], top, cnt, rot;bool as[M], ok[N], bu[Q];int size(int u, int pa) { cnt++, d[u] = 1; for (int p = hd[u]; ~p; p = nx[p]) if (to[p] != pa && !ok[to[p]]) d[u] += size(to[p], u); return d[u];}void dp(int u, int pa) { f[u] = cnt - d[u]; for (int p = hd[u]; ~p; p = nx[p]) if (to[p] != pa && !ok[to[p]]) { f[u] = std::max(f[u], d[to[p]]); dp(to[p], u); } if (f[u] < f[rot]) rot = u;}void get(int u, int pa, int dis) { t[top++] = dis; for (int p = hd[u]; ~p; p = nx[p]) if (to[p] != pa && !ok[to[p]]) get(to[p], u, dis + wg[p]);}void calc(int u) { cnt = 0, size(u, u); rot = u, dp(u, u); bu[0] = true, t[0] = 0, top = 1; for (int p = hd[rot], i; ~p; p = nx[p]) if (!ok[to[p]]) { i = top, get(to[p], rot, wg[p]); for (int q = 0; q < m; q++) for (int j = i; j < top && !as[q]; j++) if (ak[q] >= t[j]) as[q] = bu[ak[q] - t[j]]; --i; while (++i < top) if (t[i] < Q) bu[t[i]] = true; } while (top--) if (t[top] < Q) bu[t[top]] = false; ok[rot] = true; for (int p = hd[rot]; ~p; p = nx[p]) if (!ok[to[p]]) calc(to[p]);}int main() { std::cin >> n >> m; for (int i = 1; i <= n; i++) hd[i] = -1; for (int i = 0, u, v; i + 2 < n * 2;) { std::cin >> u >> v >> wg[i]; wg[i + 1] = wg[i]; to[i] = v, nx[i] = hd[u], hd[u] = i++; to[i] = u, nx[i] = hd[v], hd[v] = i++; } for (int i = 0; i < m; i++) std::cin >> ak[i]; calc(1); for (int i = 0; i < m; i++) std::cout << (as[i] ? "AYE\n" : "NAY\n");}
#include <algorithm>#include <cstring>#include <iostream>#include <queue>using namespace std;constexpr long long MAXN = 2000010;constexpr long long inf = 2e9;long long n, a, b, c, q, rt, siz[MAXN], maxx[MAXN], dist[MAXN];long long cur, h[MAXN], nxt[MAXN], p[MAXN], w[MAXN], ret;bool vis[MAXN];void add_edge(long long x, long long y, long long z) { cur++; nxt[cur] = h[x]; h[x] = cur; p[cur] = y; w[cur] = z;}long long sum;void calcsiz(long long x, long long fa) { siz[x] = 1; maxx[x] = 0; for (long long j = h[x]; j; j = nxt[j]) if (p[j] != fa && !vis[p[j]]) { calcsiz(p[j], x); maxx[x] = max(maxx[x], siz[p[j]]); siz[x] += siz[p[j]]; } maxx[x] = max(maxx[x], sum - siz[x]); if (maxx[x] < maxx[rt]) rt = x;}long long dd[MAXN], cnt;void calcdist(long long x, long long fa) { dd[++cnt] = dist[x]; for (long long j = h[x]; j; j = nxt[j]) if (p[j] != fa && !vis[p[j]]) dist[p[j]] = dist[x] + w[j], calcdist(p[j], x);}queue<long long> tag;struct segtree { long long cnt, rt, lc[MAXN], rc[MAXN], sum[MAXN]; void clear() { while (!tag.empty()) update(rt, 1, 20000000, tag.front(), -1), tag.pop(); cnt = 0; } void print(long long o, long long l, long long r) { if (!o || !sum[o]) return; if (l == r) { cout << l << ' ' << sum[o] << '\n'; return; } long long mid = (l + r) >> 1; print(lc[o], l, mid); print(rc[o], mid + 1, r); } void update(long long& o, long long l, long long r, long long x, long long v) { if (!o) o = ++cnt; if (l == r) { sum[o] += v; if (!sum[o]) o = 0; return; } long long mid = (l + r) >> 1; if (x <= mid) update(lc[o], l, mid, x, v); else update(rc[o], mid + 1, r, x, v); sum[o] = sum[lc[o]] + sum[rc[o]]; if (!sum[o]) o = 0; } long long query(long long o, long long l, long long r, long long ql, long long qr) { if (!o) return 0; if (r < ql || l > qr) return 0; if (ql <= l && r <= qr) return sum[o]; long long mid = (l + r) >> 1; return query(lc[o], l, mid, ql, qr) + query(rc[o], mid + 1, r, ql, qr); }} st;void dfz(long long x, long long fa) { // tf[0]=true;tag.push(0); st.update(st.rt, 1, 20000000, 1, 1); tag.push(1); vis[x] = true; for (long long j = h[x]; j; j = nxt[j]) if (p[j] != fa && !vis[p[j]]) { dist[p[j]] = w[j]; calcdist(p[j], x); for (long long k = 1; k <= cnt; k++) if (q - dd[k] >= 0) ret += st.query(st.rt, 1, 20000000, max(0ll, 1 - dd[k]) + 1, max(0ll, q - dd[k]) + 1); for (long long k = 1; k <= cnt; k++) st.update(st.rt, 1, 20000000, dd[k] + 1, 1), tag.push(dd[k] + 1); cnt = 0; } st.clear(); for (long long j = h[x]; j; j = nxt[j]) if (p[j] != fa && !vis[p[j]]) { sum = siz[p[j]]; rt = 0; maxx[rt] = inf; calcsiz(p[j], x); calcsiz(rt, -1); dfz(rt, x); }}signed main() { cin.tie(nullptr)->sync_with_stdio(false); cin >> n; for (long long i = 1; i < n; i++) cin >> a >> b >> c, add_edge(a, b, c), add_edge(b, a, c); cin >> q; rt = 0; maxx[rt] = inf; sum = n; calcsiz(1, -1); calcsiz(rt, -1); dfz(rt, -1); cout << ret << '\n'; return 0;}
而针对 2 部分,设当前根节点 u 的一个子节点为 d,d 的子树里任取一个点为 v,那么 v 的答案可以分为两部分:
(u,v) 路径上出现过的颜色,数量设为 num,u 除了 d 以外的其他所有子树的总大小设为 siz1, 那么这些出现过的颜色对 v 的答案贡献为 num×siz1。
(u,v) 路径上没有出现过的颜色 j,它们的贡献来自于 u 除了 d 以外的其他所有子树的 cntj,这部分答案为 ∑j∈/(u,v)cntj。
以上是全部统计思路,实现细节详见参考代码。
参考代码
#include <algorithm>#include <iostream>using namespace std;#define rep(i, a, b) for (int i = (a); i <= (b); ++i)constexpr int N = 200005;int h[N], nxt[N * 2], to[N * 2], c[N], gr;void tu(int x, int y) { to[++gr] = y, nxt[gr] = h[x], h[x] = gr; }using ll = long long;int n, nn, siz[N], mn, rt;bool vis[N];void get_root(int u, int f) { siz[u] = 1; int mx = 0; for (int i = h[u]; i; i = nxt[i]) { int d = to[i]; if (vis[d] || d == f) continue; get_root(d, u); siz[u] += siz[d]; mx = max(mx, siz[d]); } mx = max(mx, nn - siz[u]); if (mx < mn) mn = mx, rt = u;}ll ans[N], sum;int cnt[N], v[N];// sum实时统计的是cnt[i]的和int nowrt;void get_dis(int u, int f, int now) { // now为当前树链上的颜色数量(不含u) siz[u] = 1; if (!v[c[u]]) { sum -= cnt[c[u]]; // 减去在之前子树中已经出现过的颜色信息 now++; } v[c[u]]++; ans[u] += sum + now * siz[nowrt]; // 统计过u点的路径对u的贡献 for (int i = h[u]; i; i = nxt[i]) { int d = to[i]; if (d == f || vis[d]) continue; get_dis(d, u, now); siz[u] += siz[d]; } v[c[u]]--; if (!v[c[u]]) { sum += cnt[c[u]]; // 回溯 }}void get_cnt(int u, int f) { if (!v[c[u]]) { cnt[c[u]] += siz[u]; sum += siz[u]; // 将刚遍历过的子树的信息整合到cnt[i]和sum上去 } v[c[u]]++; for (int i = h[u]; i; i = nxt[i]) { int d = to[i]; if (vis[d] || d == f) continue; get_cnt(d, u); } v[c[u]]--;}void clear(int u, int f, int now) { if (!v[c[u]]) now++; v[c[u]]++; ans[u] -= now; ans[nowrt] += now; for (int i = h[u]; i; i = nxt[i]) { int d = to[i]; if (vis[d] || d == f) continue; clear(d, u, now); } v[c[u]]--; cnt[c[u]] = 0;}void clear2(int u, int f) { cnt[c[u]] = 0; for (int i = h[u]; i; i = nxt[i]) { int d = to[i]; if (vis[d] || d == f) continue; clear2(d, u); }}int son[N];void divid(int u) { vis[u] = true; int tot = 0; nowrt = u; ans[u]++; for (int i = h[u]; i; i = nxt[i]) { if (vis[to[i]]) continue; son[++tot] = to[i]; } siz[u] = sum = cnt[c[u]] = 1; v[c[u]]++; rep(i, 1, tot) { // 统计每个子树和它之前的所有子树中节点组合产生的贡献 int d = son[i]; get_dis(d, u, 0); get_cnt(d, u); siz[u] += siz[d]; cnt[c[u]] += siz[d]; sum += siz[d]; } clear2(u, 0); // 清空数组,记得不可以用memset siz[u] = sum = cnt[c[u]] = 1; for (int i = tot; i >= 1; --i) { // 统计每个子树和它之后的所有子树中节点组合产生的贡献 int d = son[i]; get_dis(d, u, 0); get_cnt(d, u); siz[u] += siz[d]; cnt[c[u]] += siz[d]; sum += siz[d]; } v[c[u]]--; clear(u, 0, 0); // 清空的同时统计答案 for (int i = h[u]; i; i = nxt[i]) { // 继续向下进行点分治 int d = to[i]; if (vis[d]) continue; nn = siz[d], mn = n + 1, rt = 0; get_root(d, u); divid(rt); }}int main() { cin.tie(nullptr)->sync_with_stdio(false); cin >> n; int u, v; rep(i, 1, n) cin >> c[i]; rep(i, 2, n) cin >> u >> v, tu(u, v), tu(v, u); rt = 0, nn = n, mn = n + 1; get_root(1, 0); divid(rt); rep(i, 1, n) cout << ans[i] << '\n'; return 0;}