1use super::super::traits::count;
4use super::super::traits::find_nth;
5use super::super::utils::buf_range;
6
7use std::fmt::Debug;
8use std::ops::{Range, RangeBounds};
9
10use buf_range::bounds_within;
11use count::Count;
12use find_nth::FindNth;
13
14const WORD_SIZE: usize = 64;
15const WORD_SIZE_2: usize = WORD_SIZE * WORD_SIZE;
16
17#[derive(Clone, Debug)]
53pub struct RsDict {
54 len: usize,
55 buf: Vec<u64>,
56 rank: Vec<usize>,
57 sel0: Vec<SelectPreprocess>,
58 sel1: Vec<SelectPreprocess>,
59}
60
61#[derive(Clone, Debug)]
62enum SelectPreprocess {
63 Sparse(Vec<usize>),
64 Dense(Range<usize>),
65}
66use SelectPreprocess::{Dense, Sparse};
67
68impl From<Vec<bool>> for RsDict {
69 fn from(buf: Vec<bool>) -> Self {
70 let len = buf.len();
71 let buf = Self::compress_vec_bool(buf);
72 let rank = Self::preprocess_rank(&buf);
73 let sel0 = Self::preprocess_select(&buf, len, 0);
74 let sel1 = Self::preprocess_select(&buf, len, 1);
75 Self { len, buf, rank, sel0, sel1 }
76 }
77}
78
79impl RsDict {
80 fn compress_vec_bool(buf: Vec<bool>) -> Vec<u64> {
81 if buf.is_empty() {
82 return vec![];
83 }
84 let n = buf.len();
85 let nc = 1 + (n - 1) / WORD_SIZE;
86 let mut res = vec![0; nc + 1];
87 for i in 0..n {
88 if buf[i] {
89 res[i / WORD_SIZE] |= 1_u64 << (i % WORD_SIZE);
90 }
91 }
92 res
93 }
94 fn preprocess_rank(buf: &[u64]) -> Vec<usize> {
95 let n = buf.len();
96 let mut res = vec![0; n];
97 for i in 1..n {
98 res[i] = res[i - 1] + buf[i - 1].count_ones() as usize;
99 }
100 res
101 }
102 fn preprocess_select(
103 buf: &[u64],
104 n: usize,
105 x: u64,
106 ) -> Vec<SelectPreprocess> {
107 let mut sel = vec![];
108 let mut tmp = vec![];
109 let mut last = 0;
110 for i in 0..n {
111 if buf[i / WORD_SIZE] >> (i % WORD_SIZE) & 1 != x {
112 continue;
113 }
114 if tmp.len() == WORD_SIZE {
115 let len = i - last;
116 if len < WORD_SIZE_2 {
117 sel.push(Dense(last..i));
118 } else {
119 sel.push(Sparse(tmp));
120 }
121 tmp = vec![];
122 last = i;
123 }
124 tmp.push(i);
125 }
126 if !tmp.is_empty() {
127 sel.push(Sparse(tmp));
128 }
129 sel
130 }
131 pub fn rank(&self, end: usize, x: u64) -> usize {
132 let il = end / WORD_SIZE;
133 let is = end % WORD_SIZE;
134 let rank1 = self.rank[il]
135 + (self.buf[il] & !(!0_u64 << is)).count_ones() as usize;
136 let rank = if x == 0 { end - rank1 } else { rank1 };
137 rank
138 }
139 pub fn select(&self, x: u64, k: usize) -> Option<usize> {
140 if self.rank(self.len, x) < k {
141 None
142 } else if k == 0 {
143 Some(0)
144 } else {
145 Some(self.find_nth_internal(x, k - 1) + 1)
146 }
147 }
148}
149
150impl Count<u64> for RsDict {
151 fn count(&self, r: impl RangeBounds<usize>, x: u64) -> usize {
152 let Range { start, end } = bounds_within(r, self.len);
153 if start > 0 {
154 self.rank(end, x) - self.rank(start, x)
155 } else {
156 self.rank(end, x)
157 }
158 }
159}
160
161impl FindNth<u64> for RsDict {
162 fn find_nth(
163 &self,
164 r: impl RangeBounds<usize>,
165 x: u64,
166 n: usize,
167 ) -> Option<usize> {
168 let Range { start, end } = bounds_within(r, self.len);
169 if self.count(start..end, x) <= n {
170 None
171 } else {
172 let offset = self.rank(start, x);
173 Some(self.find_nth_internal(x, offset + n))
174 }
175 }
176}
177
178impl RsDict {
179 fn find_nth_internal(&self, x: u64, n: usize) -> usize {
180 if self.rank(self.len, x) < n {
181 panic!("the number of {}s is less than {}", x, n);
182 }
183 let sel = if x == 0 { &self.sel0 } else { &self.sel1 };
184 let il = n / WORD_SIZE;
185 let is = n % WORD_SIZE;
186 match &sel[il] {
187 Sparse(dir) => dir[is],
188 Dense(range) => {
189 let mut lo = range.start / WORD_SIZE;
190 let mut hi = 1 + (range.end - 1) / WORD_SIZE;
191 while hi - lo > 1 {
192 let mid = lo + (hi - lo) / 2;
193 let rank = self.rank_rough(mid, x);
194 *(if rank <= n { &mut lo } else { &mut hi }) = mid;
195 }
196 let rank_frac = n - self.rank_rough(lo, x);
197 lo * WORD_SIZE
198 + Self::find_nth_small(self.buf[lo], x, rank_frac)
199 }
200 }
201 }
202 fn rank_rough(&self, n: usize, x: u64) -> usize {
203 let rank1 = self.rank[n];
204 let rank = if x == 0 { n * WORD_SIZE - rank1 } else { rank1 };
205 rank
206 }
207 fn find_nth_small(word: u64, x: u64, n: usize) -> usize {
208 let mut word = if x == 0 { !word } else { word };
209 let mut n = n as u32;
210 let mut res = 0;
211 for &mid in &[32, 16, 8, 4, 2, 1] {
212 let count = (word & !(!0 << mid)).count_ones();
213 if count <= n {
214 n -= count;
215 word >>= mid;
216 res += mid;
217 }
218 }
219 res
220 }
221}
222
223#[test]
224fn select_internal() {
225 assert_eq!(RsDict::find_nth_small(0x00000000_00000001_u64, 1, 0), 0);
226 assert_eq!(RsDict::find_nth_small(0x00000000_00000003_u64, 1, 1), 1);
227 assert_eq!(RsDict::find_nth_small(0x00000000_00000010_u64, 1, 0), 4);
228 assert_eq!(RsDict::find_nth_small(0xffffffff_ffffffff_u64, 1, 63), 63);
229}
230
231#[test]
232fn test_rs() {
233 let n = 65536 + 4096;
234 let buf: Vec<_> = (0..n).map(|i| i % 1024 != 0).collect();
235
236 let rs: RsDict = buf.clone().into();
237 let mut zero = 0;
238 let mut one = 0;
239 for i in 0..n {
240 assert_eq!(rs.count(0..i, 0), zero);
241 assert_eq!(rs.count(0..i, 1), one);
242 if buf[i] {
243 one += 1;
244 } else {
245 zero += 1;
246 }
247 }
248 assert_eq!(rs.count(.., 0), zero);
249 assert_eq!(rs.count(.., 1), one);
250
251 let zeros: Vec<_> = (0..n).filter(|&i| !buf[i]).collect();
252 let ones: Vec<_> = (0..n).filter(|&i| buf[i]).collect();
253
254 for i in 0..zeros.len() {
255 let s0 = rs.find_nth(.., 0, i);
256 assert_eq!(s0, Some(zeros[i]));
257 assert_eq!(rs.count(..=s0.unwrap(), 0), i + 1);
258 }
259 for i in 0..ones.len() {
260 let s1 = rs.find_nth(.., 1, i);
261 assert_eq!(s1, Some(ones[i]));
262 assert_eq!(rs.count(..=s1.unwrap(), 1), i + 1);
263 }
264 assert_eq!(rs.find_nth(.., 0, zeros.len()), None);
265 assert_eq!(rs.find_nth(.., 1, ones.len()), None);
266}