nekolib/ds/
binary_trie.rs1use std::fmt::{self, Debug};
2use std::ops::AddAssign;
3
4#[derive(Debug)]
5pub struct BinaryTrie<I> {
6 head: Link<I>,
7}
8
9type Link<I> = Option<Box<Node<I>>>;
10
11struct Node<I> {
12 sum0: usize,
13 sum1: I,
14 next: [Link<I>; 2],
15}
16
17impl<I: Debug> Debug for Node<I> {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 f.debug_map()
20 .entry(&"sum0", &self.sum0)
21 .entry(&"sum1", &self.sum1)
22 .entry(&"next", &[
23 self.next[0].as_ref().map(|_| ..),
24 self.next[1].as_ref().map(|_| ..),
25 ])
26 .finish()
27 }
28}
29
30impl<I: BinaryInt> BinaryTrie<I> {
31 pub fn new() -> Self { Self { head: None } }
32
33 pub fn insert(&mut self, elem: I) {
34 let mut cursor = &mut self.head;
35 for bit in elem.bits() {
36 let tmp = cursor.get_or_insert_with(|| Self::new_node());
37 tmp.sum0 += 1;
38 tmp.sum1 += elem;
39 cursor = &mut tmp.next[bit as usize];
40 }
41 let tmp = cursor.get_or_insert_with(|| Self::new_node());
42 tmp.sum0 += 1;
43 tmp.sum1 += elem;
44 }
45
46 pub fn iter(&self) -> Iter<'_, I> { Iter::new(&self) }
47
48 pub fn iter_dup(
49 &self,
50 ) -> impl '_ + Iterator<Item = I> + DoubleEndedIterator {
51 self.iter().flat_map(|(x, i)| (0..i).map(move |_| x))
52 }
53
54 fn new_node() -> Box<Node<I>> {
55 Box::new(Node { sum0: 0, sum1: I::zero(), next: [None, None] })
56 }
57}
58
59#[derive(Debug)]
60pub struct Iter<'a, I> {
61 left_path: Vec<(&'a Box<Node<I>>, usize)>,
63 right_path: Vec<(&'a Box<Node<I>>, usize)>,
64 left_int: I,
65 right_int: I,
66}
67
68impl<'a, I: BinaryInt> Iter<'a, I> {
69 fn new(trie: &'a BinaryTrie<I>) -> Iter<I> {
70 let (left_path, left_int) = Self::descend(trie, 0);
71 let (right_path, right_int) = Self::descend(trie, 1);
72 Self { left_path, left_int, right_path, right_int }
73 }
74
75 fn descend(
76 trie: &'a BinaryTrie<I>,
77 fst: usize,
78 ) -> (Vec<(&Box<Node<I>>, usize)>, I) {
79 let mut int = I::zero();
80 let cursor = trie.head.as_ref();
81 let mut path = vec![];
82 Self::descend_inner(cursor, 0, &mut path, &mut int, fst);
83 (path, int)
84 }
85
86 fn descend_inner(
87 mut cursor: Option<&'a Box<Node<I>>>,
88 mut dir: usize,
89 path: &mut Vec<(&'a Box<Node<I>>, usize)>,
90 int: &mut I,
91 fst: usize,
92 ) {
93 while let Some(next) = cursor {
94 path.push((next, dir));
95 if let Some(fst_path) = &next.next[fst] {
96 int.push(fst != 0);
97 cursor = Some(&fst_path);
98 dir = fst;
99 } else if let Some(snd_path) = &next.next[fst ^ 1] {
100 int.push((fst ^ 1) != 0);
101 cursor = Some(&snd_path);
102 dir = fst ^ 1;
103 } else {
104 break;
105 }
106 }
107 }
108
109 fn next_dir(&mut self, dir: usize) -> Option<(I, usize)> {
110 let Self { left_int, left_path, right_int, right_path } = self;
113 if left_path.is_empty() {
114 return None;
115 }
116 let res = if dir == 0 { *left_int } else { *right_int };
117 let count = {
118 let path_last = if dir == 0 {
119 left_path.last().unwrap()
120 } else {
121 right_path.last().unwrap()
122 };
123 path_last.0.sum0
124 };
125
126 if left_int == right_int {
127 left_path.clear();
128 right_path.clear();
129 }
130
131 let path = if dir == 0 { left_path } else { right_path };
132 let int = if dir == 0 { left_int } else { right_int };
133
134 let mut last_dir = dir ^ 1;
135 while let Some((node, cur_dir)) = path.pop() {
136 if let Some(next) = &node.next[dir ^ 1] {
137 if last_dir == dir {
138 path.push((node, cur_dir));
139 int.push(dir == 0);
140 Self::descend_inner(Some(next), dir ^ 1, path, int, dir);
141 break;
142 }
143 }
144
145 int.pop();
146 last_dir = cur_dir;
147 }
148 Some((res, count))
149 }
150}
151
152impl<I: BinaryInt> Iterator for Iter<'_, I> {
153 type Item = (I, usize);
154 fn next(&mut self) -> Option<Self::Item> { self.next_dir(0) }
155}
156
157impl<I: BinaryInt> DoubleEndedIterator for Iter<'_, I> {
158 fn next_back(&mut self) -> Option<Self::Item> { self.next_dir(1) }
159}
160
161pub trait BinaryInt: Copy + AddAssign<Self> + Eq + Debug {
162 fn zero() -> Self;
163 fn bits(self) -> Bits<Self>;
164 fn test(self, shift: u32) -> bool;
165 fn push(&mut self, bit: bool);
166 fn pop(&mut self);
167}
168
169pub struct Bits<I> {
170 val: I,
171 shift: u32,
172}
173
174impl<I: BinaryInt> Iterator for Bits<I> {
175 type Item = bool;
176 fn next(&mut self) -> Option<bool> {
177 if self.shift == 0 {
178 return None;
179 }
180 self.shift -= 1;
181 Some(self.val.test(self.shift))
182 }
183}
184
185macro_rules! impl_uint {
186 ( $($ty:ty)* ) => { $(
187 impl BinaryInt for $ty {
188 fn zero() -> Self { 0 }
189 fn bits(self) -> Bits<Self> {
190 let bits = (0 as $ty).count_zeros();
191 Bits { val: self, shift: bits }
192 }
193 fn test(self, shift: u32) -> bool {
194 self >> shift & 1 != 0
195 }
196 fn push(&mut self, bit: bool) {
197 *self <<= 1;
198 *self |= bit as $ty;
199 }
200 fn pop(&mut self) {
201 *self >>= 1;
202 }
203 }
204 )* }
205}
206
207impl_uint! { u8 u16 u32 u64 u128 usize }
208
209#[test]
210fn test() {
211 let mut bt = BinaryTrie::<u8>::new();
212 bt.insert(10);
214 bt.insert(3);
216 bt.insert(1);
218 bt.insert(3);
220 bt.insert(0);
222 bt.insert(14);
223 bt.insert(100);
224
225 for x in bt.iter().take(10) {
226 eprintln!("{x:?}");
227 }
228
229 for x in bt.iter().rev().take(10) {
230 eprintln!("{x:?}");
231 }
232
233 for x in bt.iter_dup().take(10) {
234 eprintln!("{x:?}");
235 }
236
237 for x in bt.iter_dup().rev().take(10) {
238 eprintln!("{x:?}");
239 }
240}
241
242