Skip to main content

nekolib/ds/
binary_trie.rs

1use std::fmt::{self, Debug};
2use std::ops::AddAssign;
3
4#[derive(Debug)]
5pub struct BinaryTrie<I> {
6    head: Link<I>,
7}
8
9type Link<I> = Option<Box<Node<I>>>;
10
11struct Node<I> {
12    sum0: usize,
13    sum1: I,
14    next: [Link<I>; 2],
15}
16
17impl<I: Debug> Debug for Node<I> {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        f.debug_map()
20            .entry(&"sum0", &self.sum0)
21            .entry(&"sum1", &self.sum1)
22            .entry(&"next", &[
23                self.next[0].as_ref().map(|_| ..),
24                self.next[1].as_ref().map(|_| ..),
25            ])
26            .finish()
27    }
28}
29
30impl<I: BinaryInt> BinaryTrie<I> {
31    pub fn new() -> Self { Self { head: None } }
32
33    pub fn insert(&mut self, elem: I) {
34        let mut cursor = &mut self.head;
35        for bit in elem.bits() {
36            let tmp = cursor.get_or_insert_with(|| Self::new_node());
37            tmp.sum0 += 1;
38            tmp.sum1 += elem;
39            cursor = &mut tmp.next[bit as usize];
40        }
41        let tmp = cursor.get_or_insert_with(|| Self::new_node());
42        tmp.sum0 += 1;
43        tmp.sum1 += elem;
44    }
45
46    pub fn iter(&self) -> Iter<'_, I> { Iter::new(&self) }
47
48    pub fn iter_dup(
49        &self,
50    ) -> impl '_ + Iterator<Item = I> + DoubleEndedIterator {
51        self.iter().flat_map(|(x, i)| (0..i).map(move |_| x))
52    }
53
54    fn new_node() -> Box<Node<I>> {
55        Box::new(Node { sum0: 0, sum1: I::zero(), next: [None, None] })
56    }
57}
58
59#[derive(Debug)]
60pub struct Iter<'a, I> {
61    // trie: &'a BinaryTrie<I>,
62    left_path: Vec<(&'a Box<Node<I>>, usize)>,
63    right_path: Vec<(&'a Box<Node<I>>, usize)>,
64    left_int: I,
65    right_int: I,
66}
67
68impl<'a, I: BinaryInt> Iter<'a, I> {
69    fn new(trie: &'a BinaryTrie<I>) -> Iter<I> {
70        let (left_path, left_int) = Self::descend(trie, 0);
71        let (right_path, right_int) = Self::descend(trie, 1);
72        Self { left_path, left_int, right_path, right_int }
73    }
74
75    fn descend(
76        trie: &'a BinaryTrie<I>,
77        fst: usize,
78    ) -> (Vec<(&Box<Node<I>>, usize)>, I) {
79        let mut int = I::zero();
80        let cursor = trie.head.as_ref();
81        let mut path = vec![];
82        Self::descend_inner(cursor, 0, &mut path, &mut int, fst);
83        (path, int)
84    }
85
86    fn descend_inner(
87        mut cursor: Option<&'a Box<Node<I>>>,
88        mut dir: usize,
89        path: &mut Vec<(&'a Box<Node<I>>, usize)>,
90        int: &mut I,
91        fst: usize,
92    ) {
93        while let Some(next) = cursor {
94            path.push((next, dir));
95            if let Some(fst_path) = &next.next[fst] {
96                int.push(fst != 0);
97                cursor = Some(&fst_path);
98                dir = fst;
99            } else if let Some(snd_path) = &next.next[fst ^ 1] {
100                int.push((fst ^ 1) != 0);
101                cursor = Some(&snd_path);
102                dir = fst ^ 1;
103            } else {
104                break;
105            }
106        }
107    }
108
109    fn next_dir(&mut self, dir: usize) -> Option<(I, usize)> {
110        // (値, 個数) を返したい?
111
112        let Self { left_int, left_path, right_int, right_path } = self;
113        if left_path.is_empty() {
114            return None;
115        }
116        let res = if dir == 0 { *left_int } else { *right_int };
117        let count = {
118            let path_last = if dir == 0 {
119                left_path.last().unwrap()
120            } else {
121                right_path.last().unwrap()
122            };
123            path_last.0.sum0
124        };
125
126        if left_int == right_int {
127            left_path.clear();
128            right_path.clear();
129        }
130
131        let path = if dir == 0 { left_path } else { right_path };
132        let int = if dir == 0 { left_int } else { right_int };
133
134        let mut last_dir = dir ^ 1;
135        while let Some((node, cur_dir)) = path.pop() {
136            if let Some(next) = &node.next[dir ^ 1] {
137                if last_dir == dir {
138                    path.push((node, cur_dir));
139                    int.push(dir == 0);
140                    Self::descend_inner(Some(next), dir ^ 1, path, int, dir);
141                    break;
142                }
143            }
144
145            int.pop();
146            last_dir = cur_dir;
147        }
148        Some((res, count))
149    }
150}
151
152impl<I: BinaryInt> Iterator for Iter<'_, I> {
153    type Item = (I, usize);
154    fn next(&mut self) -> Option<Self::Item> { self.next_dir(0) }
155}
156
157impl<I: BinaryInt> DoubleEndedIterator for Iter<'_, I> {
158    fn next_back(&mut self) -> Option<Self::Item> { self.next_dir(1) }
159}
160
161pub trait BinaryInt: Copy + AddAssign<Self> + Eq + Debug {
162    fn zero() -> Self;
163    fn bits(self) -> Bits<Self>;
164    fn test(self, shift: u32) -> bool;
165    fn push(&mut self, bit: bool);
166    fn pop(&mut self);
167}
168
169pub struct Bits<I> {
170    val: I,
171    shift: u32,
172}
173
174impl<I: BinaryInt> Iterator for Bits<I> {
175    type Item = bool;
176    fn next(&mut self) -> Option<bool> {
177        if self.shift == 0 {
178            return None;
179        }
180        self.shift -= 1;
181        Some(self.val.test(self.shift))
182    }
183}
184
185macro_rules! impl_uint {
186    ( $($ty:ty)* ) => { $(
187        impl BinaryInt for $ty {
188            fn zero() -> Self { 0 }
189            fn bits(self) -> Bits<Self> {
190                let bits = (0 as $ty).count_zeros();
191                Bits { val: self, shift: bits }
192            }
193            fn test(self, shift: u32) -> bool {
194                self >> shift & 1 != 0
195            }
196            fn push(&mut self, bit: bool) {
197                *self <<= 1;
198                *self |= bit as $ty;
199            }
200            fn pop(&mut self) {
201                *self >>= 1;
202            }
203        }
204    )* }
205}
206
207impl_uint! { u8 u16 u32 u64 u128 usize }
208
209#[test]
210fn test() {
211    let mut bt = BinaryTrie::<u8>::new();
212    // eprintln!("{:?}", bt);
213    bt.insert(10);
214    // eprintln!("{:?}", bt);
215    bt.insert(3);
216    // eprintln!("{:?}", bt);
217    bt.insert(1);
218    // eprintln!("{:?}", bt);
219    bt.insert(3);
220    // eprintln!("{:#?}", bt.iter());
221    bt.insert(0);
222    bt.insert(14);
223    bt.insert(100);
224
225    for x in bt.iter().take(10) {
226        eprintln!("{x:?}");
227    }
228
229    for x in bt.iter().rev().take(10) {
230        eprintln!("{x:?}");
231    }
232
233    for x in bt.iter_dup().take(10) {
234        eprintln!("{x:?}");
235    }
236
237    for x in bt.iter_dup().rev().take(10) {
238        eprintln!("{x:?}");
239    }
240}
241
242// ```
243// bt.insert(0);
244// bt.insert(0);
245// bt.remove(0);
246// bt.insert(1);
247// bt.insert(3);
248// bt.insert(3);
249// bt.insert(10);
250// bt.insert(14);
251// bt.insert(100);
252// bt.iter_dup().collect();    // {0, 1, 3, 3, 10, 14, 100}
253// bt.count(.., 3);            // 2
254// bt.count_3way(.., 3);       // {lt: 2, eq: 2, gt: 3}
255// bt.count(.., 3..=11);       // 3
256// bt.count_3way(.., 3..=11);  // {lt: 2, eq: 3, gt: 2}
257// bt.sum(.., 3..=11);         // 16
258// bt.sum_3way(.., 3..=11);    // {lt: 1, eq: 16, gt: 114}
259// bt.sum(.., ..);             // 131
260// bt.quantile(.., 4);         // 10
261// bt.quantile_sum(.., 4);     // 7
262// ```