class Dsu: def __init__(self, size): self.pa = list(range(size)) self.size = [1] * size def unite(self, x, y): x, y = self.find(x), self.find(y) if x == y: return if self.size[x] < self.size[y]: x, y = y, x self.pa[y] = x self.size[x] += self.size[y]
这道题目中,操作 1 和操作 3 都容易处理,难点在于操作 2。假定要将元素 x 移动到元素 y 所在的集合。在普通的并查集中,直接将元素 x 的父亲设为元素 y 所在集合的根节点是不行的,因为这样会将元素 x 所在子树的元素都一起移动。针对这个问题,解决方法就是保证元素 x 没有子节点。为此,在建立并查集时为每个元素 x 都建立一个虚点 x~,并将元素 x 的父亲指向对应的虚点 x~。这样,在合并两个集合的时候,因为总是将一个树根连接到另一个树根,而树根又全部是虚点,所以,只有虚点会有子节点,而所有实际存储元素的点都没有子节点。此时,要移动元素,就容易实现得多。
参考实现
[list2tab]
C++
#include <cassert>#include <iostream>#include <numeric>#include <vector>using namespace std;struct dsu { vector<size_t> pa, size, sum; explicit dsu(size_t size_) : pa(size_ * 2), size(size_ * 2, 1), sum(size_ * 2) { // size 与 sum 的前半段其实没有使用,只是为了让下标计算更简单 iota(pa.begin(), pa.begin() + size_, size_); iota(pa.begin() + size_, pa.end(), size_); iota(sum.begin() + size_, sum.end(), 0); } void unite(size_t x, size_t y) { x = find(x), y = find(y); if (x == y) return; if (size[x] < size[y]) swap(x, y); pa[y] = x; size[x] += size[y]; sum[x] += sum[y]; } void move(size_t x, size_t y) { auto fx = find(x), fy = find(y); if (fx == fy) return; pa[x] = fy; --size[fx], ++size[fy]; sum[fx] -= x, sum[fy] += x; } size_t find(size_t x) { return pa[x] == x ? x : pa[x] = find(pa[x]); }};int main() { size_t n, m, op, x, y; while (cin >> n >> m) { dsu dsu(n + 1); // 元素范围是 1..n while (m--) { cin >> op; switch (op) { case 1: cin >> x >> y; dsu.unite(x, y); break; case 2: cin >> x >> y; dsu.move(x, y); break; case 3: cin >> x; x = dsu.find(x); cout << dsu.size[x] << ' ' << dsu.sum[x] << '\n'; break; default: assert(false); // not reachable } } } return 0;}
Python
class Dsu: def __init__(self, size): # size 与 sum 的前半段其实没有使用,只是为了让下标计算更简单 self.pa = list(range(size, size * 2)) * 2 self.size = [1] * size * 2 self.sum = list(range(size)) * 2 def unite(self, x, y): x, y = self.find(x), self.find(y) if x == y: return if self.size[x] < self.size[y]: x, y = y, x self.pa[y] = x self.size[x] += self.size[y] self.sum[x] += self.sum[y] def move(self, x, y): fx, fy = self.find(x), self.find(y) if fx == fy: return self.pa[x] = fy self.size[fx] -= 1 self.size[fy] += 1 self.sum[fx] -= x self.sum[fy] += x def find(self, x): if self.pa[x] != x: self.pa[x] = self.find(self.pa[x]) return self.pa[x]if __name__ == "__main__": while True: try: n, m = map(int, input().split()) dsu = Dsu(n + 1) # 元素范围是 1..n for _ in range(m): op_x_y = list(map(int, input().split())) op = op_x_y[0] if op == 1: dsu.unite(op_x_y[1], op_x_y[2]) elif op == 2: dsu.move(op_x_y[1], op_x_y[2]) elif op == 3: x = dsu.find(op_x_y[1]) print(dsu.size[x], dsu.sum[x]) except EOFError: break
现有 N 个动物,以 1∼N 编号。每个动物都是 A,B,C 中的一种,但是我们并不知道它到底是哪一种。
有人用两种说法对这 N 个动物所构成的食物链关系进行描述:
第一种说法是 1 X Y,表示 X 和 Y 是同类。
第二种说法是 2 X Y,表示 X 吃 Y。
此人对 N 个动物,用上述两种说法,一句接一句地说出 K 句话,这 K 句话有的是真的,有的是假的。当一句话满足下列三条之一时,这句话就是假话,否则就是真话。
当前的话与前面的某些真的话冲突,就是假话;
当前的话中 X 或 Y 比 N 大,就是假话;
当前的话表示 X 吃 X,就是假话。
你的任务是根据给定的 N 和 K 句话,输出假话的总数。
解答一
考虑用带权并查集维护食物链信息。如果 x 和 y 是同类,那么 x≡y(mod3);如果 x 吃 y,那么 x−y≡1(mod3)。这样就将本题转化为前文的模板题。
具体地,对于每一句话,除去那些那些 x>n 或 y>n 的显然的假话外,需要判断 x 和 y 是否已经连接:如果已经连接,计算两者的模意义下的距离,并与这句话声称的信息进行比较;否则,将两者按照这句话提供的信息连接。除了显然的情形外,一句话是假话,当且仅当提到的两个节点已经连接,且对应的距离与这句话声称的信息矛盾。
参考实现一
[list2tab]
C++
#include <algorithm>#include <iostream>#include <numeric>#include <vector>constexpr int M = 3;struct DSU { std::vector<size_t> pa, size, dist; explicit DSU(size_t size_) : pa(size_), size(size_, 1), dist(size_) { std::iota(pa.begin(), pa.end(), 0); } size_t find(size_t x) { if (pa[x] == x) return x; size_t y = find(pa[x]); (dist[x] += dist[pa[x]]) %= M; return pa[x] = y; } bool unite(size_t x, size_t y, int d) { find(x), find(y); (d += M - dist[y]) %= M; (d += dist[x]) %= M; x = pa[x], y = pa[y]; if (x == y) return d == 0; if (size[x] < size[y]) { std::swap(x, y); d = (M - d) % M; } pa[y] = x; size[x] += size[y]; dist[y] = d; return true; }};int main() { int n, m; std::cin >> n >> m; DSU dsu(n + 1); int res = 0; for (; m; --m) { int op, x, y; std::cin >> op >> x >> y; if (x > n || y > n) ++res; else res += !dsu.unite(x, y, op == 1 ? 0 : 1); } std::cout << res << std::endl; return 0;}
Python
M = 3class DSU: def __init__(self, size: int): self.pa = list(range(size)) self.size = [1] * size self.dist = [0] * size def find(self, x: int) -> int: if self.pa[x] == x: return x y = self.find(self.pa[x]) self.dist[x] = (self.dist[x] + self.dist[self.pa[x]]) % M self.pa[x] = y return y def unite(self, x: int, y: int, d: int) -> bool: self.find(x) self.find(y) d = (d + M - self.dist[y]) % M d = (d + self.dist[x]) % M x, y = self.pa[x], self.pa[y] if x == y: return d == 0 if self.size[x] < self.size[y]: x, y = y, x d = (M - d) % M self.pa[y] = x self.size[x] += self.size[y] self.dist[y] = d return Trueif __name__ == "__main__": n, m = map(int, input().split()) dsu = DSU(n + 1) res = 0 for _ in range(m): op, x, y = map(int, input().split()) if x > n or y > n: res += 1 else: res += not dsu.unite(x, y, 0 if op == 1 else 1) print(res)
[!note]- 解答二
将一种生物 x 拆分为三种状态。在具体实现中,我们可以直接将不同的状态当作不同的元素:
与 x 处于同一集合的状态与 x 属于同一物种;
与 x+n 处于同一集合的状态能被 x 吃;
与 x+2n 处于同一集合的能吃 x。
于是,对于一句话:
1 x y 为假话当且仅当:
x>N 或 y>N;
y 与 x+n 或 x+2n 中的一个处于同一集合内。
2 x y 为假话当且仅当:
x>N 或 y>N;
y 与 x 或 x+2n 中的一个处于同一集合内。
若为真话,合并对应状态。
参考实现二
[list2tab]
C++
#include <algorithm>#include <iostream>#include <numeric>#include <vector>struct DSU { std::vector<size_t> pa, size; explicit DSU(size_t size_) : pa(size_), size(size_, 1) { std::iota(pa.begin(), pa.end(), 0); } size_t find(size_t x) { return pa[x] == x ? x : pa[x] = find(pa[x]); } void unite(size_t x, size_t y) { x = find(x), y = find(y); if (x == y) return; if (size[x] < size[y]) std::swap(x, y); pa[y] = x; size[x] += size[y]; }};int main() { int n, m; std::cin >> n >> m; DSU dsu(n * 3 + 1); int res = 0; for (; m; --m) { int op, x, y; std::cin >> op >> x >> y; if (x > n || y > n) ++res; else if (op == 1) { if (dsu.find(x) == dsu.find(y + n) || dsu.find(x) == dsu.find(y + (n << 1))) { ++res; } else { dsu.unite(x, y); dsu.unite(x + n, y + n); dsu.unite(x + n * 2, y + n * 2); } } else { if (dsu.find(x) == dsu.find(y) || dsu.find(x) == dsu.find(y + n)) { ++res; } else { dsu.unite(x, y + n * 2); dsu.unite(x + n, y); dsu.unite(x + n * 2, y + n); } } } std::cout << res << std::endl; return 0;}
Python
class Dsu: def __init__(self, size): self.pa = list(range(size)) self.size = [1] * size def find(self, x): if self.pa[x] != x: self.pa[x] = self.find(self.pa[x]) return self.pa[x] def unite(self, x, y): x, y = self.find(x), self.find(y) if x == y: return if self.size[x] < self.size[y]: x, y = y, x self.pa[y] = x self.size[x] += self.size[y]if __name__ == "__main__": n, m = map(int, input().split()) dsu = Dsu(n * 3 + 1) res = 0 for _ in range(m): op, x, y = map(int, input().split()) if x > n or y > n: res += 1 elif op == 1: if dsu.find(x) == dsu.find(y + n) or dsu.find(x) == dsu.find(y + (n << 1)): res += 1 else: dsu.unite(x, y) dsu.unite(x + n, y + n) dsu.unite(x + n * 2, y + n * 2) else: if dsu.find(x) == dsu.find(y) or dsu.find(x) == dsu.find(y + n): res += 1 else: dsu.unite(x, y + n * 2) dsu.unite(x + n, y) dsu.unite(x + n * 2, y + n) print(res)
异或就是单个二进制位上的「相同」或「不同」关系。那么,将 Ai 的所有二进制位拆开,异或关系就能用带权并查集(或种类并查集)维护了。同一个连通块内的元素一定对应着 A 中不同数字的同一个数位。统计答案时,同一连通块的元素通常分为两组,两组之间取值应当不同,只需要取其中较大的一组赋值为 0,另一组赋值为 1 即可保证总权值最小。
参考实现
[list2tab]
C++
#include <algorithm>#include <iostream>#include <numeric>#include <vector>constexpr int M = 2;struct DSU { std::vector<size_t> pa, size, dist; explicit DSU(size_t size_) : pa(size_), size(size_, 1), dist(size_) { std::iota(pa.begin(), pa.end(), 0); } size_t find(size_t x) { if (pa[x] == x) return x; size_t y = find(pa[x]); (dist[x] += dist[pa[x]]) %= M; return pa[x] = y; } bool unite(size_t x, size_t y, int d) { find(x), find(y); (d += M - dist[y]) %= M; (d += dist[x]) %= M; x = pa[x], y = pa[y]; if (x == y) return d == 0; if (size[x] < size[y]) { std::swap(x, y); d = (M - d) % M; } pa[y] = x; size[x] += size[y]; dist[y] = d; return true; }};int main() { int n, m; std::cin >> n >> m; DSU dsu((n + 1) << 5); for (; m; --m) { int x, y, z; std::cin >> x >> y >> z; for (int i = 0; i < 31; ++i) { if (!dsu.unite((x << 5) | i, (y << 5) | i, (z >> i) & 1)) { std::cout << -1 << std::endl; return 0; } } } std::vector<int> a(n + 1), cnt((n + 1) << 5); for (int i = 1; i < ((n + 1) << 5); ++i) { dsu.find(i); if (dsu.dist[i]) ++cnt[dsu.pa[i]]; } for (int i = 1; i <= n; ++i) { for (int j = 0; j < 31; ++j) { int x = (i << 5) | j, y = dsu.pa[x]; if ((cnt[y] > dsu.size[y] / 2) ^ dsu.dist[x]) { a[i] |= 1 << j; } } } for (int i = 1; i <= n; ++i) std::cout << a[i] << ' '; std::cout << std::endl; return 0;}
Python
M = 2class DSU: def __init__(self, size: int): self.pa = list(range(size)) self.size = [1] * size self.dist = [0] * size def find(self, x: int) -> int: if self.pa[x] == x: return x y = self.find(self.pa[x]) self.dist[x] = (self.dist[x] + self.dist[self.pa[x]]) % M self.pa[x] = y return y def unite(self, x: int, y: int, d: int) -> bool: self.find(x) self.find(y) d = (d + M - self.dist[y]) % M d = (d + self.dist[x]) % M x, y = self.pa[x], self.pa[y] if x == y: return d == 0 if self.size[x] < self.size[y]: x, y = y, x d = (M - d) % M self.pa[y] = x self.size[x] += self.size[y] self.dist[y] = d return Trueif __name__ == "__main__": n, m = map(int, input().split()) dsu = DSU((n + 1) << 5) for _ in range(m): x, y, z = map(int, input().split()) for i in range(31): if not dsu.unite((x << 5) | i, (y << 5) | i, (z >> i) & 1): print(-1) exit() a = [0] * (n + 1) cnt = [0] * ((n + 1) << 5) for i in range(1, (n + 1) << 5): dsu.find(i) if dsu.dist[i]: cnt[dsu.pa[i]] += 1 for i in range(1, n + 1): for j in range(31): x = (i << 5) | j y = dsu.pa[x] if (cnt[y] > dsu.size[y] // 2) ^ dsu.dist[x]: a[i] |= 1 << j print(" ".join(map(str, a[1:])))
Gabow, H. N., & Tarjan, R. E. (1985). A Linear-Time Algorithm for a Special Case of Disjoint Set Union. JOURNAL OF COMPUTER AND SYSTEM SCIENCES, 30, 209-221.PDF