Skip to main content

nekolib/ds/
rs_dict.rs

1//! rank/select 辞書。
2
3use super::super::traits::count;
4use super::super::traits::find_nth;
5use super::super::utils::buf_range;
6
7use std::fmt::Debug;
8use std::ops::{Range, RangeBounds};
9
10use buf_range::bounds_within;
11use count::Count;
12use find_nth::FindNth;
13
14const WORD_SIZE: usize = 64;
15const WORD_SIZE_2: usize = WORD_SIZE * WORD_SIZE;
16
17/// rank/select 辞書。
18///
19/// 要素が `0`/`1` からなる配列で、任意区間の `0`/`1` の個数を数えられる。
20///
21/// # Idea
22/// 要素数 $n$ のビット配列に対して、rank/select のクエリはそれぞれ $n+1$
23/// 通りしかないので、それらを $O(n)$ 時間で前計算しておけば、$3n+O(1)$ words で
24/// $O(1)$ query time を実現できる[^1]。
25///
26/// [^1]: $\\mathtt{rank}\_1$ の結果から $\\mathtt{rank}\_0$ の結果を求めることは可能だが、
27/// $\\mathtt{select}\_1$ の結果から $\\mathtt{select}\_0$ の結果を求めることはできない。
28///
29/// しかし、wavelet matrix などに用いる際はこれを 64 本持ったりする必要があることから、
30/// 空間を削減できた方がうれしそうなので、$6n/w+O(1)$ words, $O(\\log(w))$ query time
31/// の方法を用いた[^2]。
32///
33/// [^2]: やや複雑なので、もしかすると愚直の方がよい可能性もあるかも? 実測しましょう。
34///
35/// rank については、$w$ bits ごとに求めた個数の累積和 ($n/w$ words) を用いる。
36/// 端数については word size の popcount を行う[^3]。
37///
38/// [^3]: $O(\\log(w))$ time の方法は実装していません。`.count_ones()`
39/// の実測が遅いようなら考えます。ここを $O(1)$ time と見なしていいかはわかりません。
40/// 簡潔データ構造の文脈では、「どうせ表引きできるので...」となっていそうです。
41///
42/// select については `0` と `1` に対して用意する必要があり、以下では `1` のみ述べるが、
43/// `0` についても mutatis mutandis でできる。
44/// まず、`1` の $w$ 個おきの出現箇所を求める (at most $n/w$ words)。このうち、幅が
45/// $w^2$ 以上であるものを "疎" と呼び、そうでないところを "密" と呼ぶ。
46/// 疎である区間は高々 $n/w^2$ 個しかないので、その出現位置を陽に持っても $n/w$
47/// words で抑えられる。また、密である区間については、区間幅が $w^2$ 未満なので、
48/// クエリごとに二分探索しても $\\log(w)$ time で抑えられる。
49///
50/// # Complexity
51/// $O(n)$ preprocess, $O(n/w)$ space, $O(\\log(w))$ query time.
52#[derive(Clone, Debug)]
53pub struct RsDict {
54    len: usize,
55    buf: Vec<u64>,
56    rank: Vec<usize>,
57    sel0: Vec<SelectPreprocess>,
58    sel1: Vec<SelectPreprocess>,
59}
60
61#[derive(Clone, Debug)]
62enum SelectPreprocess {
63    Sparse(Vec<usize>),
64    Dense(Range<usize>),
65}
66use SelectPreprocess::{Dense, Sparse};
67
68impl From<Vec<bool>> for RsDict {
69    fn from(buf: Vec<bool>) -> Self {
70        let len = buf.len();
71        let buf = Self::compress_vec_bool(buf);
72        let rank = Self::preprocess_rank(&buf);
73        let sel0 = Self::preprocess_select(&buf, len, 0);
74        let sel1 = Self::preprocess_select(&buf, len, 1);
75        Self { len, buf, rank, sel0, sel1 }
76    }
77}
78
79impl RsDict {
80    fn compress_vec_bool(buf: Vec<bool>) -> Vec<u64> {
81        if buf.is_empty() {
82            return vec![];
83        }
84        let n = buf.len();
85        let nc = 1 + (n - 1) / WORD_SIZE;
86        let mut res = vec![0; nc + 1];
87        for i in 0..n {
88            if buf[i] {
89                res[i / WORD_SIZE] |= 1_u64 << (i % WORD_SIZE);
90            }
91        }
92        res
93    }
94    fn preprocess_rank(buf: &[u64]) -> Vec<usize> {
95        let n = buf.len();
96        let mut res = vec![0; n];
97        for i in 1..n {
98            res[i] = res[i - 1] + buf[i - 1].count_ones() as usize;
99        }
100        res
101    }
102    fn preprocess_select(
103        buf: &[u64],
104        n: usize,
105        x: u64,
106    ) -> Vec<SelectPreprocess> {
107        let mut sel = vec![];
108        let mut tmp = vec![];
109        let mut last = 0;
110        for i in 0..n {
111            if buf[i / WORD_SIZE] >> (i % WORD_SIZE) & 1 != x {
112                continue;
113            }
114            if tmp.len() == WORD_SIZE {
115                let len = i - last;
116                if len < WORD_SIZE_2 {
117                    sel.push(Dense(last..i));
118                } else {
119                    sel.push(Sparse(tmp));
120                }
121                tmp = vec![];
122                last = i;
123            }
124            tmp.push(i);
125        }
126        if !tmp.is_empty() {
127            sel.push(Sparse(tmp));
128        }
129        sel
130    }
131    pub fn rank(&self, end: usize, x: u64) -> usize {
132        let il = end / WORD_SIZE;
133        let is = end % WORD_SIZE;
134        let rank1 = self.rank[il]
135            + (self.buf[il] & !(!0_u64 << is)).count_ones() as usize;
136        let rank = if x == 0 { end - rank1 } else { rank1 };
137        rank
138    }
139    pub fn select(&self, x: u64, k: usize) -> Option<usize> {
140        if self.rank(self.len, x) < k {
141            None
142        } else if k == 0 {
143            Some(0)
144        } else {
145            Some(self.find_nth_internal(x, k - 1) + 1)
146        }
147    }
148}
149
150impl Count<u64> for RsDict {
151    fn count(&self, r: impl RangeBounds<usize>, x: u64) -> usize {
152        let Range { start, end } = bounds_within(r, self.len);
153        if start > 0 {
154            self.rank(end, x) - self.rank(start, x)
155        } else {
156            self.rank(end, x)
157        }
158    }
159}
160
161impl FindNth<u64> for RsDict {
162    fn find_nth(
163        &self,
164        r: impl RangeBounds<usize>,
165        x: u64,
166        n: usize,
167    ) -> Option<usize> {
168        let Range { start, end } = bounds_within(r, self.len);
169        if self.count(start..end, x) <= n {
170            None
171        } else {
172            let offset = self.rank(start, x);
173            Some(self.find_nth_internal(x, offset + n))
174        }
175    }
176}
177
178impl RsDict {
179    fn find_nth_internal(&self, x: u64, n: usize) -> usize {
180        if self.rank(self.len, x) < n {
181            panic!("the number of {}s is less than {}", x, n);
182        }
183        let sel = if x == 0 { &self.sel0 } else { &self.sel1 };
184        let il = n / WORD_SIZE;
185        let is = n % WORD_SIZE;
186        match &sel[il] {
187            Sparse(dir) => dir[is],
188            Dense(range) => {
189                let mut lo = range.start / WORD_SIZE;
190                let mut hi = 1 + (range.end - 1) / WORD_SIZE;
191                while hi - lo > 1 {
192                    let mid = lo + (hi - lo) / 2;
193                    let rank = self.rank_rough(mid, x);
194                    *(if rank <= n { &mut lo } else { &mut hi }) = mid;
195                }
196                let rank_frac = n - self.rank_rough(lo, x);
197                lo * WORD_SIZE
198                    + Self::find_nth_small(self.buf[lo], x, rank_frac)
199            }
200        }
201    }
202    fn rank_rough(&self, n: usize, x: u64) -> usize {
203        let rank1 = self.rank[n];
204        let rank = if x == 0 { n * WORD_SIZE - rank1 } else { rank1 };
205        rank
206    }
207    fn find_nth_small(word: u64, x: u64, n: usize) -> usize {
208        let mut word = if x == 0 { !word } else { word };
209        let mut n = n as u32;
210        let mut res = 0;
211        for &mid in &[32, 16, 8, 4, 2, 1] {
212            let count = (word & !(!0 << mid)).count_ones();
213            if count <= n {
214                n -= count;
215                word >>= mid;
216                res += mid;
217            }
218        }
219        res
220    }
221}
222
223#[test]
224fn select_internal() {
225    assert_eq!(RsDict::find_nth_small(0x00000000_00000001_u64, 1, 0), 0);
226    assert_eq!(RsDict::find_nth_small(0x00000000_00000003_u64, 1, 1), 1);
227    assert_eq!(RsDict::find_nth_small(0x00000000_00000010_u64, 1, 0), 4);
228    assert_eq!(RsDict::find_nth_small(0xffffffff_ffffffff_u64, 1, 63), 63);
229}
230
231#[test]
232fn test_rs() {
233    let n = 65536 + 4096;
234    let buf: Vec<_> = (0..n).map(|i| i % 1024 != 0).collect();
235
236    let rs: RsDict = buf.clone().into();
237    let mut zero = 0;
238    let mut one = 0;
239    for i in 0..n {
240        assert_eq!(rs.count(0..i, 0), zero);
241        assert_eq!(rs.count(0..i, 1), one);
242        if buf[i] {
243            one += 1;
244        } else {
245            zero += 1;
246        }
247    }
248    assert_eq!(rs.count(.., 0), zero);
249    assert_eq!(rs.count(.., 1), one);
250
251    let zeros: Vec<_> = (0..n).filter(|&i| !buf[i]).collect();
252    let ones: Vec<_> = (0..n).filter(|&i| buf[i]).collect();
253
254    for i in 0..zeros.len() {
255        let s0 = rs.find_nth(.., 0, i);
256        assert_eq!(s0, Some(zeros[i]));
257        assert_eq!(rs.count(..=s0.unwrap(), 0), i + 1);
258    }
259    for i in 0..ones.len() {
260        let s1 = rs.find_nth(.., 1, i);
261        assert_eq!(s1, Some(ones[i]));
262        assert_eq!(rs.count(..=s1.unwrap(), 1), i + 1);
263    }
264    assert_eq!(rs.find_nth(.., 0, zeros.len()), None);
265    assert_eq!(rs.find_nth(.., 1, ones.len()), None);
266}