rs01dict_tree/
lib.rs

1#![allow(dead_code)]
2
3use std::ops::Range;
4
5const W: usize = u64::BITS as usize;
6
7struct IntVec {
8    unit: usize,
9    buf: Vec<u64>,
10    len: usize,
11}
12
13pub struct Rs01DictTree {
14    buf: IntVec,
15    rank_index: RankIndex,
16    select_index: (SelectIndex, SelectIndex),
17}
18
19struct RankIndex {
20    large: IntVec,
21    small: IntVec,
22    large_len: usize,
23    small_len: usize,
24}
25
26struct SelectIndex {
27    indir: IntVec,
28    sparse: IntVec,
29    dense: IntVec,
30    table_tree: Vec<u8>,
31    // table_tree: IntVec,
32    large_popcnt: usize,
33    branch: usize,
34    small_len: usize,
35}
36
37const RANK_TABLE: [[u8; 12]; 4096] = rank_table::<12, 4096>();
38const SELECT_TABLE: [[u8; 12]; 4096] = select_table::<12, 4096>();
39
40const fn rank_table<const LEN: usize, const PAT: usize>() -> [[u8; LEN]; PAT] {
41    let mut res = [[0; LEN]; PAT];
42    let mut i = 0;
43    while i < PAT {
44        let mut cur = 0;
45        let mut j = 0;
46        while j < LEN {
47            res[i][j] = cur;
48            if i >> j & 1 != 0 {
49                cur += 1;
50            }
51            j += 1;
52        }
53        i += 1;
54    }
55    res
56}
57
58const fn select_table<const LEN: usize, const PAT: usize>() -> [[u8; LEN]; PAT]
59{
60    let mut res = [[0; LEN]; PAT];
61    let mut i = 0;
62    while i < PAT {
63        let mut cur = 0;
64        let mut j = 0;
65        while j < LEN {
66            if i >> j & 1 != 0 {
67                res[i][cur] = j as _;
68                cur += 1;
69            }
70            j += 1;
71        }
72        i += 1;
73    }
74    res
75}
76
77impl IntVec {
78    pub fn new(unit: usize) -> Self { Self { unit, buf: vec![], len: 0 } }
79    pub fn len(&self) -> usize { self.len }
80    pub fn bitlen(&self) -> usize { self.len * self.unit }
81
82    pub fn push(&mut self, w: u64) {
83        let unit = self.unit;
84        debug_assert!(unit == W || w & (!0 << unit) == 0);
85
86        let bitlen = self.bitlen();
87        if unit == 0 {
88            // nothing to do
89        } else if bitlen % W == 0 {
90            self.buf.push(w);
91        } else {
92            self.buf[bitlen / W] |= w << (bitlen % W);
93            if bitlen % W + unit > W {
94                self.buf.push(w >> (W - bitlen % W));
95            }
96        }
97        self.len += 1;
98    }
99
100    #[inline(always)]
101    pub fn get_usize(&self, i: usize) -> usize { self.get::<true>(i) as _ }
102
103    #[inline(always)]
104    pub fn get<const X: bool>(&self, i: usize) -> u64 {
105        let start = i * self.unit;
106        let end = start + self.unit;
107        self.bits_range::<X>(start..end)
108    }
109
110    #[inline(always)]
111    pub fn bits_range<const X: bool>(
112        &self,
113        Range { start, end }: Range<usize>,
114    ) -> u64 {
115        let end = end.min(self.bitlen()); // (!)
116        let mask = !(!0 << (end - start));
117
118        let mut res = self.buf[start / W] >> (start % W);
119        if end > (start / W + 1) * W {
120            res |= self.buf[end / W] << (W - start % W);
121        }
122
123        ((if X { res } else { !res }) & mask) as _
124    }
125}
126
127impl std::fmt::Debug for IntVec {
128    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        fmt.debug_list()
130            .entries((0..self.len).map(|i| self.get::<true>(i)))
131            .finish()
132    }
133}
134
135fn bitlen(n: usize) -> usize {
136    // max {1, ceil(log2(|{0, 1, ..., n-1}|))}
137    1.max((n + 1).next_power_of_two().trailing_zeros() as usize)
138}
139
140impl RankIndex {
141    pub fn new(buf: &[bool]) -> Self {
142        let len = buf.len();
143        let small_len = (1_usize..)
144            .find(|&i| 4_usize.saturating_pow(i as _) >= len)
145            .unwrap(); // log(n)/2
146        let large_len = (2 * small_len).pow(2); // log(n)^2
147
148        let small_bitlen = bitlen(len.min(large_len));
149        let large_bitlen = bitlen(len);
150
151        let mut small = IntVec::new(small_bitlen);
152        let mut large = IntVec::new(large_bitlen);
153        let mut small_acc = 0;
154        let mut large_acc = 0;
155        let per = large_len / small_len;
156        for (c, i) in buf
157            .chunks(small_len)
158            .map(|ch| ch.iter().filter(|&&x| x).count() as u64)
159            .zip((0..per).cycle())
160        {
161            small.push(small_acc);
162            small_acc = if i < per - 1 { small_acc + c } else { 0 };
163
164            if i == 0 {
165                large.push(large_acc);
166            }
167            large_acc += c as u64;
168        }
169
170        // let table = Self::table(small_len);
171        Self { large, small, large_len, small_len }
172    }
173
174    fn table(len: usize) -> IntVec {
175        let unit = bitlen(len);
176        let mut table = IntVec::new(unit);
177        for i in 0..1 << len {
178            let mut cur = 0;
179            for j in 0..len {
180                table.push(cur);
181                if i >> j & 1 != 0 {
182                    cur += 1;
183                }
184            }
185        }
186        table
187    }
188
189    #[inline(always)]
190    fn lookup(&self, w: u64, i: usize) -> usize {
191        RANK_TABLE[w as usize][i] as _
192    }
193
194    #[inline(always)]
195    pub fn rank1(&self, i: usize, b: &IntVec) -> usize {
196        let large_acc = self.large.get_usize(i / self.large_len);
197        let small_acc = self.small.get_usize(i / self.small_len);
198        let il = i / self.small_len * self.small_len;
199        let ir = il + self.small_len;
200        let w = b.bits_range::<true>(il..ir);
201        let small = self.lookup(w, i % self.small_len);
202        large_acc + small_acc + small
203    }
204    pub fn rank0(&self, i: usize, b: &IntVec) -> usize { i - self.rank1(i, b) }
205    pub fn rank<const X: bool>(&self, i: usize, b: &IntVec) -> usize {
206        if X { self.rank1(i, b) } else { self.rank0(i, b) }
207    }
208
209    #[cfg(test)]
210    pub fn size_info(&self) -> (usize, usize) {
211        // eprintln!("large: {} bits", self.large.bitlen());
212        // eprintln!("small: {} bits", self.small.bitlen());
213        // eprintln!("table: {} bits", self.table.bitlen());
214
215        let rt = self.large.bitlen() + self.small.bitlen();
216        // (rt, rt + self.table.bitlen())
217        (rt, rt + 8 * RANK_TABLE.len() * RANK_TABLE[0].len())
218    }
219}
220
221impl SelectIndex {
222    pub fn new<const X: bool>(buf: &[bool]) -> Self {
223        let len = buf.len();
224        let len_lg = (len as f64).log2().max(1.0);
225
226        let dense_max = (len_lg.powi(4) / 128.0).ceil() as usize;
227        let large_popcnt = (len_lg.powi(2) / 16.0).ceil() as usize;
228        let small_len = (len_lg / 2.0).ceil().max(2.0) as usize;
229        let branch = len_lg.cbrt().ceil() as usize;
230
231        let mut indir = IntVec::new(bitlen(len) + 2);
232        let mut sparse = IntVec::new(bitlen(len));
233        let mut dense = IntVec::new(bitlen(large_popcnt));
234
235        let mut start = 0;
236        let mut pos = vec![];
237        for i in 0..len {
238            if buf[i] == X {
239                pos.push(i);
240            }
241            if !(pos.len() == large_popcnt || i == len - 1) {
242                continue;
243            }
244
245            let end = i;
246            if end + 1 - start > dense_max {
247                indir.push((sparse.len() << 1 | 0) as _);
248                indir.push(0);
249                indir.push(0);
250                for &p in &pos {
251                    sparse.push(p as _);
252                }
253            } else {
254                indir.push((dense.len() << 1 | 1) as _);
255                let ceil_len = (1..)
256                    .map(|i| branch.pow(i) * small_len)
257                    .find(|&b| b >= end + 1 - start)
258                    .unwrap();
259                let mut cur = dense.len();
260                for i in (start..start + ceil_len).step_by(small_len).rev() {
261                    let il = i.min(end + 1);
262                    let ir = (il + small_len).min(end + 1);
263                    let w = (il..ir).filter(|&i| buf[i] == X).count();
264                    dense.push(w as _);
265                }
266                while cur + branch < dense.len() {
267                    let mut sum = 0;
268                    for _ in 0..branch {
269                        sum += dense.get::<true>(cur);
270                        cur += 1;
271                    }
272                    dense.push(sum);
273                }
274                indir.push(dense.len() as _);
275                indir.push(start as _);
276            }
277
278            pos.clear();
279            start = i + 1;
280        }
281
282        let table_tree = Self::table_tree(large_popcnt, branch);
283
284        Self {
285            indir,
286            sparse,
287            dense,
288            table_tree,
289            large_popcnt,
290            branch,
291            small_len,
292        }
293    }
294
295    #[inline(always)]
296    fn lookup_tree(&self, w: u64, i: usize) -> (usize, usize) {
297        let bitlen_branch = bitlen(self.branch);
298        let wi = w as usize * self.large_popcnt + i;
299        let res = self.table_tree[wi] as usize;
300        // let res = self.table_tree.get_usize(wi);
301        (res >> bitlen_branch, res & !(!0 << bitlen_branch))
302    }
303
304    #[inline(always)]
305    fn lookup_word(&self, w: u64, i: usize) -> usize {
306        SELECT_TABLE[w as usize][i] as _
307    }
308
309    // fn table_tree(popcnt: usize, branch: usize) -> IntVec {
310    fn table_tree(popcnt: usize, branch: usize) -> Vec<u8> {
311        let len = bitlen(popcnt);
312        // let unit = len + bitlen(branch);
313
314        let enc = |i, j| i << bitlen(branch) | j;
315        let mut table = vec![];
316        // let mut table = IntVec::new(unit);
317        for i in 0..1 << (len * branch) {
318            let mut count = 0;
319            for b in 0..branch {
320                let sh = (branch - 1 - b) * len;
321                let c = i >> sh & !(!0 << len);
322                if count + c > popcnt {
323                    break;
324                }
325                for _ in 0..c {
326                    table.push(enc(count, b) as _);
327                }
328                count += c;
329            }
330            for _ in count..popcnt {
331                table.push(0);
332            }
333        }
334        table
335    }
336
337    fn table_word(len: usize) -> IntVec {
338        let unit = bitlen(len);
339        let mut table = IntVec::new(unit);
340        for i in 0..1 << len {
341            let mut cur = 0;
342            for j in 0..len {
343                if i >> j & 1 != 0 {
344                    table.push(j as _);
345                    cur += 1;
346                }
347            }
348            for _ in cur..len {
349                table.push(0);
350            }
351        }
352        table
353    }
354
355    #[inline(always)]
356    pub fn select<const X: bool>(&self, i: usize, b: &IntVec) -> usize {
357        let (il_div, il_mod) = (i / self.large_popcnt, i % self.large_popcnt);
358        let large = self.indir.get_usize(3 * il_div);
359        let (large_i, large_ty) = (large >> 1, large & 1);
360        if large_ty == 0 {
361            self.sparse.get_usize(large_i + il_mod)
362        } else {
363            let start = large_i;
364            let end = self.indir.get_usize(3 * il_div + 1);
365            let b_start = self.indir.get_usize(3 * il_div + 2);
366            let unit = bitlen(self.large_popcnt);
367            let branch = self.branch;
368            let mut cur = 0;
369            let mut i = il_mod;
370            let mut b_i = 0;
371            loop {
372                // let il = (end - (cur + branch)) * unit;
373                let ir = (end - cur) * unit;
374                let il = ir - branch * unit;
375                let w = self.dense.bits_range::<true>(il..ir);
376                let (acc, br) = self.lookup_tree(w, i);
377                let tmp = (cur + br + 1) * branch;
378                if end - start <= tmp {
379                    let il = b_start + (b_i * branch + br) * self.small_len;
380                    let ir = il + self.small_len;
381                    let w = b.bits_range::<X>(il..ir);
382                    break il + self.lookup_word(w, i - acc);
383                }
384                b_i = b_i * branch + br;
385                cur = tmp;
386                i -= acc;
387            }
388        }
389    }
390
391    #[cfg(test)]
392    pub fn size_info(&self) -> (usize, usize) {
393        // eprintln!("indir:  {} bits", self.indir.bitlen());
394        // eprintln!("sparse: {} bits", self.sparse.bitlen());
395        // eprintln!("dense:  {} bits", self.dense.bitlen());
396
397        let rt =
398            self.indir.bitlen() + self.sparse.bitlen() + self.dense.bitlen();
399
400        // (rt, rt + self.table_tree.bitlen() + self.table_word.bitlen())
401        (
402            rt,
403            rt + 8 * self.table_tree.len()
404                + 8 * SELECT_TABLE.len() * SELECT_TABLE[0].len(),
405        )
406    }
407}
408
409impl Rs01DictTree {
410    pub fn new(a: &[bool]) -> Self {
411        let rank_index = RankIndex::new(&a);
412        let mut buf = IntVec::new(1);
413        for &x in a {
414            buf.push(x as _);
415        }
416        let select_index =
417            (SelectIndex::new::<false>(&a), SelectIndex::new::<true>(&a));
418
419        // select.0 と select.1 で同じ lookup table を作るの無駄だから、
420        // rank も含めてそれらは親のクラスで持つ設計でもいいかも?
421        // buf と一緒に table も渡す感じで
422        Self { buf, rank_index, select_index }
423    }
424
425    pub fn rank<const X: bool>(&self, i: usize) -> usize {
426        self.rank_index.rank::<X>(i, &self.buf)
427    }
428    pub fn rank0(&self, i: usize) -> usize { self.rank::<false>(i) }
429    pub fn rank1(&self, i: usize) -> usize { self.rank::<true>(i) }
430
431    pub fn select<const X: bool>(&self, i: usize) -> usize {
432        if X {
433            self.select_index.1.select::<X>(i, &self.buf)
434        } else {
435            self.select_index.0.select::<X>(i, &self.buf)
436        }
437    }
438    pub fn select0(&self, i: usize) -> usize { self.select::<false>(i) }
439    pub fn select1(&self, i: usize) -> usize { self.select::<true>(i) }
440
441    #[cfg(test)]
442    pub fn size_info(&self) {
443        let len = self.buf.bitlen();
444        let naive = 3 * len * bitlen(len);
445        eprintln!("* naive: {naive:>10} bits, {:>10} words", naive / 64);
446
447        let (r, r_table) = self.rank_index.size_info();
448        let (s0, s0_table) = self.select_index.0.size_info();
449        let (s1, s1_table) = self.select_index.1.size_info();
450        let sum = r + s0 + s1;
451        let sum_table = r_table + s0_table + s1_table;
452
453        let ratio = sum as f64 / naive as f64;
454        eprintln!(
455            "- table: {sum:>10} bits, {:>10} words (x{ratio:.03})",
456            sum / 64
457        );
458        let ratio = sum_table as f64 / naive as f64;
459        eprintln!(
460            "+ table: {sum_table:>10} bits, {:>10} words (x{ratio:.03})",
461            sum_table / 64
462        );
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use rand::{
469        distributions::{Bernoulli, Distribution},
470        Rng, SeedableRng,
471    };
472    use rand_chacha::ChaCha20Rng;
473
474    use crate::*;
475
476    fn rng() -> ChaCha20Rng {
477        ChaCha20Rng::from_seed([
478            0x55, 0xEF, 0xE0, 0x3C, 0x71, 0xDA, 0xFC, 0xAB, 0x5C, 0x1A, 0x9F,
479            0xEB, 0xA4, 0x9E, 0x61, 0xE6, 0x1E, 0x7E, 0x29, 0x77, 0x38, 0x9A,
480            0xF5, 0x67, 0xF5, 0xDD, 0x07, 0x06, 0xAE, 0xE4, 0x5A, 0xDC,
481        ])
482    }
483
484    fn test_rank_internal(len: usize, p: f64) {
485        let mut rng = rng();
486        let dist = Bernoulli::new(p).unwrap();
487        let a: Vec<_> = (0..len).map(|_| dist.sample(&mut rng)).collect();
488        let naive: Vec<_> = a
489            .iter()
490            .map(|&x| x as usize)
491            .scan(0, |acc, x| Some(std::mem::replace(acc, *acc + x)))
492            .collect();
493        let dict = Rs01DictTree::new(&a);
494        for i in 0..len {
495            assert_eq!(dict.rank1(i), naive[i], "i: {}", i);
496            assert_eq!(dict.rank0(i), i - naive[i], "i: {}", i);
497        }
498        if p == 1.0 {
499            eprintln!("---");
500            eprintln!("a.len(): {}", a.len());
501            dict.size_info();
502        }
503    }
504
505    fn test_select_internal(len: usize, p: f64) {
506        eprintln!("{:?}", (len, p));
507        let mut rng = rng();
508        let dist = Bernoulli::new(p).unwrap();
509        let a: Vec<_> = (0..len).map(|_| dist.sample(&mut rng)).collect();
510        let naive: (Vec<_>, _) = (0..len).partition(|&i| !a[i]);
511        let dict = Rs01DictTree::new(&a);
512
513        for i in 0..naive.0.len() {
514            assert_eq!(dict.select0(i), naive.0[i], "i: {}", i);
515        }
516        for i in 0..naive.1.len() {
517            assert_eq!(dict.select1(i), naive.1[i], "i: {}", i);
518        }
519        if p == 1.0 {
520            eprintln!("---");
521            eprintln!("a.len(): {}", a.len());
522            dict.size_info();
523        }
524    }
525
526    #[test]
527    fn test_rank() {
528        for len in Some(0).into_iter().chain((0..=7).map(|e| 10_usize.pow(e))) {
529            for &p in &[1.0, 0.999, 0.9, 0.5, 0.1, 1.0e-3, 0.0] {
530                test_rank_internal(len, p);
531            }
532        }
533    }
534
535    #[test]
536    fn test_select() {
537        for len in Some(0).into_iter().chain((0..=7).map(|e| 10_usize.pow(e))) {
538            for &p in &[1.0, 0.999, 0.9, 0.5, 0.1, 1.0e-3, 0.0] {
539                test_select_internal(len, p);
540            }
541        }
542    }
543
544    #[test]
545    fn sanity_check() { test_select_internal(100, 0.2); }
546}
547
548#[cfg(test)]
549macro_rules! bitvec {
550    ($lit:literal) => {
551        $lit.iter()
552            .filter(|&&b| matches!(b, b'0' | b'1'))
553            .map(|&b| b != b'0')
554            .collect::<Vec<_>>()
555    };
556}
557
558#[test]
559fn simple() {
560    // let a = bitvec!(b"110 100 000 101 001 010 111 110 010");
561    // let _ = SelectIndex::new::<true>(&a);
562
563    // let a = bitvec!(b"1101 1000 0001 1010 0010 0101 1110 1101 0100");
564    // let _ = SelectIndex::new::<true>(&a);
565
566    // let a = bitvec!(b"1101 1000 0001 1010 0010 0101 11");
567    // let _ = SelectIndex::new::<true>(&a);
568
569    //                 (13)
570    //       3           4           6
571    //   2   1   0   2   1   1   3   2   1  <- sum of length log(log(n)^2)
572    // 110 100 000 101 001 010 111 110 010  <- block of length log(n)/2
573    //
574    // 3 4 6; 2 1 0; 2 1 1; 3 2 1
575    //
576    // 1 2 3; 1 1 2; 0 1 2; _ _ _; _
577    //
578    // 0..26
579    // [1, 2, 3, 1, 1, 2, 0, 1, 2, 6, 4, 3]
580    // ok
581
582    //         5              5              7
583    //    3    1    1    2    1    2    3    3    1
584    // 1101 1000 0001 1010 0010 0101 1110 1101 0100
585    //
586    // [1, 3, 3, 2, 1, 2, 1, 1, 3, 7, 5, 5]
587    // ok
588
589    //         5              5              2
590    //    3    1    1    2    1    2    2    0    0
591    // 1101 1000 0001 1010 0010 0101 11__ ____ ____
592    //
593    // [0, 0, 2, 2, 1, 2, 1, 1, 3, 2, 5, 5]
594    // ok
595
596    for i in 0..=7 {
597        let a = vec![false; 10_usize.pow(i)];
598        let dict = Rs01DictTree::new(&a);
599        dict.size_info();
600    }
601}
602
603#[test]
604fn table() {
605    let w = 0b_1101_1010_1001;
606    assert_eq!(RANK_TABLE[w], [0, 1, 1, 1, 2, 2, 3, 3, 4, 5, 5, 6]);
607    assert_eq!(SELECT_TABLE[w], [0, 3, 5, 7, 8, 10, 11, 0, 0, 0, 0, 0]);
608}