1#![allow(unused)]
2
3use std::ops::Range;
4
5const W: usize = u64::BITS as usize;
6
7struct IntVec {
8 unit: usize,
9 buf: Vec<u64>,
10 len: usize,
11}
12
13pub struct Rs01DictRuntime {
14 buf: IntVec,
15 rank_index: RankIndex,
16 select_index: (SelectIndex, SelectIndex),
17}
18
19struct RankIndex {
20 large: IntVec,
21 small: IntVec,
22 table: IntVec,
23 large_len: usize,
24 small_len: usize,
25}
26
27struct SelectIndex {
28 small_popcnt: usize,
29 small_start: IntVec,
30 small_indir: IntVec,
31 small_sparse: IntVec,
32 small_sparse_offset: IntVec,
33 small_dense_max: usize,
34 large_popcnt: usize,
35 large_start: IntVec,
36 large_indir: IntVec,
37 large_sparse: IntVec,
38 table: IntVec,
39}
40
41impl IntVec {
42 pub fn new(unit: usize) -> Self { Self { unit, buf: vec![], len: 0 } }
43 pub fn len(&self) -> usize { self.len }
44 pub fn bitlen(&self) -> usize { self.len * self.unit }
45
46 pub fn push(&mut self, w: u64) {
47 let unit = self.unit;
48 assert!(unit == W || w & (!0 << unit) == 0);
49
50 let bitlen = self.bitlen();
51 if unit == 0 {
52 } else if bitlen % W == 0 {
54 self.buf.push(w);
55 } else {
56 self.buf[bitlen / W] |= w << (bitlen % W);
57 if bitlen % W + unit > W {
58 self.buf.push(w >> (W - bitlen % W));
59 }
60 }
61 self.len += 1;
62 }
63
64 pub fn get_usize(&self, i: usize) -> usize { self.get::<true>(i) as _ }
65
66 pub fn get<const X: bool>(&self, i: usize) -> u64 {
67 let start = i * self.unit;
68 let end = start + self.unit;
69 self.bits_range::<X>(start..end)
70 }
71
72 pub fn bits_range<const X: bool>(
73 &self,
74 Range { start, end }: Range<usize>,
75 ) -> u64 {
76 let end = end.min(self.bitlen()); let mask = if end - start == W { !0 } else { !(!0 << (end - start)) };
78 let res = if start == end {
79 0
80 } else if start % W == 0 {
81 self.buf[start / W]
82 } else if end <= (start / W + 1) * W {
83 self.buf[start / W] >> (start % W)
84 } else {
85 self.buf[start / W] >> (start % W)
86 | self.buf[end / W] << (W - start % W)
87 };
88 (if X { res } else { !res }) & mask
89 }
90}
91
92impl std::fmt::Debug for IntVec {
93 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 fmt.debug_list()
95 .entries((0..self.len).map(|i| self.get::<true>(i)))
96 .finish()
97 }
98}
99
100fn bitlen(n: usize) -> usize {
101 1.max((n + 1).next_power_of_two().trailing_zeros() as usize)
103}
104fn lg_half(n: usize) -> usize {
105 (1_usize..).find(|&i| 4_usize.saturating_pow(i as _) >= n).unwrap()
107}
108
109impl RankIndex {
110 pub fn new(buf: &[bool]) -> Self {
111 let len = buf.len();
112 let small_len = (1_usize..)
113 .find(|&i| 4_usize.saturating_pow(i as _) >= len)
114 .unwrap(); let large_len = (2 * small_len).pow(2); let small_bitlen = bitlen(len.min(large_len));
118 let large_bitlen = bitlen(len);
119
120 let mut small = IntVec::new(small_bitlen);
121 let mut large = IntVec::new(large_bitlen);
122 let mut small_acc = 0;
123 let mut large_acc = 0;
124 let per = large_len / small_len;
125 for (c, i) in buf
126 .chunks(small_len)
127 .map(|ch| ch.iter().filter(|&&x| x).count() as u64)
128 .zip((0..per).cycle())
129 {
130 small.push(small_acc);
131 small_acc = if i < per - 1 { small_acc + c } else { 0 };
132
133 if i == 0 {
134 large.push(large_acc);
135 }
136 large_acc += c as u64;
137 }
138
139 let table = Self::table(small_len);
140 Self { large, small, table, large_len, small_len }
141 }
142
143 fn table(len: usize) -> IntVec {
144 let unit = bitlen(len);
145 let mut table = IntVec::new(unit);
146 for i in 0..1 << len {
147 let mut cur = 0;
148 for j in 0..len {
149 table.push(cur);
150 if i >> j & 1 != 0 {
151 cur += 1;
152 }
153 }
154 }
155 table
156 }
157
158 fn lookup(&self, w: u64, i: usize) -> usize {
159 let wi = w as usize * self.small_len + i;
160 self.table.get_usize(wi)
161 }
162
163 pub fn rank1(&self, i: usize, b: &IntVec) -> usize {
164 let large_acc = self.large.get_usize(i / self.large_len);
165 let small_acc = self.small.get_usize(i / self.small_len);
166 let il = i / self.small_len * self.small_len;
167 let ir = il + self.small_len;
168 let w = b.bits_range::<true>(il..ir);
169 let small = self.lookup(w, i % self.small_len);
170 large_acc + small_acc + small
171 }
172 pub fn rank0(&self, i: usize, b: &IntVec) -> usize { i - self.rank1(i, b) }
173 pub fn rank<const X: bool>(&self, i: usize, b: &IntVec) -> usize {
174 if X { self.rank1(i, b) } else { self.rank0(i, b) }
175 }
176
177 #[cfg(test)]
178 pub fn size_info(&self) -> usize {
179 self.large.bitlen() + self.small.bitlen() + self.table.bitlen()
184 }
185}
186
187impl SelectIndex {
188 pub fn new<const X: bool>(buf: &[bool]) -> Self {
189 let len = buf.len();
190 let small_popcnt = lg_half(len);
191 let large_popcnt = (2 * small_popcnt).pow(2); let small_dense_max =
193 (((len as f64).log2().max(1.0).log2().max(1.0).powi(4) / 24.0)
194 .ceil()) as usize;
195 let large_dense_max = large_popcnt.pow(2); let mut large_start = IntVec::new(bitlen(len));
197 let mut large_indir = IntVec::new(bitlen(len) + 1);
198 let mut large_sparse = IntVec::new(bitlen(len));
199 let mut small_start = IntVec::new(bitlen(large_dense_max));
200 let mut small_indir = IntVec::new(bitlen(large_dense_max) + 1);
201 let mut small_sparse = IntVec::new(bitlen(large_dense_max));
202 let mut small_sparse_offset = IntVec::new(bitlen(len));
203
204 let mut start = 0;
205 let mut pos = vec![];
206 for i in 0..len {
207 if buf[i] == X {
208 pos.push(i);
209 }
210 if !(pos.len() == large_popcnt || i == len - 1) {
211 continue;
212 }
213
214 let cur_large_start = start;
215 let cur_large_end = i;
216 large_start.push(cur_large_start as _);
217 small_sparse_offset.push(small_sparse.len() as _);
218 if cur_large_end + 1 - cur_large_start > large_dense_max {
219 large_indir.push((large_sparse.len() << 1 | 0) as _);
220 for p in pos.drain(..) {
221 large_sparse.push(p as _);
222 }
223 } else {
224 large_indir.push((small_start.len() << 1 | 1) as _);
225 let small_start_offset = small_start.len();
226 let small_sparse_offset = small_sparse.len();
227 let mut cur_small_start = cur_large_start;
228 for j in (0..pos.len()).step_by(small_popcnt) {
229 let start = cur_small_start;
230 let end = if j + small_popcnt < pos.len() {
231 pos[j + small_popcnt] - 1
232 } else if i == len - 1 {
233 i
234 } else {
235 pos[pos.len() - 1]
236 };
237 small_start.push((start - cur_large_start) as _);
238 if end + 1 - start > small_dense_max {
239 let tmp = (small_sparse.len() - small_sparse_offset)
240 / small_popcnt;
241 small_indir.push((tmp << 1 | 0) as _);
242 for &p in &pos[j..pos.len().min(j + small_popcnt)] {
243 let pos_offset = p - start;
244 small_sparse.push(pos_offset as _);
245 }
246 } else {
247 small_indir.push(0 << 1 | 1);
248 }
249 cur_small_start = end + 1;
250 }
251
252 pos.clear();
253 }
254 start = i + 1;
255 }
256
257 let table = Self::table(small_dense_max);
258 Self {
259 small_popcnt,
260 small_start,
261 small_indir,
262 small_sparse,
263 small_sparse_offset,
264 small_dense_max,
265 large_popcnt,
266 large_start,
267 large_indir,
268 large_sparse,
269 table,
270 }
271 }
272
273 fn table(len: usize) -> IntVec {
274 let unit = bitlen(len);
275 let mut table = IntVec::new(unit);
276 for i in 0..1 << len {
277 let mut cur = 0;
278 for j in 0..len {
279 if i >> j & 1 != 0 {
280 table.push(j as _);
281 cur += 1;
282 }
283 }
284 for _ in cur..len {
285 table.push(0);
286 }
287 }
288 table
289 }
290
291 fn lookup(&self, w: u64, i: usize) -> usize {
292 let wi = w as usize * self.small_dense_max + i;
293 self.table.get_usize(wi)
294 }
295
296 pub fn select<const X: bool>(&self, i: usize, b: &IntVec) -> usize {
297 let (il_div, il_mod) = (i / self.large_popcnt, i % self.large_popcnt);
298 let large = self.large_indir.get_usize(il_div);
299 let (large_i, large_ty) = (large >> 1, large & 1);
300 if large_ty == 0 {
301 self.large_sparse.get_usize(large_i + il_mod)
302 } else {
303 let large_start = self.large_start.get_usize(il_div);
304 let per = self.large_popcnt / self.small_popcnt;
305 let is_div = i / self.small_popcnt % per;
306 let is_mod = i % self.small_popcnt;
307
308 let small = self.small_indir.get_usize(large_i + is_div);
309 let (small_i, small_ty) = (small >> 1, small & 1);
310 let small_start = self.small_start.get_usize(large_i + is_div);
311 if small_ty == 0 {
312 let offset = self.small_sparse_offset.get_usize(il_div);
313 let small_sparse = self
314 .small_sparse
315 .get_usize(offset + small_i * self.small_popcnt + is_mod);
316 large_start + small_start + small_sparse
317 } else {
318 let offset = large_start + small_start;
319 let w =
320 b.bits_range::<X>(offset..offset + self.small_dense_max);
321 offset + self.lookup(w, is_mod)
322 }
323 }
324 }
325
326 #[cfg(test)]
327 pub fn size_info(&self) -> usize {
328 eprintln!("small_start: {} bits", self.small_start.bitlen());
329 eprintln!("small_indir: {} bits", self.small_indir.bitlen());
330 eprintln!("small_sparse: {} bits", self.small_sparse.bitlen());
331 eprintln!(
332 "small_sparse_offset: {} bits",
333 self.small_sparse_offset.bitlen()
334 );
335 eprintln!("large_start: {} bits", self.large_start.bitlen());
336 eprintln!("large_indir: {} bits", self.large_indir.bitlen());
337 eprintln!("large_sparse: {} bits", self.large_sparse.bitlen());
338 eprintln!("table: {} bits", self.table.bitlen());
339
340 self.small_start.bitlen()
341 + self.small_indir.bitlen()
342 + self.small_sparse.bitlen()
343 + self.small_sparse_offset.bitlen()
344 + self.large_start.bitlen()
345 + self.large_indir.bitlen()
346 + self.large_sparse.bitlen()
347 + self.table.bitlen()
348 }
349}
350
351impl Rs01DictRuntime {
352 pub fn new(a: &[bool]) -> Self {
353 let rank_index = RankIndex::new(&a);
354 let mut buf = IntVec::new(1);
355 for &x in a {
356 buf.push(x as _);
357 }
358 let select_index =
359 (SelectIndex::new::<false>(&a), SelectIndex::new::<true>(&a));
360
361 Self { buf, rank_index, select_index }
365 }
366
367 pub fn rank<const X: bool>(&self, i: usize) -> usize {
368 self.rank_index.rank::<X>(i, &self.buf)
369 }
370 pub fn rank0(&self, i: usize) -> usize { self.rank::<false>(i) }
371 pub fn rank1(&self, i: usize) -> usize { self.rank::<true>(i) }
372
373 pub fn select<const X: bool>(&self, i: usize) -> usize {
374 if X {
375 self.select_index.1.select::<X>(i, &self.buf)
376 } else {
377 self.select_index.0.select::<X>(i, &self.buf)
378 }
379 }
380 pub fn select0(&self, i: usize) -> usize { self.select::<false>(i) }
381 pub fn select1(&self, i: usize) -> usize { self.select::<true>(i) }
382
383 #[cfg(test)]
384 pub fn size_info(&self) {
385 let mut sum = 0;
386 sum += self.rank_index.size_info();
387 sum += self.select_index.0.size_info();
388 sum += self.select_index.1.size_info();
389 let ratio = sum as f64 / self.buf.len() as f64;
390 eprintln!("total: {sum} bits (x{ratio:.03})");
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use rand::{
397 distributions::{Bernoulli, Distribution},
398 Rng, SeedableRng,
399 };
400 use rand_chacha::ChaCha20Rng;
401
402 use crate::*;
403
404 fn rng() -> ChaCha20Rng {
405 ChaCha20Rng::from_seed([
406 0x55, 0xEF, 0xE0, 0x3C, 0x71, 0xDA, 0xFC, 0xAB, 0x5C, 0x1A, 0x9F,
407 0xEB, 0xA4, 0x9E, 0x61, 0xE6, 0x1E, 0x7E, 0x29, 0x77, 0x38, 0x9A,
408 0xF5, 0x67, 0xF5, 0xDD, 0x07, 0x06, 0xAE, 0xE4, 0x5A, 0xDC,
409 ])
410 }
411
412 fn test_rank_internal(len: usize, p: f64) {
413 let mut rng = rng();
414 let dist = Bernoulli::new(p).unwrap();
415 let a: Vec<_> = (0..len).map(|_| dist.sample(&mut rng)).collect();
416 let naive: Vec<_> = a
417 .iter()
418 .map(|&x| x as usize)
419 .scan(0, |acc, x| Some(std::mem::replace(acc, *acc + x)))
420 .collect();
421 let dict = Rs01DictRuntime::new(&a);
422 for i in 0..len {
423 assert_eq!(dict.rank1(i), naive[i], "i: {}", i);
424 assert_eq!(dict.rank0(i), i - naive[i], "i: {}", i);
425 }
426 if p == 1.0 {
427 eprintln!("---");
428 eprintln!("a.len(): {}", a.len());
429 dict.size_info();
430 }
431 }
432
433 fn test_select_internal(len: usize, p: f64) {
434 let mut rng = rng();
435 let dist = Bernoulli::new(p).unwrap();
436 let a: Vec<_> = (0..len).map(|_| dist.sample(&mut rng)).collect();
437 let naive: (Vec<_>, _) = (0..len).partition(|&i| !a[i]);
438 let dict = Rs01DictRuntime::new(&a);
439 for i in 0..naive.0.len() {
440 assert_eq!(dict.select0(i), naive.0[i], "i: {}", i);
441 }
442 for i in 0..naive.1.len() {
443 assert_eq!(dict.select1(i), naive.1[i], "i: {}", i);
444 }
445 if p == 1.0 {
446 eprintln!("---");
447 eprintln!("a.len(): {}", a.len());
448 dict.size_info();
449 }
450 }
451
452 #[test]
453 fn test_rank() {
454 for len in Some(0).into_iter().chain((0..=7).map(|e| 10_usize.pow(e))) {
455 for &p in &[1.0, 0.999, 0.9, 0.5, 0.1, 1.0e-3, 0.0] {
456 test_rank_internal(len, p);
457 }
458 }
459 }
460
461 #[test]
462 fn test_select() {
463 for len in Some(0).into_iter().chain((0..=7).map(|e| 10_usize.pow(e))) {
464 for &p in &[1.0, 0.999, 0.9, 0.5, 0.1, 1.0e-3, 0.0] {
465 test_select_internal(len, p);
466 }
467 }
468 }
469
470 #[test]
471 fn sanity_check() { test_select_internal(100, 0.2); }
472}