wavelet_matrix/
lib.rs

1#![allow(dead_code)]
2
3use std::ops::{Range, RangeBounds, RangeInclusive};
4
5use rs01_dict::Rs01Dict;
6use usize_bounds::UsizeBounds;
7
8pub struct WaveletMatrix<I> {
9    len: usize,
10    bitlen: usize,
11    buf: Vec<Rs01Dict>,
12    zeros: Vec<usize>,
13    orig: Vec<I>,
14}
15
16#[derive(Clone, Copy, Debug, Eq, PartialEq)]
17pub struct Count3wayResult {
18    lt: usize,
19    eq: usize,
20    gt: usize,
21}
22
23impl Count3wayResult {
24    fn new(lt: usize, eq: usize, gt: usize) -> Self { Self { lt, eq, gt } }
25    pub fn lt(self) -> usize { self.lt }
26    pub fn le(self) -> usize { self.lt + self.eq }
27    pub fn eq(self) -> usize { self.eq }
28    pub fn ge(self) -> usize { self.eq + self.gt }
29    pub fn gt(self) -> usize { self.gt }
30    pub fn ne(self) -> usize { self.gt + self.lt }
31}
32
33impl<I: WmInt> From<Vec<I>> for WaveletMatrix<I> {
34    fn from(orig: Vec<I>) -> Self {
35        let len = orig.len();
36        let bitlen =
37            orig.iter().map(|ai| ai.bitlen()).max().unwrap_or(0) as usize;
38        let mut whole = orig.clone();
39        let mut zeros = vec![0; bitlen];
40        let mut buf = vec![];
41        for i in (0..bitlen).rev() {
42            let mut zero = vec![];
43            let mut one = vec![];
44            let mut vb = vec![false; len];
45            for (j, aj) in whole.into_iter().enumerate() {
46                (if aj.test(i) { &mut one } else { &mut zero }).push(aj);
47                vb[j] = aj.test(i);
48            }
49            zeros[i] = zero.len();
50            buf.push(Rs01Dict::new(&vb));
51            whole = zero;
52            whole.append(&mut one);
53        }
54        buf.reverse();
55        Self { len, bitlen, buf, zeros, orig }
56    }
57}
58
59impl<I: WmInt> WaveletMatrix<I> {
60    pub fn count<R: WmIntRange<Int = I>>(
61        &self,
62        range: impl RangeBounds<usize>,
63        value: R,
64    ) -> usize {
65        self.count_3way(range, value).eq()
66    }
67    pub fn count_3way<R: WmIntRange<Int = I>>(
68        &self,
69        range: impl RangeBounds<usize>,
70        value: R,
71    ) -> Count3wayResult {
72        let Range { start: il, end: ir } = range.to_range(self.len);
73        let value = value.to_inclusive_range();
74        let vl = *value.start();
75        let vr = *value.end();
76        let (lt, gt) = if vl == vr {
77            self.count_3way_internal(il..ir, vl)
78        } else {
79            let lt = self.count_3way_internal(il..ir, vl).0;
80            let gt = self.count_3way_internal(il..ir, vr).1;
81            (lt, gt)
82        };
83        let eq = (ir - il) - (lt + gt);
84        Count3wayResult::new(lt, eq, gt)
85    }
86    fn count_3way_internal(
87        &self,
88        Range { mut start, mut end }: Range<usize>,
89        value: I,
90    ) -> (usize, usize) {
91        if start == end {
92            return (0, 0);
93        }
94        if value.bitlen() > self.bitlen {
95            return (end - start, 0);
96        }
97        let mut lt = 0;
98        let mut gt = 0;
99        for i in (0..self.bitlen).rev() {
100            let tmp = end - start;
101            if !value.test(i) {
102                start = self.buf[i].count0(..start);
103                end = self.buf[i].count0(..end);
104            } else {
105                start = self.zeros[i] + self.buf[i].count1(..start);
106                end = self.zeros[i] + self.buf[i].count1(..end);
107            }
108            let len = end - start;
109            *(if value.test(i) { &mut lt } else { &mut gt }) += tmp - len;
110        }
111        (lt, gt)
112    }
113
114    pub fn quantile(
115        &self,
116        range: impl RangeBounds<usize>,
117        mut n: usize,
118    ) -> Option<I> {
119        let Range { mut start, mut end } = range.to_range(self.len);
120        if end - start <= n {
121            return None;
122        }
123        let mut res = I::zero();
124        for i in (0..self.bitlen).rev() {
125            let z = self.buf[i].count0(start..end);
126            if n < z {
127                start = self.buf[i].count0(..start);
128                end = self.buf[i].count0(..end);
129            } else {
130                res.set(i);
131                start = self.zeros[i] + self.buf[i].count1(..start);
132                end = self.zeros[i] + self.buf[i].count1(..end);
133                n -= z;
134            }
135        }
136        Some(res)
137    }
138}
139
140pub trait WmInt: Copy + Eq {
141    fn test(self, i: usize) -> bool;
142    fn set(&mut self, i: usize);
143    fn bitlen(self) -> usize;
144    fn zero() -> Self;
145}
146
147pub trait WmIntRange {
148    type Int;
149    fn to_inclusive_range(self) -> RangeInclusive<Self::Int>;
150}
151
152macro_rules! impl_uint {
153    ( $($ty:ty)* ) => { $(
154        impl WmInt for $ty {
155            fn test(self, i: usize) -> bool { self >> i & 1 != 0 }
156            fn set(&mut self, i: usize) { *self |= 1 << i; }
157            fn bitlen(self) -> usize {
158                let bits = <$ty>::BITS;
159                (if self == 0 { 1 } else { bits - self.leading_zeros() }) as _
160            }
161            fn zero() -> $ty { 0 }
162        }
163        impl WmIntRange for $ty {
164            type Int = $ty;
165            fn to_inclusive_range(self) -> RangeInclusive<$ty> { self..=self }
166        }
167        impl WmIntRange for RangeInclusive<$ty> {
168            type Int = $ty;
169            fn to_inclusive_range(self) -> RangeInclusive<$ty> { self }
170        }
171    )* }
172}
173
174impl_uint! { u8 u16 u32 u64 u128 usize }