1#![allow(dead_code)]
2
3use std::ops::Range;
4
5const W: usize = u64::BITS as usize;
6
7struct IntVec {
8 unit: usize,
9 buf: Vec<u64>,
10 len: usize,
11}
12
13pub struct Rs01DictTree {
14 buf: IntVec,
15 rank_index: RankIndex,
16 select_index: (SelectIndex, SelectIndex),
17}
18
19struct RankIndex {
20 large: IntVec,
21 small: IntVec,
22 large_len: usize,
23 small_len: usize,
24}
25
26struct SelectIndex {
27 indir: IntVec,
28 sparse: IntVec,
29 dense: IntVec,
30 table_tree: Vec<u8>,
31 large_popcnt: usize,
33 branch: usize,
34 small_len: usize,
35}
36
37const RANK_TABLE: [[u8; 12]; 4096] = rank_table::<12, 4096>();
38const SELECT_TABLE: [[u8; 12]; 4096] = select_table::<12, 4096>();
39
40const fn rank_table<const LEN: usize, const PAT: usize>() -> [[u8; LEN]; PAT] {
41 let mut res = [[0; LEN]; PAT];
42 let mut i = 0;
43 while i < PAT {
44 let mut cur = 0;
45 let mut j = 0;
46 while j < LEN {
47 res[i][j] = cur;
48 if i >> j & 1 != 0 {
49 cur += 1;
50 }
51 j += 1;
52 }
53 i += 1;
54 }
55 res
56}
57
58const fn select_table<const LEN: usize, const PAT: usize>() -> [[u8; LEN]; PAT]
59{
60 let mut res = [[0; LEN]; PAT];
61 let mut i = 0;
62 while i < PAT {
63 let mut cur = 0;
64 let mut j = 0;
65 while j < LEN {
66 if i >> j & 1 != 0 {
67 res[i][cur] = j as _;
68 cur += 1;
69 }
70 j += 1;
71 }
72 i += 1;
73 }
74 res
75}
76
77impl IntVec {
78 pub fn new(unit: usize) -> Self { Self { unit, buf: vec![], len: 0 } }
79 pub fn len(&self) -> usize { self.len }
80 pub fn bitlen(&self) -> usize { self.len * self.unit }
81
82 pub fn push(&mut self, w: u64) {
83 let unit = self.unit;
84 debug_assert!(unit == W || w & (!0 << unit) == 0);
85
86 let bitlen = self.bitlen();
87 if unit == 0 {
88 } else if bitlen % W == 0 {
90 self.buf.push(w);
91 } else {
92 self.buf[bitlen / W] |= w << (bitlen % W);
93 if bitlen % W + unit > W {
94 self.buf.push(w >> (W - bitlen % W));
95 }
96 }
97 self.len += 1;
98 }
99
100 #[inline(always)]
101 pub fn get_usize(&self, i: usize) -> usize { self.get::<true>(i) as _ }
102
103 #[inline(always)]
104 pub fn get<const X: bool>(&self, i: usize) -> u64 {
105 let start = i * self.unit;
106 let end = start + self.unit;
107 self.bits_range::<X>(start..end)
108 }
109
110 #[inline(always)]
111 pub fn bits_range<const X: bool>(
112 &self,
113 Range { start, end }: Range<usize>,
114 ) -> u64 {
115 let end = end.min(self.bitlen()); let mask = !(!0 << (end - start));
117
118 let mut res = self.buf[start / W] >> (start % W);
119 if end > (start / W + 1) * W {
120 res |= self.buf[end / W] << (W - start % W);
121 }
122
123 ((if X { res } else { !res }) & mask) as _
124 }
125}
126
127impl std::fmt::Debug for IntVec {
128 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 fmt.debug_list()
130 .entries((0..self.len).map(|i| self.get::<true>(i)))
131 .finish()
132 }
133}
134
135fn bitlen(n: usize) -> usize {
136 1.max((n + 1).next_power_of_two().trailing_zeros() as usize)
138}
139
140impl RankIndex {
141 pub fn new(buf: &[bool]) -> Self {
142 let len = buf.len();
143 let small_len = (1_usize..)
144 .find(|&i| 4_usize.saturating_pow(i as _) >= len)
145 .unwrap(); let large_len = (2 * small_len).pow(2); let small_bitlen = bitlen(len.min(large_len));
149 let large_bitlen = bitlen(len);
150
151 let mut small = IntVec::new(small_bitlen);
152 let mut large = IntVec::new(large_bitlen);
153 let mut small_acc = 0;
154 let mut large_acc = 0;
155 let per = large_len / small_len;
156 for (c, i) in buf
157 .chunks(small_len)
158 .map(|ch| ch.iter().filter(|&&x| x).count() as u64)
159 .zip((0..per).cycle())
160 {
161 small.push(small_acc);
162 small_acc = if i < per - 1 { small_acc + c } else { 0 };
163
164 if i == 0 {
165 large.push(large_acc);
166 }
167 large_acc += c as u64;
168 }
169
170 Self { large, small, large_len, small_len }
172 }
173
174 fn table(len: usize) -> IntVec {
175 let unit = bitlen(len);
176 let mut table = IntVec::new(unit);
177 for i in 0..1 << len {
178 let mut cur = 0;
179 for j in 0..len {
180 table.push(cur);
181 if i >> j & 1 != 0 {
182 cur += 1;
183 }
184 }
185 }
186 table
187 }
188
189 #[inline(always)]
190 fn lookup(&self, w: u64, i: usize) -> usize {
191 RANK_TABLE[w as usize][i] as _
192 }
193
194 #[inline(always)]
195 pub fn rank1(&self, i: usize, b: &IntVec) -> usize {
196 let large_acc = self.large.get_usize(i / self.large_len);
197 let small_acc = self.small.get_usize(i / self.small_len);
198 let il = i / self.small_len * self.small_len;
199 let ir = il + self.small_len;
200 let w = b.bits_range::<true>(il..ir);
201 let small = self.lookup(w, i % self.small_len);
202 large_acc + small_acc + small
203 }
204 pub fn rank0(&self, i: usize, b: &IntVec) -> usize { i - self.rank1(i, b) }
205 pub fn rank<const X: bool>(&self, i: usize, b: &IntVec) -> usize {
206 if X { self.rank1(i, b) } else { self.rank0(i, b) }
207 }
208
209 #[cfg(test)]
210 pub fn size_info(&self) -> (usize, usize) {
211 let rt = self.large.bitlen() + self.small.bitlen();
216 (rt, rt + 8 * RANK_TABLE.len() * RANK_TABLE[0].len())
218 }
219}
220
221impl SelectIndex {
222 pub fn new<const X: bool>(buf: &[bool]) -> Self {
223 let len = buf.len();
224 let len_lg = (len as f64).log2().max(1.0);
225
226 let dense_max = (len_lg.powi(4) / 128.0).ceil() as usize;
227 let large_popcnt = (len_lg.powi(2) / 16.0).ceil() as usize;
228 let small_len = (len_lg / 2.0).ceil().max(2.0) as usize;
229 let branch = len_lg.cbrt().ceil() as usize;
230
231 let mut indir = IntVec::new(bitlen(len) + 2);
232 let mut sparse = IntVec::new(bitlen(len));
233 let mut dense = IntVec::new(bitlen(large_popcnt));
234
235 let mut start = 0;
236 let mut pos = vec![];
237 for i in 0..len {
238 if buf[i] == X {
239 pos.push(i);
240 }
241 if !(pos.len() == large_popcnt || i == len - 1) {
242 continue;
243 }
244
245 let end = i;
246 if end + 1 - start > dense_max {
247 indir.push((sparse.len() << 1 | 0) as _);
248 indir.push(0);
249 indir.push(0);
250 for &p in &pos {
251 sparse.push(p as _);
252 }
253 } else {
254 indir.push((dense.len() << 1 | 1) as _);
255 let ceil_len = (1..)
256 .map(|i| branch.pow(i) * small_len)
257 .find(|&b| b >= end + 1 - start)
258 .unwrap();
259 let mut cur = dense.len();
260 for i in (start..start + ceil_len).step_by(small_len).rev() {
261 let il = i.min(end + 1);
262 let ir = (il + small_len).min(end + 1);
263 let w = (il..ir).filter(|&i| buf[i] == X).count();
264 dense.push(w as _);
265 }
266 while cur + branch < dense.len() {
267 let mut sum = 0;
268 for _ in 0..branch {
269 sum += dense.get::<true>(cur);
270 cur += 1;
271 }
272 dense.push(sum);
273 }
274 indir.push(dense.len() as _);
275 indir.push(start as _);
276 }
277
278 pos.clear();
279 start = i + 1;
280 }
281
282 let table_tree = Self::table_tree(large_popcnt, branch);
283
284 Self {
285 indir,
286 sparse,
287 dense,
288 table_tree,
289 large_popcnt,
290 branch,
291 small_len,
292 }
293 }
294
295 #[inline(always)]
296 fn lookup_tree(&self, w: u64, i: usize) -> (usize, usize) {
297 let bitlen_branch = bitlen(self.branch);
298 let wi = w as usize * self.large_popcnt + i;
299 let res = self.table_tree[wi] as usize;
300 (res >> bitlen_branch, res & !(!0 << bitlen_branch))
302 }
303
304 #[inline(always)]
305 fn lookup_word(&self, w: u64, i: usize) -> usize {
306 SELECT_TABLE[w as usize][i] as _
307 }
308
309 fn table_tree(popcnt: usize, branch: usize) -> Vec<u8> {
311 let len = bitlen(popcnt);
312 let enc = |i, j| i << bitlen(branch) | j;
315 let mut table = vec![];
316 for i in 0..1 << (len * branch) {
318 let mut count = 0;
319 for b in 0..branch {
320 let sh = (branch - 1 - b) * len;
321 let c = i >> sh & !(!0 << len);
322 if count + c > popcnt {
323 break;
324 }
325 for _ in 0..c {
326 table.push(enc(count, b) as _);
327 }
328 count += c;
329 }
330 for _ in count..popcnt {
331 table.push(0);
332 }
333 }
334 table
335 }
336
337 fn table_word(len: usize) -> IntVec {
338 let unit = bitlen(len);
339 let mut table = IntVec::new(unit);
340 for i in 0..1 << len {
341 let mut cur = 0;
342 for j in 0..len {
343 if i >> j & 1 != 0 {
344 table.push(j as _);
345 cur += 1;
346 }
347 }
348 for _ in cur..len {
349 table.push(0);
350 }
351 }
352 table
353 }
354
355 #[inline(always)]
356 pub fn select<const X: bool>(&self, i: usize, b: &IntVec) -> usize {
357 let (il_div, il_mod) = (i / self.large_popcnt, i % self.large_popcnt);
358 let large = self.indir.get_usize(3 * il_div);
359 let (large_i, large_ty) = (large >> 1, large & 1);
360 if large_ty == 0 {
361 self.sparse.get_usize(large_i + il_mod)
362 } else {
363 let start = large_i;
364 let end = self.indir.get_usize(3 * il_div + 1);
365 let b_start = self.indir.get_usize(3 * il_div + 2);
366 let unit = bitlen(self.large_popcnt);
367 let branch = self.branch;
368 let mut cur = 0;
369 let mut i = il_mod;
370 let mut b_i = 0;
371 loop {
372 let ir = (end - cur) * unit;
374 let il = ir - branch * unit;
375 let w = self.dense.bits_range::<true>(il..ir);
376 let (acc, br) = self.lookup_tree(w, i);
377 let tmp = (cur + br + 1) * branch;
378 if end - start <= tmp {
379 let il = b_start + (b_i * branch + br) * self.small_len;
380 let ir = il + self.small_len;
381 let w = b.bits_range::<X>(il..ir);
382 break il + self.lookup_word(w, i - acc);
383 }
384 b_i = b_i * branch + br;
385 cur = tmp;
386 i -= acc;
387 }
388 }
389 }
390
391 #[cfg(test)]
392 pub fn size_info(&self) -> (usize, usize) {
393 let rt =
398 self.indir.bitlen() + self.sparse.bitlen() + self.dense.bitlen();
399
400 (
402 rt,
403 rt + 8 * self.table_tree.len()
404 + 8 * SELECT_TABLE.len() * SELECT_TABLE[0].len(),
405 )
406 }
407}
408
409impl Rs01DictTree {
410 pub fn new(a: &[bool]) -> Self {
411 let rank_index = RankIndex::new(&a);
412 let mut buf = IntVec::new(1);
413 for &x in a {
414 buf.push(x as _);
415 }
416 let select_index =
417 (SelectIndex::new::<false>(&a), SelectIndex::new::<true>(&a));
418
419 Self { buf, rank_index, select_index }
423 }
424
425 pub fn rank<const X: bool>(&self, i: usize) -> usize {
426 self.rank_index.rank::<X>(i, &self.buf)
427 }
428 pub fn rank0(&self, i: usize) -> usize { self.rank::<false>(i) }
429 pub fn rank1(&self, i: usize) -> usize { self.rank::<true>(i) }
430
431 pub fn select<const X: bool>(&self, i: usize) -> usize {
432 if X {
433 self.select_index.1.select::<X>(i, &self.buf)
434 } else {
435 self.select_index.0.select::<X>(i, &self.buf)
436 }
437 }
438 pub fn select0(&self, i: usize) -> usize { self.select::<false>(i) }
439 pub fn select1(&self, i: usize) -> usize { self.select::<true>(i) }
440
441 #[cfg(test)]
442 pub fn size_info(&self) {
443 let len = self.buf.bitlen();
444 let naive = 3 * len * bitlen(len);
445 eprintln!("* naive: {naive:>10} bits, {:>10} words", naive / 64);
446
447 let (r, r_table) = self.rank_index.size_info();
448 let (s0, s0_table) = self.select_index.0.size_info();
449 let (s1, s1_table) = self.select_index.1.size_info();
450 let sum = r + s0 + s1;
451 let sum_table = r_table + s0_table + s1_table;
452
453 let ratio = sum as f64 / naive as f64;
454 eprintln!(
455 "- table: {sum:>10} bits, {:>10} words (x{ratio:.03})",
456 sum / 64
457 );
458 let ratio = sum_table as f64 / naive as f64;
459 eprintln!(
460 "+ table: {sum_table:>10} bits, {:>10} words (x{ratio:.03})",
461 sum_table / 64
462 );
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use rand::{
469 distributions::{Bernoulli, Distribution},
470 Rng, SeedableRng,
471 };
472 use rand_chacha::ChaCha20Rng;
473
474 use crate::*;
475
476 fn rng() -> ChaCha20Rng {
477 ChaCha20Rng::from_seed([
478 0x55, 0xEF, 0xE0, 0x3C, 0x71, 0xDA, 0xFC, 0xAB, 0x5C, 0x1A, 0x9F,
479 0xEB, 0xA4, 0x9E, 0x61, 0xE6, 0x1E, 0x7E, 0x29, 0x77, 0x38, 0x9A,
480 0xF5, 0x67, 0xF5, 0xDD, 0x07, 0x06, 0xAE, 0xE4, 0x5A, 0xDC,
481 ])
482 }
483
484 fn test_rank_internal(len: usize, p: f64) {
485 let mut rng = rng();
486 let dist = Bernoulli::new(p).unwrap();
487 let a: Vec<_> = (0..len).map(|_| dist.sample(&mut rng)).collect();
488 let naive: Vec<_> = a
489 .iter()
490 .map(|&x| x as usize)
491 .scan(0, |acc, x| Some(std::mem::replace(acc, *acc + x)))
492 .collect();
493 let dict = Rs01DictTree::new(&a);
494 for i in 0..len {
495 assert_eq!(dict.rank1(i), naive[i], "i: {}", i);
496 assert_eq!(dict.rank0(i), i - naive[i], "i: {}", i);
497 }
498 if p == 1.0 {
499 eprintln!("---");
500 eprintln!("a.len(): {}", a.len());
501 dict.size_info();
502 }
503 }
504
505 fn test_select_internal(len: usize, p: f64) {
506 eprintln!("{:?}", (len, p));
507 let mut rng = rng();
508 let dist = Bernoulli::new(p).unwrap();
509 let a: Vec<_> = (0..len).map(|_| dist.sample(&mut rng)).collect();
510 let naive: (Vec<_>, _) = (0..len).partition(|&i| !a[i]);
511 let dict = Rs01DictTree::new(&a);
512
513 for i in 0..naive.0.len() {
514 assert_eq!(dict.select0(i), naive.0[i], "i: {}", i);
515 }
516 for i in 0..naive.1.len() {
517 assert_eq!(dict.select1(i), naive.1[i], "i: {}", i);
518 }
519 if p == 1.0 {
520 eprintln!("---");
521 eprintln!("a.len(): {}", a.len());
522 dict.size_info();
523 }
524 }
525
526 #[test]
527 fn test_rank() {
528 for len in Some(0).into_iter().chain((0..=7).map(|e| 10_usize.pow(e))) {
529 for &p in &[1.0, 0.999, 0.9, 0.5, 0.1, 1.0e-3, 0.0] {
530 test_rank_internal(len, p);
531 }
532 }
533 }
534
535 #[test]
536 fn test_select() {
537 for len in Some(0).into_iter().chain((0..=7).map(|e| 10_usize.pow(e))) {
538 for &p in &[1.0, 0.999, 0.9, 0.5, 0.1, 1.0e-3, 0.0] {
539 test_select_internal(len, p);
540 }
541 }
542 }
543
544 #[test]
545 fn sanity_check() { test_select_internal(100, 0.2); }
546}
547
548#[cfg(test)]
549macro_rules! bitvec {
550 ($lit:literal) => {
551 $lit.iter()
552 .filter(|&&b| matches!(b, b'0' | b'1'))
553 .map(|&b| b != b'0')
554 .collect::<Vec<_>>()
555 };
556}
557
558#[test]
559fn simple() {
560 for i in 0..=7 {
597 let a = vec![false; 10_usize.pow(i)];
598 let dict = Rs01DictTree::new(&a);
599 dict.size_info();
600 }
601}
602
603#[test]
604fn table() {
605 let w = 0b_1101_1010_1001;
606 assert_eq!(RANK_TABLE[w], [0, 1, 1, 1, 2, 2, 3, 3, 4, 5, 5, 6]);
607 assert_eq!(SELECT_TABLE[w], [0, 3, 5, 7, 8, 10, 11, 0, 0, 0, 0, 0]);
608}