1use std::ops::Range;
2
3const W: usize = u64::BITS as usize;
4const BLOCK_LEN: usize = 8;
5const BLOCK: u64 = !(!0 << BLOCK_LEN);
6const HI_POS: usize = BLOCK_LEN - 1;
7const M_LO: u64 = 0x0101010101010101;
8const M_IOTA: u64 = 0x8040201008040201;
9const M_HI: u64 = M_LO << HI_POS;
10
11#[inline(always)]
12fn splat(w: u8) -> u64 { M_LO * w as u64 }
13#[inline(always)]
14fn nonzero(w: u64) -> u64 { ((w | (M_HI - M_LO + w)) & M_HI) >> HI_POS }
15#[inline(always)]
16fn expand(w: u8) -> u64 { nonzero(splat(w) & M_IOTA) }
17#[inline(always)]
18fn accumulate(w: u64) -> u64 { w.wrapping_mul(M_LO) }
19#[inline(always)]
20fn get(w: u64, i: usize) -> usize { (w >> (BLOCK_LEN * i) & BLOCK) as _ }
21#[inline(always)]
22fn gt_eq(wl: u64, wr: u64) -> u64 { (((wl | M_HI) - wr) & M_HI) >> HI_POS }
23#[inline(always)]
24fn shift(w: u64) -> u64 { w << BLOCK_LEN }
25#[inline(always)]
26fn popcnt(w: u64) -> usize { (accumulate(w) >> (W - BLOCK_LEN)) as _ }
27
28#[inline(always)]
29pub fn rank(w: u8, i: usize) -> usize { get(shift(accumulate(expand(w))), i) }
30#[inline(always)]
31pub fn select(w: u8, i: usize) -> usize {
32 popcnt(gt_eq(splat(i as _), accumulate(expand(w))))
33}
34
35pub const fn const_rank_table<const LEN: usize, const PAT: usize>()
36-> [[u8; LEN]; PAT] {
37 let mut res = [[0; LEN]; PAT];
38 let mut i = 0;
39 while i < PAT {
40 let mut cur = 0;
41 let mut j = 0;
42 while j < LEN {
43 res[i][j] = cur;
44 if i >> j & 1 != 0 {
45 cur += 1;
46 }
47 j += 1;
48 }
49 i += 1;
50 }
51 res
52}
53
54pub const fn const_select_table<const LEN: usize, const PAT: usize>()
55-> [[u8; LEN]; PAT] {
56 let mut res = [[0; LEN]; PAT];
57 let mut i = 0;
58 while i < PAT {
59 let mut cur = 0;
60 let mut j = 0;
61 while j < LEN {
62 if i >> j & 1 != 0 {
63 res[i][cur] = j as _;
64 cur += 1;
65 }
66 j += 1;
67 }
68 i += 1;
69 }
70 res
71}
72
73#[cfg(test)]
74mod tests {
75 use crate::*;
76
77 const RANK_TABLE: [[u8; 8]; 256] = const_rank_table::<8, 256>();
78 const SELECT_TABLE: [[u8; 8]; 256] = const_select_table::<8, 256>();
79
80 #[test]
81 fn test_rank() {
82 for w in 0_u8..=!0 {
83 for i in 0..8 {
84 assert_eq!(rank(w, i), RANK_TABLE[w as usize][i] as usize);
85 }
86 }
87 }
88
89 #[test]
90 fn test_select() {
91 for w in 0_u8..=!0 {
92 for i in 0..w.count_ones() as _ {
93 assert_eq!(select(w, i), SELECT_TABLE[w as usize][i] as usize);
94 }
95 }
96 }
97}
98
99pub struct IntVec {
100 unit: usize,
101 buf: Vec<u64>,
102 len: usize,
103}
104
105impl IntVec {
106 pub fn new(unit: usize) -> Self { Self { unit, buf: vec![], len: 0 } }
107 pub fn len(&self) -> usize { self.len }
108 pub fn bitlen(&self) -> usize { self.len * self.unit }
109
110 pub fn push(&mut self, w: u64) {
111 let unit = self.unit;
112 debug_assert!(unit == W || w & (!0 << unit) == 0);
113
114 let bitlen = self.bitlen();
115 if unit == 0 {
116 } else if bitlen % W == 0 {
118 self.buf.push(w);
119 } else {
120 self.buf[bitlen / W] |= w << (bitlen % W);
121 if bitlen % W + unit > W {
122 self.buf.push(w >> (W - bitlen % W));
123 }
124 }
125 self.len += 1;
126 }
127
128 #[inline(always)]
129 pub fn get_usize(&self, i: usize) -> usize { self.get::<true>(i) as _ }
130
131 #[inline(always)]
132 pub fn get<const X: bool>(&self, i: usize) -> u64 {
133 let start = i * self.unit;
134 let end = start + self.unit;
135 self.bits_range::<X>(start..end)
136 }
137
138 #[inline(always)]
139 pub fn bits_range<const X: bool>(
140 &self,
141 Range { start, end }: Range<usize>,
142 ) -> u64 {
143 let end = end.min(self.bitlen()); let mask = !(!0 << (end - start));
145
146 let mut res = self.buf[start / W] >> (start % W);
147 if end > (start / W + 1) * W {
148 res |= self.buf[end / W] << (W - start % W);
149 }
150
151 ((if X { res } else { !res }) & mask) as _
152 }
153}
154
155pub struct RankTable(IntVec);
156pub struct SelectTable(IntVec);
157
158impl RankTable {
159 pub fn new() -> Self {
160 let len = 8;
161 let unit = 3;
162 let mut table = IntVec::new(unit);
163 for i in 0..1 << len {
164 let mut cur = 0;
165 for j in 0..len {
166 table.push(cur);
167 if i >> j & 1 != 0 {
168 cur += 1;
169 }
170 }
171 }
172 Self(table)
173 }
174 pub fn rank(&self, w: u64, i: usize) -> usize {
175 let wi = w as usize * 8 + i;
176 self.0.get_usize(wi)
177 }
178}
179
180impl SelectTable {
181 pub fn new() -> Self {
182 let len = 8;
183 let unit = 3;
184 let mut table = IntVec::new(unit);
185 for i in 0..1 << len {
186 let mut cur = 0;
187 for j in 0..len {
188 if i >> j & 1 != 0 {
189 table.push(j as _);
190 cur += 1;
191 }
192 }
193 for _ in cur..len {
194 table.push(0);
195 }
196 }
197 Self(table)
198 }
199
200 pub fn select(&self, w: u64, i: usize) -> usize {
201 let wi = w as usize * 8 + i;
202 self.0.get_usize(wi)
203 }
204}