use super::super::traits::binop;
use std::collections::BTreeMap;
use std::fmt::Debug;
use binop::CommutativeGroup;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct BicrementalMedianDev<M: CommutativeGroup>
where
M::Set: Ord + Clone,
{
lower_sum: M::Set,
upper_sum: M::Set,
lower_len: usize,
upper_len: usize,
lower: BTreeMap<M::Set, usize>,
upper: BTreeMap<M::Set, usize>,
cgroup: M,
}
impl<M: CommutativeGroup> BicrementalMedianDev<M>
where
M::Set: Ord + Clone,
{
pub fn new() -> Self
where
M: Default,
{
Self::with(M::default())
}
pub fn with(cgroup: M) -> Self {
Self {
lower_sum: cgroup.id(),
upper_sum: cgroup.id(),
lower_len: 0,
upper_len: 0,
lower: BTreeMap::new(),
upper: BTreeMap::new(),
cgroup,
}
}
pub fn insert(&mut self, x: M::Set) {
if self.lower_len == 0 {
self.lower_sum = self.cgroup.op(self.lower_sum.clone(), x.clone());
self.lower.insert(x, 1);
self.lower_len += 1;
} else if self.lower_len == self.upper_len {
if &x <= self.upper.iter().next().unwrap().0 {
self.lower_sum =
self.cgroup.op(self.lower_sum.clone(), x.clone());
*self.lower.entry(x).or_insert(0) += 1;
} else {
self.upper_sum =
self.cgroup.op(self.upper_sum.clone(), x.clone());
self.rotate_to_lower();
*self.upper.entry(x).or_insert(0) += 1;
}
self.lower_len += 1;
} else {
if self.lower.iter().next_back().unwrap().0 < &x {
self.upper_sum =
self.cgroup.op(self.upper_sum.clone(), x.clone());
*self.upper.entry(x).or_insert(0) += 1;
} else {
self.lower_sum =
self.cgroup.op(self.lower_sum.clone(), x.clone());
self.rotate_to_upper();
*self.lower.entry(x).or_insert(0) += 1;
}
self.upper_len += 1;
}
}
pub fn remove(&mut self, x: M::Set) -> bool {
if self.lower_len == 0 {
false
} else if self.lower_len == self.upper_len {
if self.upper.contains_key(&x) {
self.remove_from_upper(x, false);
return true;
}
if self.lower.contains_key(&x) {
self.remove_from_lower(x, true);
return true;
}
false
} else {
if self.lower.contains_key(&x) {
self.remove_from_lower(x, false);
return true;
}
if self.upper.contains_key(&x) {
self.remove_from_upper(x, true);
return true;
}
false
}
}
pub fn median(&self) -> Option<&M::Set> {
if self.lower_len == 0 {
None
} else {
Some(self.lower.iter().next_back().unwrap().0)
}
}
pub fn median_dev(&self) -> M::Set {
if self.lower_len == 0 {
self.cgroup.id()
} else {
let diff = self.cgroup.op(
self.upper_sum.clone(),
self.cgroup.recip(self.lower_sum.clone()),
);
if self.lower_len == self.upper_len {
diff
} else {
self.cgroup.op(diff, self.median().unwrap().clone())
}
}
}
}
impl<M: CommutativeGroup> BicrementalMedianDev<M>
where
M::Set: Ord + Clone,
{
fn rotate_to_lower(&mut self) {
let (x, k) =
self.upper.iter().next().map(|(x, &k)| (x.clone(), k)).unwrap();
if k == 1 {
self.upper.remove(&x);
} else {
*self.upper.get_mut(&x).unwrap() -= 1;
}
self.upper_sum = self
.cgroup
.op(self.upper_sum.clone(), self.cgroup.recip(x.clone()));
self.lower_sum = self.cgroup.op(self.lower_sum.clone(), x.clone());
*self.lower.entry(x).or_insert(0) += 1;
}
fn rotate_to_upper(&mut self) {
let (x, k) = self
.lower
.iter()
.next_back()
.map(|(x, &k)| (x.clone(), k))
.unwrap();
if k == 1 {
self.lower.remove(&x);
} else {
*self.lower.get_mut(&x).unwrap() -= 1;
}
self.lower_sum = self
.cgroup
.op(self.lower_sum.clone(), self.cgroup.recip(x.clone()));
self.upper_sum = self.cgroup.op(self.upper_sum.clone(), x.clone());
*self.upper.entry(x).or_insert(0) += 1;
}
fn remove_from_lower(&mut self, x: M::Set, rotate: bool) {
if self.lower[&x] == 1 {
self.lower.remove(&x);
} else {
*self.lower.get_mut(&x).unwrap() -= 1;
}
self.lower_sum =
self.cgroup.op(self.lower_sum.clone(), self.cgroup.recip(x));
if rotate {
self.rotate_to_lower();
self.upper_len -= 1;
} else {
self.lower_len -= 1;
}
}
fn remove_from_upper(&mut self, x: M::Set, rotate: bool) {
if self.upper[&x] == 1 {
self.upper.remove(&x);
} else {
*self.upper.get_mut(&x).unwrap() -= 1;
}
self.upper_sum =
self.cgroup.op(self.upper_sum.clone(), self.cgroup.recip(x));
if rotate {
self.rotate_to_upper();
self.lower_len -= 1;
} else {
self.upper_len -= 1;
}
}
}
#[test]
fn test_simple() {
use op_add::OpAdd;
let n = 32768;
let mut f =
std::iter::successors(Some(296), |&x| Some((x * 258 + 185) % 397))
.map(|x| x & 15);
let mut bucket = vec![0; 8];
let mut bm = BicrementalMedianDev::<OpAdd<i32>>::new();
for _ in 0..n {
let x = f.next().unwrap();
let (remove, x) = (x & 8 != 0, x & 7);
if remove && bucket[x as usize] > 0 {
bucket[x as usize] -= 1;
bm.remove(x);
} else {
bucket[x as usize] += 1;
bm.insert(x);
}
let mut naive = vec![];
for i in 0..8 {
naive.extend(std::iter::repeat(i as i32).take(bucket[i]));
}
assert_eq!(bm.median(), naive.get(naive.len().wrapping_sub(1) / 2));
let &median = bm.median().unwrap_or(&0);
eprintln!("{:?}", naive);
eprintln!("{:?}", bm);
let dev: i32 = naive.iter().map(|&x| (x - median).abs()).sum();
assert_eq!(bm.median_dev(), dev);
}
}