rs01dict_runtime/
lib.rs

1#![allow(unused)]
2
3use std::ops::Range;
4
5const W: usize = u64::BITS as usize;
6
7struct IntVec {
8    unit: usize,
9    buf: Vec<u64>,
10    len: usize,
11}
12
13pub struct Rs01DictRuntime {
14    buf: IntVec,
15    rank_index: RankIndex,
16    select_index: (SelectIndex, SelectIndex),
17}
18
19struct RankIndex {
20    large: IntVec,
21    small: IntVec,
22    table: IntVec,
23    large_len: usize,
24    small_len: usize,
25}
26
27struct SelectIndex {
28    small_popcnt: usize,
29    small_start: IntVec,
30    small_indir: IntVec,
31    small_sparse: IntVec,
32    small_sparse_offset: IntVec,
33    small_dense_max: usize,
34    large_popcnt: usize,
35    large_start: IntVec,
36    large_indir: IntVec,
37    large_sparse: IntVec,
38    table: IntVec,
39}
40
41impl IntVec {
42    pub fn new(unit: usize) -> Self { Self { unit, buf: vec![], len: 0 } }
43    pub fn len(&self) -> usize { self.len }
44    pub fn bitlen(&self) -> usize { self.len * self.unit }
45
46    pub fn push(&mut self, w: u64) {
47        let unit = self.unit;
48        assert!(unit == W || w & (!0 << unit) == 0);
49
50        let bitlen = self.bitlen();
51        if unit == 0 {
52            // nothing to do
53        } else if bitlen % W == 0 {
54            self.buf.push(w);
55        } else {
56            self.buf[bitlen / W] |= w << (bitlen % W);
57            if bitlen % W + unit > W {
58                self.buf.push(w >> (W - bitlen % W));
59            }
60        }
61        self.len += 1;
62    }
63
64    pub fn get_usize(&self, i: usize) -> usize { self.get::<true>(i) as _ }
65
66    pub fn get<const X: bool>(&self, i: usize) -> u64 {
67        let start = i * self.unit;
68        let end = start + self.unit;
69        self.bits_range::<X>(start..end)
70    }
71
72    pub fn bits_range<const X: bool>(
73        &self,
74        Range { start, end }: Range<usize>,
75    ) -> u64 {
76        let end = end.min(self.bitlen()); // (!)
77        let mask = if end - start == W { !0 } else { !(!0 << (end - start)) };
78        let res = if start == end {
79            0
80        } else if start % W == 0 {
81            self.buf[start / W]
82        } else if end <= (start / W + 1) * W {
83            self.buf[start / W] >> (start % W)
84        } else {
85            self.buf[start / W] >> (start % W)
86                | self.buf[end / W] << (W - start % W)
87        };
88        (if X { res } else { !res }) & mask
89    }
90}
91
92impl std::fmt::Debug for IntVec {
93    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        fmt.debug_list()
95            .entries((0..self.len).map(|i| self.get::<true>(i)))
96            .finish()
97    }
98}
99
100fn bitlen(n: usize) -> usize {
101    // max {1, ceil(log2(|{0, 1, ..., n-1}|))}
102    1.max((n + 1).next_power_of_two().trailing_zeros() as usize)
103}
104fn lg_half(n: usize) -> usize {
105    // log(n)/2
106    (1_usize..).find(|&i| 4_usize.saturating_pow(i as _) >= n).unwrap()
107}
108
109impl RankIndex {
110    pub fn new(buf: &[bool]) -> Self {
111        let len = buf.len();
112        let small_len = (1_usize..)
113            .find(|&i| 4_usize.saturating_pow(i as _) >= len)
114            .unwrap(); // log(n)/2
115        let large_len = (2 * small_len).pow(2); // log(n)^2
116
117        let small_bitlen = bitlen(len.min(large_len));
118        let large_bitlen = bitlen(len);
119
120        let mut small = IntVec::new(small_bitlen);
121        let mut large = IntVec::new(large_bitlen);
122        let mut small_acc = 0;
123        let mut large_acc = 0;
124        let per = large_len / small_len;
125        for (c, i) in buf
126            .chunks(small_len)
127            .map(|ch| ch.iter().filter(|&&x| x).count() as u64)
128            .zip((0..per).cycle())
129        {
130            small.push(small_acc);
131            small_acc = if i < per - 1 { small_acc + c } else { 0 };
132
133            if i == 0 {
134                large.push(large_acc);
135            }
136            large_acc += c as u64;
137        }
138
139        let table = Self::table(small_len);
140        Self { large, small, table, large_len, small_len }
141    }
142
143    fn table(len: usize) -> IntVec {
144        let unit = bitlen(len);
145        let mut table = IntVec::new(unit);
146        for i in 0..1 << len {
147            let mut cur = 0;
148            for j in 0..len {
149                table.push(cur);
150                if i >> j & 1 != 0 {
151                    cur += 1;
152                }
153            }
154        }
155        table
156    }
157
158    fn lookup(&self, w: u64, i: usize) -> usize {
159        let wi = w as usize * self.small_len + i;
160        self.table.get_usize(wi)
161    }
162
163    pub fn rank1(&self, i: usize, b: &IntVec) -> usize {
164        let large_acc = self.large.get_usize(i / self.large_len);
165        let small_acc = self.small.get_usize(i / self.small_len);
166        let il = i / self.small_len * self.small_len;
167        let ir = il + self.small_len;
168        let w = b.bits_range::<true>(il..ir);
169        let small = self.lookup(w, i % self.small_len);
170        large_acc + small_acc + small
171    }
172    pub fn rank0(&self, i: usize, b: &IntVec) -> usize { i - self.rank1(i, b) }
173    pub fn rank<const X: bool>(&self, i: usize, b: &IntVec) -> usize {
174        if X { self.rank1(i, b) } else { self.rank0(i, b) }
175    }
176
177    #[cfg(test)]
178    pub fn size_info(&self) -> usize {
179        // eprintln!("large: {} bits", self.large.bitlen());
180        // eprintln!("small: {} bits", self.small.bitlen());
181        // eprintln!("table: {} bits", self.table.bitlen());
182
183        self.large.bitlen() + self.small.bitlen() + self.table.bitlen()
184    }
185}
186
187impl SelectIndex {
188    pub fn new<const X: bool>(buf: &[bool]) -> Self {
189        let len = buf.len();
190        let small_popcnt = lg_half(len);
191        let large_popcnt = (2 * small_popcnt).pow(2); // log(n)^2
192        let small_dense_max =
193            (((len as f64).log2().max(1.0).log2().max(1.0).powi(4) / 24.0)
194                .ceil()) as usize;
195        let large_dense_max = large_popcnt.pow(2); // log(n)^4
196        let mut large_start = IntVec::new(bitlen(len));
197        let mut large_indir = IntVec::new(bitlen(len) + 1);
198        let mut large_sparse = IntVec::new(bitlen(len));
199        let mut small_start = IntVec::new(bitlen(large_dense_max));
200        let mut small_indir = IntVec::new(bitlen(large_dense_max) + 1);
201        let mut small_sparse = IntVec::new(bitlen(large_dense_max));
202        let mut small_sparse_offset = IntVec::new(bitlen(len));
203
204        let mut start = 0;
205        let mut pos = vec![];
206        for i in 0..len {
207            if buf[i] == X {
208                pos.push(i);
209            }
210            if !(pos.len() == large_popcnt || i == len - 1) {
211                continue;
212            }
213
214            let cur_large_start = start;
215            let cur_large_end = i;
216            large_start.push(cur_large_start as _);
217            small_sparse_offset.push(small_sparse.len() as _);
218            if cur_large_end + 1 - cur_large_start > large_dense_max {
219                large_indir.push((large_sparse.len() << 1 | 0) as _);
220                for p in pos.drain(..) {
221                    large_sparse.push(p as _);
222                }
223            } else {
224                large_indir.push((small_start.len() << 1 | 1) as _);
225                let small_start_offset = small_start.len();
226                let small_sparse_offset = small_sparse.len();
227                let mut cur_small_start = cur_large_start;
228                for j in (0..pos.len()).step_by(small_popcnt) {
229                    let start = cur_small_start;
230                    let end = if j + small_popcnt < pos.len() {
231                        pos[j + small_popcnt] - 1
232                    } else if i == len - 1 {
233                        i
234                    } else {
235                        pos[pos.len() - 1]
236                    };
237                    small_start.push((start - cur_large_start) as _);
238                    if end + 1 - start > small_dense_max {
239                        let tmp = (small_sparse.len() - small_sparse_offset)
240                            / small_popcnt;
241                        small_indir.push((tmp << 1 | 0) as _);
242                        for &p in &pos[j..pos.len().min(j + small_popcnt)] {
243                            let pos_offset = p - start;
244                            small_sparse.push(pos_offset as _);
245                        }
246                    } else {
247                        small_indir.push(0 << 1 | 1);
248                    }
249                    cur_small_start = end + 1;
250                }
251
252                pos.clear();
253            }
254            start = i + 1;
255        }
256
257        let table = Self::table(small_dense_max);
258        Self {
259            small_popcnt,
260            small_start,
261            small_indir,
262            small_sparse,
263            small_sparse_offset,
264            small_dense_max,
265            large_popcnt,
266            large_start,
267            large_indir,
268            large_sparse,
269            table,
270        }
271    }
272
273    fn table(len: usize) -> IntVec {
274        let unit = bitlen(len);
275        let mut table = IntVec::new(unit);
276        for i in 0..1 << len {
277            let mut cur = 0;
278            for j in 0..len {
279                if i >> j & 1 != 0 {
280                    table.push(j as _);
281                    cur += 1;
282                }
283            }
284            for _ in cur..len {
285                table.push(0);
286            }
287        }
288        table
289    }
290
291    fn lookup(&self, w: u64, i: usize) -> usize {
292        let wi = w as usize * self.small_dense_max + i;
293        self.table.get_usize(wi)
294    }
295
296    pub fn select<const X: bool>(&self, i: usize, b: &IntVec) -> usize {
297        let (il_div, il_mod) = (i / self.large_popcnt, i % self.large_popcnt);
298        let large = self.large_indir.get_usize(il_div);
299        let (large_i, large_ty) = (large >> 1, large & 1);
300        if large_ty == 0 {
301            self.large_sparse.get_usize(large_i + il_mod)
302        } else {
303            let large_start = self.large_start.get_usize(il_div);
304            let per = self.large_popcnt / self.small_popcnt;
305            let is_div = i / self.small_popcnt % per;
306            let is_mod = i % self.small_popcnt;
307
308            let small = self.small_indir.get_usize(large_i + is_div);
309            let (small_i, small_ty) = (small >> 1, small & 1);
310            let small_start = self.small_start.get_usize(large_i + is_div);
311            if small_ty == 0 {
312                let offset = self.small_sparse_offset.get_usize(il_div);
313                let small_sparse = self
314                    .small_sparse
315                    .get_usize(offset + small_i * self.small_popcnt + is_mod);
316                large_start + small_start + small_sparse
317            } else {
318                let offset = large_start + small_start;
319                let w =
320                    b.bits_range::<X>(offset..offset + self.small_dense_max);
321                offset + self.lookup(w, is_mod)
322            }
323        }
324    }
325
326    #[cfg(test)]
327    pub fn size_info(&self) -> usize {
328        eprintln!("small_start: {} bits", self.small_start.bitlen());
329        eprintln!("small_indir: {} bits", self.small_indir.bitlen());
330        eprintln!("small_sparse: {} bits", self.small_sparse.bitlen());
331        eprintln!(
332            "small_sparse_offset: {} bits",
333            self.small_sparse_offset.bitlen()
334        );
335        eprintln!("large_start: {} bits", self.large_start.bitlen());
336        eprintln!("large_indir: {} bits", self.large_indir.bitlen());
337        eprintln!("large_sparse: {} bits", self.large_sparse.bitlen());
338        eprintln!("table: {} bits", self.table.bitlen());
339
340        self.small_start.bitlen()
341            + self.small_indir.bitlen()
342            + self.small_sparse.bitlen()
343            + self.small_sparse_offset.bitlen()
344            + self.large_start.bitlen()
345            + self.large_indir.bitlen()
346            + self.large_sparse.bitlen()
347            + self.table.bitlen()
348    }
349}
350
351impl Rs01DictRuntime {
352    pub fn new(a: &[bool]) -> Self {
353        let rank_index = RankIndex::new(&a);
354        let mut buf = IntVec::new(1);
355        for &x in a {
356            buf.push(x as _);
357        }
358        let select_index =
359            (SelectIndex::new::<false>(&a), SelectIndex::new::<true>(&a));
360
361        // select.0 と select.1 で同じ lookup table を作るの無駄だから、
362        // rank も含めてそれらは親のクラスで持つ設計でもいいかも?
363        // buf と一緒に table も渡す感じで
364        Self { buf, rank_index, select_index }
365    }
366
367    pub fn rank<const X: bool>(&self, i: usize) -> usize {
368        self.rank_index.rank::<X>(i, &self.buf)
369    }
370    pub fn rank0(&self, i: usize) -> usize { self.rank::<false>(i) }
371    pub fn rank1(&self, i: usize) -> usize { self.rank::<true>(i) }
372
373    pub fn select<const X: bool>(&self, i: usize) -> usize {
374        if X {
375            self.select_index.1.select::<X>(i, &self.buf)
376        } else {
377            self.select_index.0.select::<X>(i, &self.buf)
378        }
379    }
380    pub fn select0(&self, i: usize) -> usize { self.select::<false>(i) }
381    pub fn select1(&self, i: usize) -> usize { self.select::<true>(i) }
382
383    #[cfg(test)]
384    pub fn size_info(&self) {
385        let mut sum = 0;
386        sum += self.rank_index.size_info();
387        sum += self.select_index.0.size_info();
388        sum += self.select_index.1.size_info();
389        let ratio = sum as f64 / self.buf.len() as f64;
390        eprintln!("total: {sum} bits (x{ratio:.03})");
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use rand::{
397        distributions::{Bernoulli, Distribution},
398        Rng, SeedableRng,
399    };
400    use rand_chacha::ChaCha20Rng;
401
402    use crate::*;
403
404    fn rng() -> ChaCha20Rng {
405        ChaCha20Rng::from_seed([
406            0x55, 0xEF, 0xE0, 0x3C, 0x71, 0xDA, 0xFC, 0xAB, 0x5C, 0x1A, 0x9F,
407            0xEB, 0xA4, 0x9E, 0x61, 0xE6, 0x1E, 0x7E, 0x29, 0x77, 0x38, 0x9A,
408            0xF5, 0x67, 0xF5, 0xDD, 0x07, 0x06, 0xAE, 0xE4, 0x5A, 0xDC,
409        ])
410    }
411
412    fn test_rank_internal(len: usize, p: f64) {
413        let mut rng = rng();
414        let dist = Bernoulli::new(p).unwrap();
415        let a: Vec<_> = (0..len).map(|_| dist.sample(&mut rng)).collect();
416        let naive: Vec<_> = a
417            .iter()
418            .map(|&x| x as usize)
419            .scan(0, |acc, x| Some(std::mem::replace(acc, *acc + x)))
420            .collect();
421        let dict = Rs01DictRuntime::new(&a);
422        for i in 0..len {
423            assert_eq!(dict.rank1(i), naive[i], "i: {}", i);
424            assert_eq!(dict.rank0(i), i - naive[i], "i: {}", i);
425        }
426        if p == 1.0 {
427            eprintln!("---");
428            eprintln!("a.len(): {}", a.len());
429            dict.size_info();
430        }
431    }
432
433    fn test_select_internal(len: usize, p: f64) {
434        let mut rng = rng();
435        let dist = Bernoulli::new(p).unwrap();
436        let a: Vec<_> = (0..len).map(|_| dist.sample(&mut rng)).collect();
437        let naive: (Vec<_>, _) = (0..len).partition(|&i| !a[i]);
438        let dict = Rs01DictRuntime::new(&a);
439        for i in 0..naive.0.len() {
440            assert_eq!(dict.select0(i), naive.0[i], "i: {}", i);
441        }
442        for i in 0..naive.1.len() {
443            assert_eq!(dict.select1(i), naive.1[i], "i: {}", i);
444        }
445        if p == 1.0 {
446            eprintln!("---");
447            eprintln!("a.len(): {}", a.len());
448            dict.size_info();
449        }
450    }
451
452    #[test]
453    fn test_rank() {
454        for len in Some(0).into_iter().chain((0..=7).map(|e| 10_usize.pow(e))) {
455            for &p in &[1.0, 0.999, 0.9, 0.5, 0.1, 1.0e-3, 0.0] {
456                test_rank_internal(len, p);
457            }
458        }
459    }
460
461    #[test]
462    fn test_select() {
463        for len in Some(0).into_iter().chain((0..=7).map(|e| 10_usize.pow(e))) {
464            for &p in &[1.0, 0.999, 0.9, 0.5, 0.1, 1.0e-3, 0.0] {
465                test_select_internal(len, p);
466            }
467        }
468    }
469
470    #[test]
471    fn sanity_check() { test_select_internal(100, 0.2); }
472}