union_find/
lib.rs

1use std::{cell::RefCell, fmt};
2
3#[derive(Clone)]
4pub struct UnionFind(RefCell<Vec<usize>>, usize);
5
6impl UnionFind {
7    pub fn new(n: usize) -> Self {
8        Self(RefCell::new(vec![1_usize.wrapping_neg(); n]), n)
9    }
10    pub fn unite(&mut self, u: usize, v: usize) -> bool {
11        let u = self.repr(u);
12        let v = self.repr(v);
13        if u == v {
14            return false;
15        }
16
17        let (par, child) =
18            if self.count(u) < self.count(v) { (u, v) } else { (v, u) };
19
20        let mut buf = self.0.borrow_mut();
21        buf[par] = buf[par].wrapping_add(buf[child]);
22        buf[child] = par;
23        self.1 -= 1;
24        true
25    }
26    pub fn equiv(&self, u: usize, v: usize) -> bool {
27        self.repr(u) == self.repr(v)
28    }
29    pub fn repr(&self, u: usize) -> usize {
30        let par = self.0.borrow()[u];
31        if par >= self.0.borrow().len() {
32            return u;
33        }
34        let repr = self.repr(par);
35        self.0.borrow_mut()[u] = repr;
36        repr
37    }
38    pub fn count(&self, u: usize) -> usize {
39        let repr = self.repr(u);
40        self.0.borrow()[repr].wrapping_neg()
41    }
42    pub fn partition(&self) -> Vec<Vec<usize>> {
43        let len = self.0.borrow().len();
44        let mut ptn = vec![vec![]; len];
45        for i in 0..len {
46            ptn[self.repr(i)].push(i);
47        }
48        ptn
49    }
50    pub fn partition_len(&self) -> usize { self.1 }
51}
52
53struct AsSet<'a>(&'a Vec<usize>);
54impl fmt::Debug for AsSet<'_> {
55    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
56        fmt.debug_set().entries(self.0.iter()).finish()
57    }
58}
59
60impl fmt::Debug for UnionFind {
61    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
62        let ptn = self.partition();
63        let len = self.0.borrow().len();
64        fmt.debug_map()
65            .entries(
66                (0..len)
67                    .filter(|&i| !ptn[i].is_empty())
68                    .map(|i| (i, AsSet(&ptn[i]))),
69            )
70            .finish()
71    }
72}
73
74impl fmt::Display for UnionFind {
75    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
76        let ptn = self.partition();
77        fmt.debug_set()
78            .entries(
79                ptn.iter().filter(|set| !set.is_empty()).map(|set| AsSet(set)),
80            )
81            .finish()
82    }
83}
84
85#[test]
86fn sanity_check() {
87    let n = 10;
88    let mut actual = UnionFind::new(n);
89    let mut expected = naive::DisjointSet::new(n);
90
91    let f = |(u, v)| 2_u128.pow(u as _) * 3_u128.pow(v as _) % 625;
92    let query = {
93        let mut query: Vec<_> =
94            (0..n).flat_map(|u| (0..u).map(move |v| (u, v))).collect();
95        query.sort_unstable_by_key(|&(u, v)| f((u, v)));
96        query
97    };
98
99    for (u, v) in query {
100        assert_eq!(actual.unite(u, v), expected.unite(u, v));
101        for i in 0..n {
102            for j in 0..n {
103                assert_eq!(actual.equiv(i, j), expected.equiv(i, j));
104            }
105            assert_eq!(actual.count(i), expected.count(i));
106        }
107    }
108}
109
110#[test]
111fn debug_fmt() {
112    let mut uf = UnionFind::new(8);
113    uf.unite(1, 5);
114    uf.unite(2, 4);
115    uf.unite(0, 2);
116    uf.unite(1, 6);
117    uf.unite(6, 7);
118    assert_eq!(format!("{uf}"), "{{0, 2, 4}, {3}, {1, 5, 6, 7}}");
119    assert_eq!(format!("{uf:?}"), "{0: {0, 2, 4}, 3: {3}, 7: {1, 5, 6, 7}}");
120}