Skip to main content

nekolib/ds/
bicremental_median_dev.rs

1//! 中央値と偏差の管理。
2
3use super::super::traits::binop;
4
5use std::collections::BTreeMap;
6use std::fmt::Debug;
7
8use binop::CommutativeGroup;
9
10/// 中央値と偏差の管理。
11///
12/// 多重集合 $S$ への要素の追加と削除を行いつつ、中央値と偏差を管理する。
13/// ここで、偏差は中央値との絶対値の総和とする。
14/// $S$ のうち小さい方から $\\lceil|S|/2\\rceil$ 個取り出したものを $L$、
15/// 残りの $\\lfloor|S|/2\\rfloor$ 個取り出したものを $R$ とする。
16/// 中央値を $L$ の最大値として定義し、$a\_{\\text{med}}$ と書く。
17/// このとき、偏差 $\\sigma$ は次のように書ける。
18/// $$ \\sigma = \\sum\_{a\\in L} (a\_{\\text{med}}-a) + \\sum\_{a\\in R} (a-a\_{\\text{med}}). $$
19/// $\\sum\_{a\\in L} a = \\sigma\_L$、$\\sum\_{a\\in R} a = \\sigma\_R$ とすると、
20/// $$ \\begin{aligned}
21/// \\sigma &= \\begin{cases}
22/// -\\sigma\_L + \\sigma\_R, & \\text{if } |L| = |R|; \\\\
23/// -\\sigma\_L + \\sigma\_R + a\_{\\text{med}}, & \\text{if } |L| = |R|+1. \\\\
24/// \\end{cases}
25/// \\end{aligned} $$
26#[derive(Clone, Debug, Eq, PartialEq)]
27pub struct BicrementalMedianDev<M: CommutativeGroup>
28where
29    M::Set: Ord + Clone,
30{
31    lower_sum: M::Set,
32    upper_sum: M::Set,
33    lower_len: usize,
34    upper_len: usize,
35    lower: BTreeMap<M::Set, usize>,
36    upper: BTreeMap<M::Set, usize>,
37    cgroup: M,
38}
39
40impl<M: CommutativeGroup> BicrementalMedianDev<M>
41where
42    M::Set: Ord + Clone,
43{
44    pub fn new() -> Self
45    where
46        M: Default,
47    {
48        Self::with(M::default())
49    }
50    pub fn with(cgroup: M) -> Self {
51        Self {
52            lower_sum: cgroup.id(),
53            upper_sum: cgroup.id(),
54            lower_len: 0,
55            upper_len: 0,
56            lower: BTreeMap::new(),
57            upper: BTreeMap::new(),
58            cgroup,
59        }
60    }
61    pub fn insert(&mut self, x: M::Set) {
62        if self.lower_len == 0 {
63            self.lower_sum = self.cgroup.op(self.lower_sum.clone(), x.clone());
64            self.lower.insert(x, 1);
65            self.lower_len += 1;
66        } else if self.lower_len == self.upper_len {
67            // [LLL] X [RRR]
68            if &x <= self.upper.iter().next().unwrap().0 {
69                // [LLXL] [RRR]
70                self.lower_sum =
71                    self.cgroup.op(self.lower_sum.clone(), x.clone());
72                *self.lower.entry(x).or_insert(0) += 1;
73            } else {
74                // [LLLR] [RRX]
75                self.upper_sum =
76                    self.cgroup.op(self.upper_sum.clone(), x.clone());
77                self.rotate_to_lower();
78                *self.upper.entry(x).or_insert(0) += 1;
79            }
80            self.lower_len += 1;
81        } else {
82            // [LLL] X [RR]
83            if self.lower.iter().next_back().unwrap().0 < &x {
84                // [LLL] [RXR]
85                self.upper_sum =
86                    self.cgroup.op(self.upper_sum.clone(), x.clone());
87                *self.upper.entry(x).or_insert(0) += 1;
88            } else {
89                // [XLL] [LRR]
90                self.lower_sum =
91                    self.cgroup.op(self.lower_sum.clone(), x.clone());
92                self.rotate_to_upper();
93                *self.lower.entry(x).or_insert(0) += 1;
94            }
95            self.upper_len += 1;
96        }
97    }
98    pub fn remove(&mut self, x: M::Set) -> bool {
99        if self.lower_len == 0 {
100            false
101        } else if self.lower_len == self.upper_len {
102            // [LLL] [RRR]
103            if self.upper.contains_key(&x) {
104                // [LLL] [RR]
105                self.remove_from_upper(x, false);
106                return true;
107            }
108            if self.lower.contains_key(&x) {
109                // [LLR] [RR]
110                self.remove_from_lower(x, true);
111                return true;
112            }
113            false
114        } else {
115            // [LLL] [RR]
116            if self.lower.contains_key(&x) {
117                // [LL] [RR]
118                self.remove_from_lower(x, false);
119                return true;
120            }
121            if self.upper.contains_key(&x) {
122                // [LL] [LR]
123                self.remove_from_upper(x, true);
124                return true;
125            }
126            false
127        }
128    }
129    pub fn median(&self) -> Option<&M::Set> {
130        if self.lower_len == 0 {
131            None
132        } else {
133            Some(self.lower.iter().next_back().unwrap().0)
134        }
135    }
136    pub fn median_dev(&self) -> M::Set {
137        if self.lower_len == 0 {
138            self.cgroup.id()
139        } else {
140            let diff = self.cgroup.op(
141                self.upper_sum.clone(),
142                self.cgroup.recip(self.lower_sum.clone()),
143            );
144            if self.lower_len == self.upper_len {
145                diff
146            } else {
147                self.cgroup.op(diff, self.median().unwrap().clone())
148            }
149        }
150    }
151}
152
153impl<M: CommutativeGroup> BicrementalMedianDev<M>
154where
155    M::Set: Ord + Clone,
156{
157    fn rotate_to_lower(&mut self) {
158        let (x, k) =
159            self.upper.iter().next().map(|(x, &k)| (x.clone(), k)).unwrap();
160        if k == 1 {
161            self.upper.remove(&x);
162        } else {
163            *self.upper.get_mut(&x).unwrap() -= 1;
164        }
165        self.upper_sum = self
166            .cgroup
167            .op(self.upper_sum.clone(), self.cgroup.recip(x.clone()));
168        self.lower_sum = self.cgroup.op(self.lower_sum.clone(), x.clone());
169        *self.lower.entry(x).or_insert(0) += 1;
170    }
171    fn rotate_to_upper(&mut self) {
172        let (x, k) = self
173            .lower
174            .iter()
175            .next_back()
176            .map(|(x, &k)| (x.clone(), k))
177            .unwrap();
178        if k == 1 {
179            self.lower.remove(&x);
180        } else {
181            *self.lower.get_mut(&x).unwrap() -= 1;
182        }
183        self.lower_sum = self
184            .cgroup
185            .op(self.lower_sum.clone(), self.cgroup.recip(x.clone()));
186        self.upper_sum = self.cgroup.op(self.upper_sum.clone(), x.clone());
187        *self.upper.entry(x).or_insert(0) += 1;
188    }
189    fn remove_from_lower(&mut self, x: M::Set, rotate: bool) {
190        if self.lower[&x] == 1 {
191            self.lower.remove(&x);
192        } else {
193            *self.lower.get_mut(&x).unwrap() -= 1;
194        }
195        self.lower_sum =
196            self.cgroup.op(self.lower_sum.clone(), self.cgroup.recip(x));
197        if rotate {
198            self.rotate_to_lower();
199            self.upper_len -= 1;
200        } else {
201            self.lower_len -= 1;
202        }
203    }
204    fn remove_from_upper(&mut self, x: M::Set, rotate: bool) {
205        if self.upper[&x] == 1 {
206            self.upper.remove(&x);
207        } else {
208            *self.upper.get_mut(&x).unwrap() -= 1;
209        }
210        self.upper_sum =
211            self.cgroup.op(self.upper_sum.clone(), self.cgroup.recip(x));
212        if rotate {
213            self.rotate_to_upper();
214            self.lower_len -= 1;
215        } else {
216            self.upper_len -= 1;
217        }
218    }
219}
220
221#[test]
222fn test_simple() {
223    use op_add::OpAdd;
224
225    let n = 32768;
226    let mut f =
227        std::iter::successors(Some(296), |&x| Some((x * 258 + 185) % 397))
228            .map(|x| x & 15);
229    let mut bucket = vec![0; 8];
230    let mut bm = BicrementalMedianDev::<OpAdd<i32>>::new();
231    for _ in 0..n {
232        let x = f.next().unwrap();
233        let (remove, x) = (x & 8 != 0, x & 7);
234        if remove && bucket[x as usize] > 0 {
235            bucket[x as usize] -= 1;
236            bm.remove(x);
237        } else {
238            bucket[x as usize] += 1;
239            bm.insert(x);
240        }
241        let mut naive = vec![];
242        for i in 0..8 {
243            naive.extend(std::iter::repeat(i as i32).take(bucket[i]));
244        }
245        assert_eq!(bm.median(), naive.get(naive.len().wrapping_sub(1) / 2));
246        let &median = bm.median().unwrap_or(&0);
247        eprintln!("{:?}", naive);
248        eprintln!("{:?}", bm);
249        let dev: i32 = naive.iter().map(|&x| (x - median).abs()).sum();
250        assert_eq!(bm.median_dev(), dev);
251    }
252}