1use super::rs_dict;
4use super::super::traits::count;
5use super::super::traits::find_nth;
6use super::super::traits::quantile;
7use super::super::utils::buf_range;
8
9use std::ops::{Index, Range, RangeBounds, RangeInclusive};
10
11use buf_range::bounds_within;
12use count::{Count, Count3way, Count3wayResult};
13use find_nth::FindNth;
14use quantile::Quantile;
15use rs_dict::RsDict;
16
17pub struct WaveletMatrix<I> {
49 len: usize,
50 bitlen: usize,
51 buf: Vec<RsDict>,
52 zeros: Vec<usize>,
53 orig: Vec<I>,
54}
55
56impl<I: WmInt> From<Vec<I>> for WaveletMatrix<I> {
57 fn from(orig: Vec<I>) -> Self {
58 let len = orig.len();
59 let bitlen =
60 orig.iter().map(|ai| ai.bitlen()).max().unwrap_or(0) as usize;
61 let mut whole = orig.clone();
62 let mut zeros = vec![0; bitlen];
63 let mut buf = vec![];
64 for i in (0..bitlen).rev() {
65 let mut zero = vec![];
66 let mut one = vec![];
67 let mut vb = vec![false; len];
68 for (j, aj) in whole.into_iter().enumerate() {
69 (if aj.test(i) { &mut one } else { &mut zero }).push(aj);
70 vb[j] = aj.test(i);
71 }
72 zeros[i] = zero.len();
73 buf.push(vb.into());
74 whole = zero;
75 whole.append(&mut one);
76 }
77 buf.reverse();
78 Self { len, bitlen, buf, zeros, orig }
79 }
80}
81
82impl<I: WmInt> Count<I> for WaveletMatrix<I> {
83 fn count(&self, range: impl RangeBounds<usize>, value: I) -> usize {
84 self.count_3way(range, value).eq()
85 }
86}
87
88impl<I: WmInt> Count<RangeInclusive<I>> for WaveletMatrix<I> {
89 fn count(
90 &self,
91 range: impl RangeBounds<usize>,
92 value: RangeInclusive<I>,
93 ) -> usize {
94 self.count_3way(range, value).eq()
95 }
96}
97
98impl<I: WmInt> Count3way<I> for WaveletMatrix<I> {
99 fn count_3way(
100 &self,
101 range: impl RangeBounds<usize>,
102 value: I,
103 ) -> Count3wayResult {
104 let Range { start, end } = bounds_within(range, self.len);
105 let (lt, gt) = self.count_3way_internal(start..end, value);
106 let eq = (end - start) - (lt + gt);
107 Count3wayResult::new(lt, eq, gt)
108 }
109}
110
111impl<I: WmInt> Count3way<RangeInclusive<I>> for WaveletMatrix<I> {
112 fn count_3way(
113 &self,
114 range: impl RangeBounds<usize>,
115 value: RangeInclusive<I>,
116 ) -> Count3wayResult {
117 let Range { start: il, end: ir } = bounds_within(range, self.len);
118 let vl = *value.start();
119 let vr = *value.end();
120 let lt = self.count_3way_internal(il..ir, vl).0;
121 let gt = self.count_3way_internal(il..ir, vr).1;
122 let eq = (ir - il) - (lt + gt);
123 Count3wayResult::new(lt, eq, gt)
124 }
125}
126
127impl<I: WmInt> WaveletMatrix<I> {
128 fn count_3way_internal(
129 &self,
130 Range { mut start, mut end }: Range<usize>,
131 value: I,
132 ) -> (usize, usize) {
133 if start == end {
134 return (0, 0);
135 }
136 if value.bitlen() > self.bitlen {
137 return (end - start, 0);
138 }
139 let mut lt = 0;
140 let mut gt = 0;
141 for i in (0..self.bitlen).rev() {
142 let tmp = end - start;
143 if !value.test(i) {
144 start = self.buf[i].rank(start, 0);
145 end = self.buf[i].rank(end, 0);
146 } else {
147 start = self.zeros[i] + self.buf[i].rank(start, 1);
148 end = self.zeros[i] + self.buf[i].rank(end, 1);
149 }
150 *(if value.test(i) { &mut lt } else { &mut gt }) +=
151 tmp - (end - start);
152 }
153 (lt, gt)
154 }
155}
156
157impl<I: WmInt> Quantile for WaveletMatrix<I> {
158 type Output = I;
159 fn quantile(
160 &self,
161 range: impl RangeBounds<usize>,
162 mut n: usize,
163 ) -> Option<I> {
164 let Range { mut start, mut end } = bounds_within(range, self.len);
165 if end - start <= n {
166 return None;
167 }
168 let mut res = I::zero();
169 for i in (0..self.bitlen).rev() {
170 let z = self.buf[i].count(start..end, 0);
171 if n < z {
172 start = self.buf[i].rank(start, 0);
173 end = self.buf[i].rank(end, 0);
174 } else {
175 res.set(i);
176 start = self.zeros[i] + self.buf[i].rank(start, 1);
177 end = self.zeros[i] + self.buf[i].rank(end, 1);
178 n -= z;
179 }
180 }
181 Some(res)
182 }
183}
184
185impl<I: WmInt> WaveletMatrix<I> {
186 pub fn xored_quantile(
187 &self,
188 range: impl RangeBounds<usize>,
189 mut n: usize,
190 x: I,
191 ) -> Option<I> {
192 let Range { mut start, mut end } = bounds_within(range, self.len);
193 if end - start <= n {
194 return None;
195 }
196 let mut res = I::zero();
197 for i in (0..self.bitlen).rev() {
198 let z = self.buf[i].count(start..end, 0);
199 if !x.test(i) {
200 if n < z {
201 start = self.buf[i].rank(start, 0);
202 end = self.buf[i].rank(end, 0);
203 } else {
204 res.set(i);
205 start = self.zeros[i] + self.buf[i].rank(start, 1);
206 end = self.zeros[i] + self.buf[i].rank(end, 1);
207 n -= z;
208 }
209 } else {
210 let z = (end - start) - z;
211 if n < z {
212 start = self.zeros[i] + self.buf[i].rank(start, 1);
213 end = self.zeros[i] + self.buf[i].rank(end, 1);
214 } else {
215 res.set(i);
216 start = self.buf[i].rank(start, 0);
217 end = self.buf[i].rank(end, 0);
218 n -= z;
219 }
220 }
221 }
222 Some(res)
223 }
224}
225
226impl<I: WmInt> FindNth<I> for WaveletMatrix<I> {
227 fn find_nth(
228 &self,
229 range: impl RangeBounds<usize>,
230 value: I,
231 n: usize,
232 ) -> Option<usize> {
233 let start = bounds_within(range, self.len).start;
234 let (lt, gt) = self.count_3way_internal(0..start, value);
235 let offset = start - (lt + gt);
236 Some(self.select(value, n + offset + 1)? - 1)
237 }
238}
239
240impl<I: WmInt> WaveletMatrix<I> {
241 pub fn len(&self) -> usize { self.len }
242 pub fn is_empty(&self) -> bool { self.len == 0 }
243
244 pub fn rank(&self, end: usize, value: I) -> usize {
245 self.count(0..end, value)
246 }
247 pub fn select(&self, value: I, mut n: usize) -> Option<usize> {
248 if n == 0 {
249 return Some(0);
250 }
251 let (lt, gt) = self.count_3way_internal(0..self.len, value);
252 let count = self.len - (lt + gt);
253 if count < n {
254 return None;
255 }
256 let si = self.start_pos(value);
257 let value0 = value.test(0) as u64;
258 n += self.buf[0].rank(si, value0);
259 n = self.buf[0].select(value0, n).unwrap();
260
261 for i in 1..self.bitlen {
262 if !value.test(i) {
263 n = self.buf[i].select(0, n).unwrap();
264 } else {
265 n -= self.zeros[i];
266 n = self.buf[i].select(1, n).unwrap();
267 }
268 }
269 Some(n)
270 }
271 fn start_pos(&self, value: I) -> usize {
272 let mut start = 0;
273 let mut end = 0;
274 for i in (1..self.bitlen).rev() {
275 if !value.test(i) {
276 start = self.buf[i].rank(start, 0);
277 end = self.buf[i].rank(end, 0);
278 } else {
279 start = self.zeros[i] + self.buf[i].rank(start, 1);
280 end = self.zeros[i] + self.buf[i].rank(end, 1);
281 }
282 }
283 start
284 }
285}
286
287impl<I: WmInt> Index<usize> for WaveletMatrix<I> {
288 type Output = I;
289 fn index(&self, i: usize) -> &I { &self.orig[i] }
290}
291
292pub trait WmInt: Copy {
293 fn test(self, i: usize) -> bool;
294 fn set(&mut self, i: usize);
295 fn bitlen(self) -> usize;
296 fn zero() -> Self;
297}
298
299macro_rules! impl_wm_int {
300 ( $( $ty:ty )* ) => { $(
301 impl WmInt for $ty {
302 fn test(self, i: usize) -> bool { self >> i & 1 != 0 }
303 fn set(&mut self, i: usize) { *self |= 1 << i; }
304 fn bitlen(self) -> usize {
305 let w = (0 as $ty).count_zeros() as usize;
306 if self.test(w - 1) {
307 w
308 } else {
309 (self + 1).next_power_of_two().trailing_zeros() as usize
310 }
311 }
312 fn zero() -> $ty { 0 }
313 }
314 )* };
315}
316
317impl_wm_int! { u8 u16 u32 u64 u128 usize }
318
319#[test]
320fn test_simple() {
321 let n = 300;
322 let f = std::iter::successors(Some(296), |&x| Some((x * 258 + 185) % 397))
323 .map(|x| x & 7);
324 let buf: Vec<_> = f.take(n).collect();
325 let wm: WaveletMatrix<u32> = buf.clone().into();
326 for start in 0..n {
327 let mut count = vec![0; 8];
328 for end in start..=n {
329 for xl in 0..=7 {
330 for xr in xl..=7 {
331 let lt: usize = count[..xl as usize].iter().sum();
332 let gt: usize = count[xr as usize + 1..].iter().sum();
333 let eq = (end - start) - (lt + gt);
334 let c3 = Count3wayResult::new(lt, eq, gt);
335 assert_eq!(wm.count_3way(start..end, xl..=xr), c3);
336 }
337
338 let lt: usize = count[..xl as usize].iter().sum();
339 let gt: usize = count[xl as usize + 1..].iter().sum();
340 let eq = (end - start) - (lt + gt);
341 let c3 = Count3wayResult::new(lt, eq, gt);
342 assert_eq!(wm.count(start..end, xl), eq);
343 assert_eq!(wm.count(start..end, xl..=xl), eq);
344 assert_eq!(wm.count_3way(start..end, xl), c3);
345 assert_eq!(wm.count_3way(start..end, xl..=xl), c3);
346 }
347
348 if end < n {
349 count[buf[end] as usize] += 1;
350 }
351 }
352 }
353
354 for start in 0..n {
355 let mut count = vec![0; 8];
356 for end in start..n {
357 let x = buf[end];
358 assert_eq!(wm.find_nth(start.., x, count[x as usize]), Some(end));
359 count[x as usize] += 1;
360 }
361 for x in 0..8 {
362 assert_eq!(wm.find_nth(start.., x, count[x as usize]), None);
363 }
364 }
365
366 for start in 0..n {
367 for end in start..n {
368 let mut tmp = buf[start..end].to_vec();
369 tmp.sort_unstable();
370 for i in 0..tmp.len() {
371 assert_eq!(wm.quantile(start..end, i), Some(tmp[i]));
372 }
373 assert_eq!(wm.quantile(start..end, tmp.len()), None);
374 }
375 }
376
377 for start in 0..n {
378 for end in start..n {
379 for x in 0..8 {
380 let mut tmp: Vec<_> =
381 buf[start..end].iter().map(|&y| x ^ y).collect();
382 tmp.sort_unstable();
383 for i in 0..tmp.len() {
384 assert_eq!(
385 wm.xored_quantile(start..end, i, x),
386 Some(tmp[i])
387 );
388 }
389 assert_eq!(wm.xored_quantile(start..end, tmp.len(), x), None);
390 }
391 }
392 }
393}
394
395#[test]
396fn test_count() {
397 let n = 8;
398 let c3 = |lt, eq, gt| Count3wayResult::new(lt, eq, gt);
399
400 let zero: WaveletMatrix<u8> = vec![0; n].into();
401 assert_eq!(zero.count_3way(.., 0), c3(0, n, 0));
402 assert_eq!(zero.count_3way(.., 0..=0), c3(0, n, 0));
403 assert_eq!(zero.count_3way(.., 1), c3(n, 0, 0));
404 assert_eq!(zero.count_3way(.., 1..=1), c3(n, 0, 0));
405 assert_eq!(zero.count_3way(.., 254), c3(n, 0, 0));
406 assert_eq!(zero.count_3way(.., 254..=254), c3(n, 0, 0));
407 assert_eq!(zero.count_3way(.., 255), c3(n, 0, 0));
408 assert_eq!(zero.count_3way(.., 255..=255), c3(n, 0, 0));
409
410 let full: WaveletMatrix<u8> = vec![!0; n].into();
411 assert_eq!(full.count_3way(.., 0), c3(0, 0, n));
412 assert_eq!(full.count_3way(.., 0..=0), c3(0, 0, n));
413 assert_eq!(full.count_3way(.., 1), c3(0, 0, n));
414 assert_eq!(full.count_3way(.., 1..=1), c3(0, 0, n));
415 assert_eq!(full.count_3way(.., 254), c3(0, 0, n));
416 assert_eq!(full.count_3way(.., 254..=254), c3(0, 0, n));
417 assert_eq!(full.count_3way(.., 255), c3(0, n, 0));
418 assert_eq!(full.count_3way(.., 255..=255), c3(0, n, 0));
419}