nekolib/ds/
bicremental_median_dev.rs1use super::super::traits::binop;
4
5use std::collections::BTreeMap;
6use std::fmt::Debug;
7
8use binop::CommutativeGroup;
9
10#[derive(Clone, Debug, Eq, PartialEq)]
27pub struct BicrementalMedianDev<M: CommutativeGroup>
28where
29 M::Set: Ord + Clone,
30{
31 lower_sum: M::Set,
32 upper_sum: M::Set,
33 lower_len: usize,
34 upper_len: usize,
35 lower: BTreeMap<M::Set, usize>,
36 upper: BTreeMap<M::Set, usize>,
37 cgroup: M,
38}
39
40impl<M: CommutativeGroup> BicrementalMedianDev<M>
41where
42 M::Set: Ord + Clone,
43{
44 pub fn new() -> Self
45 where
46 M: Default,
47 {
48 Self::with(M::default())
49 }
50 pub fn with(cgroup: M) -> Self {
51 Self {
52 lower_sum: cgroup.id(),
53 upper_sum: cgroup.id(),
54 lower_len: 0,
55 upper_len: 0,
56 lower: BTreeMap::new(),
57 upper: BTreeMap::new(),
58 cgroup,
59 }
60 }
61 pub fn insert(&mut self, x: M::Set) {
62 if self.lower_len == 0 {
63 self.lower_sum = self.cgroup.op(self.lower_sum.clone(), x.clone());
64 self.lower.insert(x, 1);
65 self.lower_len += 1;
66 } else if self.lower_len == self.upper_len {
67 if &x <= self.upper.iter().next().unwrap().0 {
69 self.lower_sum =
71 self.cgroup.op(self.lower_sum.clone(), x.clone());
72 *self.lower.entry(x).or_insert(0) += 1;
73 } else {
74 self.upper_sum =
76 self.cgroup.op(self.upper_sum.clone(), x.clone());
77 self.rotate_to_lower();
78 *self.upper.entry(x).or_insert(0) += 1;
79 }
80 self.lower_len += 1;
81 } else {
82 if self.lower.iter().next_back().unwrap().0 < &x {
84 self.upper_sum =
86 self.cgroup.op(self.upper_sum.clone(), x.clone());
87 *self.upper.entry(x).or_insert(0) += 1;
88 } else {
89 self.lower_sum =
91 self.cgroup.op(self.lower_sum.clone(), x.clone());
92 self.rotate_to_upper();
93 *self.lower.entry(x).or_insert(0) += 1;
94 }
95 self.upper_len += 1;
96 }
97 }
98 pub fn remove(&mut self, x: M::Set) -> bool {
99 if self.lower_len == 0 {
100 false
101 } else if self.lower_len == self.upper_len {
102 if self.upper.contains_key(&x) {
104 self.remove_from_upper(x, false);
106 return true;
107 }
108 if self.lower.contains_key(&x) {
109 self.remove_from_lower(x, true);
111 return true;
112 }
113 false
114 } else {
115 if self.lower.contains_key(&x) {
117 self.remove_from_lower(x, false);
119 return true;
120 }
121 if self.upper.contains_key(&x) {
122 self.remove_from_upper(x, true);
124 return true;
125 }
126 false
127 }
128 }
129 pub fn median(&self) -> Option<&M::Set> {
130 if self.lower_len == 0 {
131 None
132 } else {
133 Some(self.lower.iter().next_back().unwrap().0)
134 }
135 }
136 pub fn median_dev(&self) -> M::Set {
137 if self.lower_len == 0 {
138 self.cgroup.id()
139 } else {
140 let diff = self.cgroup.op(
141 self.upper_sum.clone(),
142 self.cgroup.recip(self.lower_sum.clone()),
143 );
144 if self.lower_len == self.upper_len {
145 diff
146 } else {
147 self.cgroup.op(diff, self.median().unwrap().clone())
148 }
149 }
150 }
151}
152
153impl<M: CommutativeGroup> BicrementalMedianDev<M>
154where
155 M::Set: Ord + Clone,
156{
157 fn rotate_to_lower(&mut self) {
158 let (x, k) =
159 self.upper.iter().next().map(|(x, &k)| (x.clone(), k)).unwrap();
160 if k == 1 {
161 self.upper.remove(&x);
162 } else {
163 *self.upper.get_mut(&x).unwrap() -= 1;
164 }
165 self.upper_sum = self
166 .cgroup
167 .op(self.upper_sum.clone(), self.cgroup.recip(x.clone()));
168 self.lower_sum = self.cgroup.op(self.lower_sum.clone(), x.clone());
169 *self.lower.entry(x).or_insert(0) += 1;
170 }
171 fn rotate_to_upper(&mut self) {
172 let (x, k) = self
173 .lower
174 .iter()
175 .next_back()
176 .map(|(x, &k)| (x.clone(), k))
177 .unwrap();
178 if k == 1 {
179 self.lower.remove(&x);
180 } else {
181 *self.lower.get_mut(&x).unwrap() -= 1;
182 }
183 self.lower_sum = self
184 .cgroup
185 .op(self.lower_sum.clone(), self.cgroup.recip(x.clone()));
186 self.upper_sum = self.cgroup.op(self.upper_sum.clone(), x.clone());
187 *self.upper.entry(x).or_insert(0) += 1;
188 }
189 fn remove_from_lower(&mut self, x: M::Set, rotate: bool) {
190 if self.lower[&x] == 1 {
191 self.lower.remove(&x);
192 } else {
193 *self.lower.get_mut(&x).unwrap() -= 1;
194 }
195 self.lower_sum =
196 self.cgroup.op(self.lower_sum.clone(), self.cgroup.recip(x));
197 if rotate {
198 self.rotate_to_lower();
199 self.upper_len -= 1;
200 } else {
201 self.lower_len -= 1;
202 }
203 }
204 fn remove_from_upper(&mut self, x: M::Set, rotate: bool) {
205 if self.upper[&x] == 1 {
206 self.upper.remove(&x);
207 } else {
208 *self.upper.get_mut(&x).unwrap() -= 1;
209 }
210 self.upper_sum =
211 self.cgroup.op(self.upper_sum.clone(), self.cgroup.recip(x));
212 if rotate {
213 self.rotate_to_upper();
214 self.lower_len -= 1;
215 } else {
216 self.upper_len -= 1;
217 }
218 }
219}
220
221#[test]
222fn test_simple() {
223 use op_add::OpAdd;
224
225 let n = 32768;
226 let mut f =
227 std::iter::successors(Some(296), |&x| Some((x * 258 + 185) % 397))
228 .map(|x| x & 15);
229 let mut bucket = vec![0; 8];
230 let mut bm = BicrementalMedianDev::<OpAdd<i32>>::new();
231 for _ in 0..n {
232 let x = f.next().unwrap();
233 let (remove, x) = (x & 8 != 0, x & 7);
234 if remove && bucket[x as usize] > 0 {
235 bucket[x as usize] -= 1;
236 bm.remove(x);
237 } else {
238 bucket[x as usize] += 1;
239 bm.insert(x);
240 }
241 let mut naive = vec![];
242 for i in 0..8 {
243 naive.extend(std::iter::repeat(i as i32).take(bucket[i]));
244 }
245 assert_eq!(bm.median(), naive.get(naive.len().wrapping_sub(1) / 2));
246 let &median = bm.median().unwrap_or(&0);
247 eprintln!("{:?}", naive);
248 eprintln!("{:?}", bm);
249 let dev: i32 = naive.iter().map(|&x| (x - median).abs()).sum();
250 assert_eq!(bm.median_dev(), dev);
251 }
252}