small_rank_select/
lib.rs

1use std::ops::Range;
2
3const W: usize = u64::BITS as usize;
4const BLOCK_LEN: usize = 8;
5const BLOCK: u64 = !(!0 << BLOCK_LEN);
6const HI_POS: usize = BLOCK_LEN - 1;
7const M_LO: u64 = 0x0101010101010101;
8const M_IOTA: u64 = 0x8040201008040201;
9const M_HI: u64 = M_LO << HI_POS;
10
11#[inline(always)]
12fn splat(w: u8) -> u64 { M_LO * w as u64 }
13#[inline(always)]
14fn nonzero(w: u64) -> u64 { ((w | (M_HI - M_LO + w)) & M_HI) >> HI_POS }
15#[inline(always)]
16fn expand(w: u8) -> u64 { nonzero(splat(w) & M_IOTA) }
17#[inline(always)]
18fn accumulate(w: u64) -> u64 { w.wrapping_mul(M_LO) }
19#[inline(always)]
20fn get(w: u64, i: usize) -> usize { (w >> (BLOCK_LEN * i) & BLOCK) as _ }
21#[inline(always)]
22fn gt_eq(wl: u64, wr: u64) -> u64 { (((wl | M_HI) - wr) & M_HI) >> HI_POS }
23#[inline(always)]
24fn shift(w: u64) -> u64 { w << BLOCK_LEN }
25#[inline(always)]
26fn popcnt(w: u64) -> usize { (accumulate(w) >> (W - BLOCK_LEN)) as _ }
27
28#[inline(always)]
29pub fn rank(w: u8, i: usize) -> usize { get(shift(accumulate(expand(w))), i) }
30#[inline(always)]
31pub fn select(w: u8, i: usize) -> usize {
32    popcnt(gt_eq(splat(i as _), accumulate(expand(w))))
33}
34
35pub const fn const_rank_table<const LEN: usize, const PAT: usize>()
36-> [[u8; LEN]; PAT] {
37    let mut res = [[0; LEN]; PAT];
38    let mut i = 0;
39    while i < PAT {
40        let mut cur = 0;
41        let mut j = 0;
42        while j < LEN {
43            res[i][j] = cur;
44            if i >> j & 1 != 0 {
45                cur += 1;
46            }
47            j += 1;
48        }
49        i += 1;
50    }
51    res
52}
53
54pub const fn const_select_table<const LEN: usize, const PAT: usize>()
55-> [[u8; LEN]; PAT] {
56    let mut res = [[0; LEN]; PAT];
57    let mut i = 0;
58    while i < PAT {
59        let mut cur = 0;
60        let mut j = 0;
61        while j < LEN {
62            if i >> j & 1 != 0 {
63                res[i][cur] = j as _;
64                cur += 1;
65            }
66            j += 1;
67        }
68        i += 1;
69    }
70    res
71}
72
73#[cfg(test)]
74mod tests {
75    use crate::*;
76
77    const RANK_TABLE: [[u8; 8]; 256] = const_rank_table::<8, 256>();
78    const SELECT_TABLE: [[u8; 8]; 256] = const_select_table::<8, 256>();
79
80    #[test]
81    fn test_rank() {
82        for w in 0_u8..=!0 {
83            for i in 0..8 {
84                assert_eq!(rank(w, i), RANK_TABLE[w as usize][i] as usize);
85            }
86        }
87    }
88
89    #[test]
90    fn test_select() {
91        for w in 0_u8..=!0 {
92            for i in 0..w.count_ones() as _ {
93                assert_eq!(select(w, i), SELECT_TABLE[w as usize][i] as usize);
94            }
95        }
96    }
97}
98
99pub struct IntVec {
100    unit: usize,
101    buf: Vec<u64>,
102    len: usize,
103}
104
105impl IntVec {
106    pub fn new(unit: usize) -> Self { Self { unit, buf: vec![], len: 0 } }
107    pub fn len(&self) -> usize { self.len }
108    pub fn bitlen(&self) -> usize { self.len * self.unit }
109
110    pub fn push(&mut self, w: u64) {
111        let unit = self.unit;
112        debug_assert!(unit == W || w & (!0 << unit) == 0);
113
114        let bitlen = self.bitlen();
115        if unit == 0 {
116            // nothing to do
117        } else if bitlen % W == 0 {
118            self.buf.push(w);
119        } else {
120            self.buf[bitlen / W] |= w << (bitlen % W);
121            if bitlen % W + unit > W {
122                self.buf.push(w >> (W - bitlen % W));
123            }
124        }
125        self.len += 1;
126    }
127
128    #[inline(always)]
129    pub fn get_usize(&self, i: usize) -> usize { self.get::<true>(i) as _ }
130
131    #[inline(always)]
132    pub fn get<const X: bool>(&self, i: usize) -> u64 {
133        let start = i * self.unit;
134        let end = start + self.unit;
135        self.bits_range::<X>(start..end)
136    }
137
138    #[inline(always)]
139    pub fn bits_range<const X: bool>(
140        &self,
141        Range { start, end }: Range<usize>,
142    ) -> u64 {
143        let end = end.min(self.bitlen()); // (!)
144        let mask = !(!0 << (end - start));
145
146        let mut res = self.buf[start / W] >> (start % W);
147        if end > (start / W + 1) * W {
148            res |= self.buf[end / W] << (W - start % W);
149        }
150
151        ((if X { res } else { !res }) & mask) as _
152    }
153}
154
155pub struct RankTable(IntVec);
156pub struct SelectTable(IntVec);
157
158impl RankTable {
159    pub fn new() -> Self {
160        let len = 8;
161        let unit = 3;
162        let mut table = IntVec::new(unit);
163        for i in 0..1 << len {
164            let mut cur = 0;
165            for j in 0..len {
166                table.push(cur);
167                if i >> j & 1 != 0 {
168                    cur += 1;
169                }
170            }
171        }
172        Self(table)
173    }
174    pub fn rank(&self, w: u64, i: usize) -> usize {
175        let wi = w as usize * 8 + i;
176        self.0.get_usize(wi)
177    }
178}
179
180impl SelectTable {
181    pub fn new() -> Self {
182        let len = 8;
183        let unit = 3;
184        let mut table = IntVec::new(unit);
185        for i in 0..1 << len {
186            let mut cur = 0;
187            for j in 0..len {
188                if i >> j & 1 != 0 {
189                    table.push(j as _);
190                    cur += 1;
191                }
192            }
193            for _ in cur..len {
194                table.push(0);
195            }
196        }
197        Self(table)
198    }
199
200    pub fn select(&self, w: u64, i: usize) -> usize {
201        let wi = w as usize * 8 + i;
202        self.0.get_usize(wi)
203    }
204}