Skip to main content

nekolib/ds/
wavelet_matrix.rs

1//! wavelet matrix。
2
3use super::rs_dict;
4use super::super::traits::count;
5use super::super::traits::find_nth;
6use super::super::traits::quantile;
7use super::super::utils::buf_range;
8
9use std::ops::{Index, Range, RangeBounds, RangeInclusive};
10
11use buf_range::bounds_within;
12use count::{Count, Count3way, Count3wayResult};
13use find_nth::FindNth;
14use quantile::Quantile;
15use rs_dict::RsDict;
16
17/// wavelet matrix。
18///
19/// 整数に関する多くの区間クエリを処理できる。
20///
21/// # Examples
22/// ```
23/// use nekolib::ds::WaveletMatrix;
24/// use nekolib::traits::{Count3way, FindNth, Quantile};
25///
26/// let wm: WaveletMatrix<u32> = vec![1, 8, 4, 9, 2, 7, 5, 2].into();
27///
28/// assert_eq!(wm.count_3way(2.., 5).lt(), 3); // [4, _, 2, _, _, _, 2]
29///
30/// let c3 = wm.count_3way(..6, 2..=7); // [1, 8, 4, 9, 2, 7]
31/// assert_eq!(c3.lt(), 1); // [1, _, _, _, _, _]
32/// assert_eq!(c3.eq(), 3); // [_, _, 4, _, 2, 7]
33/// assert_eq!(c3.gt(), 2); // [_, 8, _, 9, _, _]
34/// assert_eq!(c3.le(), 4);
35/// assert_eq!(c3.ge(), 5);
36///
37/// assert_eq!(wm.quantile(2..=4, 0), Some(2)); // [_, _, 2]
38/// assert_eq!(wm.quantile(2..=4, 1), Some(4)); // [4, _, _]
39/// assert_eq!(wm.quantile(2..=4, 2), Some(9)); // [_, 9, _]
40/// assert_eq!(wm.quantile(2..=4, 3), None);
41///
42/// assert_eq!(wm.find_nth(3.., 2, 0), Some(4));
43/// assert_eq!(wm.find_nth(3.., 2, 1), Some(7));
44/// assert_eq!(wm.find_nth(4.., 2, 0), Some(4));
45/// assert_eq!(wm.find_nth(5.., 2, 0), Some(7));
46/// assert_eq!(wm.find_nth(5.., 2, 1), None);
47/// ```
48pub struct WaveletMatrix<I> {
49    len: usize,
50    bitlen: usize,
51    buf: Vec<RsDict>,
52    zeros: Vec<usize>,
53    orig: Vec<I>,
54}
55
56impl<I: WmInt> From<Vec<I>> for WaveletMatrix<I> {
57    fn from(orig: Vec<I>) -> Self {
58        let len = orig.len();
59        let bitlen =
60            orig.iter().map(|ai| ai.bitlen()).max().unwrap_or(0) as usize;
61        let mut whole = orig.clone();
62        let mut zeros = vec![0; bitlen];
63        let mut buf = vec![];
64        for i in (0..bitlen).rev() {
65            let mut zero = vec![];
66            let mut one = vec![];
67            let mut vb = vec![false; len];
68            for (j, aj) in whole.into_iter().enumerate() {
69                (if aj.test(i) { &mut one } else { &mut zero }).push(aj);
70                vb[j] = aj.test(i);
71            }
72            zeros[i] = zero.len();
73            buf.push(vb.into());
74            whole = zero;
75            whole.append(&mut one);
76        }
77        buf.reverse();
78        Self { len, bitlen, buf, zeros, orig }
79    }
80}
81
82impl<I: WmInt> Count<I> for WaveletMatrix<I> {
83    fn count(&self, range: impl RangeBounds<usize>, value: I) -> usize {
84        self.count_3way(range, value).eq()
85    }
86}
87
88impl<I: WmInt> Count<RangeInclusive<I>> for WaveletMatrix<I> {
89    fn count(
90        &self,
91        range: impl RangeBounds<usize>,
92        value: RangeInclusive<I>,
93    ) -> usize {
94        self.count_3way(range, value).eq()
95    }
96}
97
98impl<I: WmInt> Count3way<I> for WaveletMatrix<I> {
99    fn count_3way(
100        &self,
101        range: impl RangeBounds<usize>,
102        value: I,
103    ) -> Count3wayResult {
104        let Range { start, end } = bounds_within(range, self.len);
105        let (lt, gt) = self.count_3way_internal(start..end, value);
106        let eq = (end - start) - (lt + gt);
107        Count3wayResult::new(lt, eq, gt)
108    }
109}
110
111impl<I: WmInt> Count3way<RangeInclusive<I>> for WaveletMatrix<I> {
112    fn count_3way(
113        &self,
114        range: impl RangeBounds<usize>,
115        value: RangeInclusive<I>,
116    ) -> Count3wayResult {
117        let Range { start: il, end: ir } = bounds_within(range, self.len);
118        let vl = *value.start();
119        let vr = *value.end();
120        let lt = self.count_3way_internal(il..ir, vl).0;
121        let gt = self.count_3way_internal(il..ir, vr).1;
122        let eq = (ir - il) - (lt + gt);
123        Count3wayResult::new(lt, eq, gt)
124    }
125}
126
127impl<I: WmInt> WaveletMatrix<I> {
128    fn count_3way_internal(
129        &self,
130        Range { mut start, mut end }: Range<usize>,
131        value: I,
132    ) -> (usize, usize) {
133        if start == end {
134            return (0, 0);
135        }
136        if value.bitlen() > self.bitlen {
137            return (end - start, 0);
138        }
139        let mut lt = 0;
140        let mut gt = 0;
141        for i in (0..self.bitlen).rev() {
142            let tmp = end - start;
143            if !value.test(i) {
144                start = self.buf[i].rank(start, 0);
145                end = self.buf[i].rank(end, 0);
146            } else {
147                start = self.zeros[i] + self.buf[i].rank(start, 1);
148                end = self.zeros[i] + self.buf[i].rank(end, 1);
149            }
150            *(if value.test(i) { &mut lt } else { &mut gt }) +=
151                tmp - (end - start);
152        }
153        (lt, gt)
154    }
155}
156
157impl<I: WmInt> Quantile for WaveletMatrix<I> {
158    type Output = I;
159    fn quantile(
160        &self,
161        range: impl RangeBounds<usize>,
162        mut n: usize,
163    ) -> Option<I> {
164        let Range { mut start, mut end } = bounds_within(range, self.len);
165        if end - start <= n {
166            return None;
167        }
168        let mut res = I::zero();
169        for i in (0..self.bitlen).rev() {
170            let z = self.buf[i].count(start..end, 0);
171            if n < z {
172                start = self.buf[i].rank(start, 0);
173                end = self.buf[i].rank(end, 0);
174            } else {
175                res.set(i);
176                start = self.zeros[i] + self.buf[i].rank(start, 1);
177                end = self.zeros[i] + self.buf[i].rank(end, 1);
178                n -= z;
179            }
180        }
181        Some(res)
182    }
183}
184
185impl<I: WmInt> WaveletMatrix<I> {
186    pub fn xored_quantile(
187        &self,
188        range: impl RangeBounds<usize>,
189        mut n: usize,
190        x: I,
191    ) -> Option<I> {
192        let Range { mut start, mut end } = bounds_within(range, self.len);
193        if end - start <= n {
194            return None;
195        }
196        let mut res = I::zero();
197        for i in (0..self.bitlen).rev() {
198            let z = self.buf[i].count(start..end, 0);
199            if !x.test(i) {
200                if n < z {
201                    start = self.buf[i].rank(start, 0);
202                    end = self.buf[i].rank(end, 0);
203                } else {
204                    res.set(i);
205                    start = self.zeros[i] + self.buf[i].rank(start, 1);
206                    end = self.zeros[i] + self.buf[i].rank(end, 1);
207                    n -= z;
208                }
209            } else {
210                let z = (end - start) - z;
211                if n < z {
212                    start = self.zeros[i] + self.buf[i].rank(start, 1);
213                    end = self.zeros[i] + self.buf[i].rank(end, 1);
214                } else {
215                    res.set(i);
216                    start = self.buf[i].rank(start, 0);
217                    end = self.buf[i].rank(end, 0);
218                    n -= z;
219                }
220            }
221        }
222        Some(res)
223    }
224}
225
226impl<I: WmInt> FindNth<I> for WaveletMatrix<I> {
227    fn find_nth(
228        &self,
229        range: impl RangeBounds<usize>,
230        value: I,
231        n: usize,
232    ) -> Option<usize> {
233        let start = bounds_within(range, self.len).start;
234        let (lt, gt) = self.count_3way_internal(0..start, value);
235        let offset = start - (lt + gt);
236        Some(self.select(value, n + offset + 1)? - 1)
237    }
238}
239
240impl<I: WmInt> WaveletMatrix<I> {
241    pub fn len(&self) -> usize { self.len }
242    pub fn is_empty(&self) -> bool { self.len == 0 }
243
244    pub fn rank(&self, end: usize, value: I) -> usize {
245        self.count(0..end, value)
246    }
247    pub fn select(&self, value: I, mut n: usize) -> Option<usize> {
248        if n == 0 {
249            return Some(0);
250        }
251        let (lt, gt) = self.count_3way_internal(0..self.len, value);
252        let count = self.len - (lt + gt);
253        if count < n {
254            return None;
255        }
256        let si = self.start_pos(value);
257        let value0 = value.test(0) as u64;
258        n += self.buf[0].rank(si, value0);
259        n = self.buf[0].select(value0, n).unwrap();
260
261        for i in 1..self.bitlen {
262            if !value.test(i) {
263                n = self.buf[i].select(0, n).unwrap();
264            } else {
265                n -= self.zeros[i];
266                n = self.buf[i].select(1, n).unwrap();
267            }
268        }
269        Some(n)
270    }
271    fn start_pos(&self, value: I) -> usize {
272        let mut start = 0;
273        let mut end = 0;
274        for i in (1..self.bitlen).rev() {
275            if !value.test(i) {
276                start = self.buf[i].rank(start, 0);
277                end = self.buf[i].rank(end, 0);
278            } else {
279                start = self.zeros[i] + self.buf[i].rank(start, 1);
280                end = self.zeros[i] + self.buf[i].rank(end, 1);
281            }
282        }
283        start
284    }
285}
286
287impl<I: WmInt> Index<usize> for WaveletMatrix<I> {
288    type Output = I;
289    fn index(&self, i: usize) -> &I { &self.orig[i] }
290}
291
292pub trait WmInt: Copy {
293    fn test(self, i: usize) -> bool;
294    fn set(&mut self, i: usize);
295    fn bitlen(self) -> usize;
296    fn zero() -> Self;
297}
298
299macro_rules! impl_wm_int {
300    ( $( $ty:ty )* ) => { $(
301        impl WmInt for $ty {
302            fn test(self, i: usize) -> bool { self >> i & 1 != 0 }
303            fn set(&mut self, i: usize) { *self |= 1 << i; }
304            fn bitlen(self) -> usize {
305                let w = (0 as $ty).count_zeros() as usize;
306                if self.test(w - 1) {
307                    w
308                } else {
309                    (self + 1).next_power_of_two().trailing_zeros() as usize
310                }
311            }
312            fn zero() -> $ty { 0 }
313        }
314    )* };
315}
316
317impl_wm_int! { u8 u16 u32 u64 u128 usize }
318
319#[test]
320fn test_simple() {
321    let n = 300;
322    let f = std::iter::successors(Some(296), |&x| Some((x * 258 + 185) % 397))
323        .map(|x| x & 7);
324    let buf: Vec<_> = f.take(n).collect();
325    let wm: WaveletMatrix<u32> = buf.clone().into();
326    for start in 0..n {
327        let mut count = vec![0; 8];
328        for end in start..=n {
329            for xl in 0..=7 {
330                for xr in xl..=7 {
331                    let lt: usize = count[..xl as usize].iter().sum();
332                    let gt: usize = count[xr as usize + 1..].iter().sum();
333                    let eq = (end - start) - (lt + gt);
334                    let c3 = Count3wayResult::new(lt, eq, gt);
335                    assert_eq!(wm.count_3way(start..end, xl..=xr), c3);
336                }
337
338                let lt: usize = count[..xl as usize].iter().sum();
339                let gt: usize = count[xl as usize + 1..].iter().sum();
340                let eq = (end - start) - (lt + gt);
341                let c3 = Count3wayResult::new(lt, eq, gt);
342                assert_eq!(wm.count(start..end, xl), eq);
343                assert_eq!(wm.count(start..end, xl..=xl), eq);
344                assert_eq!(wm.count_3way(start..end, xl), c3);
345                assert_eq!(wm.count_3way(start..end, xl..=xl), c3);
346            }
347
348            if end < n {
349                count[buf[end] as usize] += 1;
350            }
351        }
352    }
353
354    for start in 0..n {
355        let mut count = vec![0; 8];
356        for end in start..n {
357            let x = buf[end];
358            assert_eq!(wm.find_nth(start.., x, count[x as usize]), Some(end));
359            count[x as usize] += 1;
360        }
361        for x in 0..8 {
362            assert_eq!(wm.find_nth(start.., x, count[x as usize]), None);
363        }
364    }
365
366    for start in 0..n {
367        for end in start..n {
368            let mut tmp = buf[start..end].to_vec();
369            tmp.sort_unstable();
370            for i in 0..tmp.len() {
371                assert_eq!(wm.quantile(start..end, i), Some(tmp[i]));
372            }
373            assert_eq!(wm.quantile(start..end, tmp.len()), None);
374        }
375    }
376
377    for start in 0..n {
378        for end in start..n {
379            for x in 0..8 {
380                let mut tmp: Vec<_> =
381                    buf[start..end].iter().map(|&y| x ^ y).collect();
382                tmp.sort_unstable();
383                for i in 0..tmp.len() {
384                    assert_eq!(
385                        wm.xored_quantile(start..end, i, x),
386                        Some(tmp[i])
387                    );
388                }
389                assert_eq!(wm.xored_quantile(start..end, tmp.len(), x), None);
390            }
391        }
392    }
393}
394
395#[test]
396fn test_count() {
397    let n = 8;
398    let c3 = |lt, eq, gt| Count3wayResult::new(lt, eq, gt);
399
400    let zero: WaveletMatrix<u8> = vec![0; n].into();
401    assert_eq!(zero.count_3way(.., 0), c3(0, n, 0));
402    assert_eq!(zero.count_3way(.., 0..=0), c3(0, n, 0));
403    assert_eq!(zero.count_3way(.., 1), c3(n, 0, 0));
404    assert_eq!(zero.count_3way(.., 1..=1), c3(n, 0, 0));
405    assert_eq!(zero.count_3way(.., 254), c3(n, 0, 0));
406    assert_eq!(zero.count_3way(.., 254..=254), c3(n, 0, 0));
407    assert_eq!(zero.count_3way(.., 255), c3(n, 0, 0));
408    assert_eq!(zero.count_3way(.., 255..=255), c3(n, 0, 0));
409
410    let full: WaveletMatrix<u8> = vec![!0; n].into();
411    assert_eq!(full.count_3way(.., 0), c3(0, 0, n));
412    assert_eq!(full.count_3way(.., 0..=0), c3(0, 0, n));
413    assert_eq!(full.count_3way(.., 1), c3(0, 0, n));
414    assert_eq!(full.count_3way(.., 1..=1), c3(0, 0, n));
415    assert_eq!(full.count_3way(.., 254), c3(0, 0, n));
416    assert_eq!(full.count_3way(.., 254..=254), c3(0, 0, n));
417    assert_eq!(full.count_3way(.., 255), c3(0, n, 0));
418    assert_eq!(full.count_3way(.., 255..=255), c3(0, n, 0));
419}