1#![allow(dead_code)]
2
3use std::ops::{Range, RangeBounds, RangeInclusive};
4
5use rs01_dict::Rs01Dict;
6use usize_bounds::UsizeBounds;
7
8pub struct WaveletMatrix<I> {
9 len: usize,
10 bitlen: usize,
11 buf: Vec<Rs01Dict>,
12 zeros: Vec<usize>,
13 orig: Vec<I>,
14}
15
16#[derive(Clone, Copy, Debug, Eq, PartialEq)]
17pub struct Count3wayResult {
18 lt: usize,
19 eq: usize,
20 gt: usize,
21}
22
23impl Count3wayResult {
24 fn new(lt: usize, eq: usize, gt: usize) -> Self { Self { lt, eq, gt } }
25 pub fn lt(self) -> usize { self.lt }
26 pub fn le(self) -> usize { self.lt + self.eq }
27 pub fn eq(self) -> usize { self.eq }
28 pub fn ge(self) -> usize { self.eq + self.gt }
29 pub fn gt(self) -> usize { self.gt }
30 pub fn ne(self) -> usize { self.gt + self.lt }
31}
32
33impl<I: WmInt> From<Vec<I>> for WaveletMatrix<I> {
34 fn from(orig: Vec<I>) -> Self {
35 let len = orig.len();
36 let bitlen =
37 orig.iter().map(|ai| ai.bitlen()).max().unwrap_or(0) as usize;
38 let mut whole = orig.clone();
39 let mut zeros = vec![0; bitlen];
40 let mut buf = vec![];
41 for i in (0..bitlen).rev() {
42 let mut zero = vec![];
43 let mut one = vec![];
44 let mut vb = vec![false; len];
45 for (j, aj) in whole.into_iter().enumerate() {
46 (if aj.test(i) { &mut one } else { &mut zero }).push(aj);
47 vb[j] = aj.test(i);
48 }
49 zeros[i] = zero.len();
50 buf.push(Rs01Dict::new(&vb));
51 whole = zero;
52 whole.append(&mut one);
53 }
54 buf.reverse();
55 Self { len, bitlen, buf, zeros, orig }
56 }
57}
58
59impl<I: WmInt> WaveletMatrix<I> {
60 pub fn count<R: WmIntRange<Int = I>>(
61 &self,
62 range: impl RangeBounds<usize>,
63 value: R,
64 ) -> usize {
65 self.count_3way(range, value).eq()
66 }
67 pub fn count_3way<R: WmIntRange<Int = I>>(
68 &self,
69 range: impl RangeBounds<usize>,
70 value: R,
71 ) -> Count3wayResult {
72 let Range { start: il, end: ir } = range.to_range(self.len);
73 let value = value.to_inclusive_range();
74 let vl = *value.start();
75 let vr = *value.end();
76 let (lt, gt) = if vl == vr {
77 self.count_3way_internal(il..ir, vl)
78 } else {
79 let lt = self.count_3way_internal(il..ir, vl).0;
80 let gt = self.count_3way_internal(il..ir, vr).1;
81 (lt, gt)
82 };
83 let eq = (ir - il) - (lt + gt);
84 Count3wayResult::new(lt, eq, gt)
85 }
86 fn count_3way_internal(
87 &self,
88 Range { mut start, mut end }: Range<usize>,
89 value: I,
90 ) -> (usize, usize) {
91 if start == end {
92 return (0, 0);
93 }
94 if value.bitlen() > self.bitlen {
95 return (end - start, 0);
96 }
97 let mut lt = 0;
98 let mut gt = 0;
99 for i in (0..self.bitlen).rev() {
100 let tmp = end - start;
101 if !value.test(i) {
102 start = self.buf[i].count0(..start);
103 end = self.buf[i].count0(..end);
104 } else {
105 start = self.zeros[i] + self.buf[i].count1(..start);
106 end = self.zeros[i] + self.buf[i].count1(..end);
107 }
108 let len = end - start;
109 *(if value.test(i) { &mut lt } else { &mut gt }) += tmp - len;
110 }
111 (lt, gt)
112 }
113
114 pub fn quantile(
115 &self,
116 range: impl RangeBounds<usize>,
117 mut n: usize,
118 ) -> Option<I> {
119 let Range { mut start, mut end } = range.to_range(self.len);
120 if end - start <= n {
121 return None;
122 }
123 let mut res = I::zero();
124 for i in (0..self.bitlen).rev() {
125 let z = self.buf[i].count0(start..end);
126 if n < z {
127 start = self.buf[i].count0(..start);
128 end = self.buf[i].count0(..end);
129 } else {
130 res.set(i);
131 start = self.zeros[i] + self.buf[i].count1(..start);
132 end = self.zeros[i] + self.buf[i].count1(..end);
133 n -= z;
134 }
135 }
136 Some(res)
137 }
138}
139
140pub trait WmInt: Copy + Eq {
141 fn test(self, i: usize) -> bool;
142 fn set(&mut self, i: usize);
143 fn bitlen(self) -> usize;
144 fn zero() -> Self;
145}
146
147pub trait WmIntRange {
148 type Int;
149 fn to_inclusive_range(self) -> RangeInclusive<Self::Int>;
150}
151
152macro_rules! impl_uint {
153 ( $($ty:ty)* ) => { $(
154 impl WmInt for $ty {
155 fn test(self, i: usize) -> bool { self >> i & 1 != 0 }
156 fn set(&mut self, i: usize) { *self |= 1 << i; }
157 fn bitlen(self) -> usize {
158 let bits = <$ty>::BITS;
159 (if self == 0 { 1 } else { bits - self.leading_zeros() }) as _
160 }
161 fn zero() -> $ty { 0 }
162 }
163 impl WmIntRange for $ty {
164 type Int = $ty;
165 fn to_inclusive_range(self) -> RangeInclusive<$ty> { self..=self }
166 }
167 impl WmIntRange for RangeInclusive<$ty> {
168 type Int = $ty;
169 fn to_inclusive_range(self) -> RangeInclusive<$ty> { self }
170 }
171 )* }
172}
173
174impl_uint! { u8 u16 u32 u64 u128 usize }