Skip to main content

nekolib/ds/
bit_set.rs

1//! bit set。
2
3use super::super::utils::buf_range;
4
5use std::cmp::Ordering;
6use std::fmt;
7use std::iter::FromIterator;
8use std::ops::{
9    BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Range,
10    RangeBounds, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
11};
12
13use buf_range::{bounds_within, check_bounds};
14
15type Word = u64;
16const WORD_SIZE: usize = (0 as Word).count_zeros() as usize;
17
18/// Bit set。
19///
20/// # Implementation notes
21/// `&` `|` `^` について、左辺の capacity を持つ新たな `BitSet` を返すため、可換でない。
22/// 可換にすることにすると、 `x = x | y` と `x |= y` の整合性を取りたくなり、`|=`
23/// の操作でも capacity を変化させるのが妥当になるが、あまりうれしくなさそう?
24///
25/// capacity は暗黙に変わらないような設計にしているが、`eq()` などでは capacity
26/// は無視するべき? 立っているビットが同じかを比較するときに `eq()`
27/// で済むのがうれしいか、`cmp().is_eq()` にするか? そこを短くするために capacity
28/// の比較を別でやる必要がある方がつらいか?
29///
30/// [`u128` での実装](https://atcoder.jp/contests/past202203-open/submissions/33482505) より
31/// [`u64` での実装](https://atcoder.jp/contests/past202203-open/submissions/33482482)
32/// の方が高速だったので、とりあえずそうしている。`BitSet<u128>` のようにすると煩雑になりそう。
33#[derive(Default, Clone, Eq)]
34pub struct BitSet {
35    capacity: usize,
36    len: usize,
37    buf: Vec<Word>,
38    autofix: bool,
39}
40
41impl BitSet {
42    pub fn new(capacity: usize) -> Self {
43        let buf = vec![0; (capacity + WORD_SIZE - 1) / WORD_SIZE];
44        Self { capacity, len: 0, buf, autofix: true }
45    }
46
47    pub fn insert(&mut self, index: usize) {
48        check_bounds(index, self.capacity);
49
50        let (wi, bi) = (index / WORD_SIZE, index % WORD_SIZE);
51        if self.buf[wi] >> bi & 1 == 0 {
52            self.buf[wi] |= 1 << bi;
53            self.len += 1;
54        }
55    }
56
57    pub fn remove(&mut self, index: usize) {
58        check_bounds(index, self.capacity);
59
60        let (wi, bi) = (index / WORD_SIZE, index % WORD_SIZE);
61        if self.buf[wi] >> bi & 1 != 0 {
62            self.buf[wi] &= !(1 << bi);
63            self.len -= 1;
64        }
65    }
66
67    #[must_use]
68    pub fn contains(&self, index: usize) -> bool {
69        check_bounds(index, self.capacity);
70
71        let (wi, bi) = (index / WORD_SIZE, index % WORD_SIZE);
72        self.buf[wi] >> bi & 1 != 0
73    }
74
75    #[must_use]
76    pub fn len(&self) -> usize { self.len }
77    #[must_use]
78    pub fn is_empty(&self) -> bool { self.len == 0 }
79    #[must_use]
80    pub fn capacity(&self) -> usize { self.capacity }
81
82    // simple bit operations (assignment)
83    pub fn and_assign(&mut self, other: &Self) {
84        for (lhs, &rhs) in self.buf.iter_mut().zip(&other.buf) {
85            *lhs &= rhs;
86        }
87        if self.buf.len() > other.buf.len() {
88            for e in &mut self.buf[other.buf.len()..] {
89                *e = 0;
90            }
91        }
92        self.fixup();
93    }
94    pub fn or_assign(&mut self, other: &Self) {
95        for (lhs, &rhs) in self.buf.iter_mut().zip(&other.buf) {
96            *lhs |= rhs;
97        }
98        self.fixup();
99    }
100    pub fn ior_assign(&mut self, other: &Self) { self.or_assign(other) }
101    pub fn xor_assign(&mut self, other: &Self) {
102        for (lhs, &rhs) in self.buf.iter_mut().zip(&other.buf) {
103            *lhs ^= rhs;
104        }
105        self.fixup();
106    }
107    pub fn sub_assign(&mut self, other: &Self) {
108        for (lhs, &rhs) in self.buf.iter_mut().zip(&other.buf) {
109            *lhs &= !rhs;
110        }
111        self.fixup();
112    }
113    pub fn not_assign(&mut self) {
114        for lhs in &mut self.buf {
115            *lhs = !*lhs;
116        }
117        self.fixup();
118    }
119    pub fn shl_assign(&mut self, shl: usize) {
120        check_bounds(shl, self.capacity);
121        let (quot, rem) = (shl / WORD_SIZE, shl % WORD_SIZE);
122        let buf = &mut self.buf;
123        for i in (quot..buf.len()).rev() {
124            let mut tmp = buf[i - quot] << rem;
125            if rem > 0 && i - quot > 0 {
126                tmp |= buf[i - quot - 1] >> (WORD_SIZE - rem);
127            }
128            buf[i] = tmp;
129        }
130        for e in &mut buf[..quot] {
131            *e = 0;
132        }
133        self.fixup();
134    }
135    pub fn shr_assign(&mut self, shr: usize) {
136        check_bounds(shr, self.capacity);
137        let (quot, rem) = (shr / WORD_SIZE, shr % WORD_SIZE);
138        let buf = &mut self.buf;
139        let mid = buf.len() - quot;
140        for i in 0..mid {
141            let mut tmp = buf[i + quot] >> rem;
142            if rem > 0 && i + quot + 1 < buf.len() {
143                tmp |= buf[i + quot + 1] << (WORD_SIZE - rem);
144            }
145            buf[i] = tmp;
146        }
147        for e in &mut buf[mid..] {
148            *e = 0;
149        }
150        self.fixup();
151    }
152
153    // simple bit operations (non-assignment)
154    #[must_use]
155    pub fn and(&self, other: &Self) -> Self {
156        let mut tmp = self.clone();
157        tmp.and_assign(other);
158        tmp
159    }
160    #[must_use]
161    pub fn or(&self, other: &Self) -> Self {
162        let mut tmp = self.clone();
163        tmp.or_assign(other);
164        tmp
165    }
166    #[must_use]
167    pub fn ior(&self, other: &Self) -> Self { self.or(other) }
168    #[must_use]
169    pub fn xor(&self, other: &Self) -> Self {
170        let mut tmp = self.clone();
171        tmp.xor_assign(other);
172        tmp
173    }
174    #[must_use]
175    pub fn sub(&self, other: &Self) -> Self {
176        let mut tmp = self.clone();
177        tmp.sub_assign(other);
178        tmp
179    }
180    #[must_use]
181    pub fn not(&self) -> Self {
182        let mut tmp = self.clone();
183        tmp.not_assign();
184        tmp
185    }
186    #[must_use]
187    pub fn shl(&self, shl: usize) -> Self {
188        let mut tmp = self.clone();
189        tmp.shl_assign(shl);
190        tmp
191    }
192    #[must_use]
193    pub fn shr(&self, shr: usize) -> Self {
194        let mut tmp = self.clone();
195        tmp.shr_assign(shr);
196        tmp
197    }
198
199    pub fn reserve_exact(&mut self, new_capacity: usize) {
200        let new_buf_len = (new_capacity + WORD_SIZE - 1) / WORD_SIZE;
201        if self.buf.len() > new_buf_len {
202            for x in &self.buf[new_buf_len..] {
203                self.len -= x.count_ones() as usize;
204            }
205        }
206        self.buf.resize(new_buf_len, 0);
207        self.capacity = new_capacity;
208        self.fixup_last();
209    }
210
211    pub fn reserve(&mut self, at_least: usize) {
212        if self.capacity < at_least {
213            self.reserve_exact(at_least);
214        }
215    }
216
217    pub fn autofix(&mut self, enable: bool) {
218        let fix_now = !self.autofix && enable;
219        self.autofix = enable;
220        if fix_now {
221            self.fixup();
222        }
223    }
224
225    fn fixup(&mut self) {
226        if !self.autofix {
227            return;
228        }
229        self.fixup_count();
230        self.fixup_last();
231    }
232
233    fn fixup_last(&mut self) {
234        let rem = self.capacity % WORD_SIZE;
235        if rem == 0 {
236            return;
237        }
238        // `rem != 0` implies `self.buf.len() > 0`
239        let last = self.buf.last_mut().unwrap();
240        self.len -= last.count_ones() as usize;
241        *last &= !(!0 << rem);
242        self.len += last.count_ones() as usize;
243    }
244
245    fn fixup_count(&mut self) {
246        self.len = self.buf.iter().map(|x| x.count_ones() as usize).sum();
247    }
248
249    #[must_use]
250    pub fn words(&self, range: impl RangeBounds<usize>) -> Words<'_> {
251        let range = bounds_within(range, self.capacity);
252        Words::new(self, range)
253    }
254
255    #[must_use]
256    pub fn indices(&self, range: impl RangeBounds<usize>) -> Indices<'_> {
257        let range = bounds_within(range, self.capacity);
258        Indices::new(self, range)
259    }
260
261    fn single_word(&self, start: usize, end: usize) -> Word {
262        // [start..end] の bit からなる word を返す。範囲外は 0 でうめる。
263        // 0 <= end - start <= WORD_SIZE と end <= self.capacity は仮定する。
264        if start == end {
265            return 0;
266        }
267
268        let (ws, bs) = (start / WORD_SIZE, start % WORD_SIZE);
269        let (we, be) = (end / WORD_SIZE, end % WORD_SIZE);
270        let len = end - start;
271        let w = self.buf[ws];
272        if be == 0 {
273            if bs == 0 { w } else { w >> (WORD_SIZE - len) }
274        } else if ws == we {
275            (w >> bs) & !(!0 << len)
276        } else {
277            // e.g.: (LSB) _____xxx xx______ (MSB); bs: 5, be: 2
278            (w >> bs) | (self.buf[we] & !(!0 << be)) << (WORD_SIZE - bs)
279        }
280    }
281
282    fn single_word_bsf(&self, start: usize, end: usize) -> Option<usize> {
283        let w = self.single_word(start, end);
284        if w == 0 { None } else { Some(start + bsf(w)) }
285    }
286
287    fn single_word_bsr(&self, start: usize, end: usize) -> Option<usize> {
288        let w = self.single_word(start, end);
289        if w == 0 { None } else { Some(start + bsr(w)) }
290    }
291
292    #[must_use]
293    pub fn find_first(&self, range: impl RangeBounds<usize>) -> Option<usize> {
294        let Range { start, end } = bounds_within(range, self.capacity);
295        if start >= end {
296            return None;
297        }
298
299        let s_ceil = (start + WORD_SIZE - 1) / WORD_SIZE;
300        let e_floor = end / WORD_SIZE;
301        if s_ceil > e_floor {
302            return self.single_word_bsf(start, end);
303        }
304
305        let first = self.single_word_bsf(start, s_ceil * WORD_SIZE);
306        let middle = self.buf[s_ceil..e_floor]
307            .iter()
308            .zip(s_ceil..e_floor)
309            .filter(|&(&w, _)| w != 0)
310            .map(|(&w, i)| i * WORD_SIZE + bsf(w));
311        let last = self.single_word_bsf(e_floor * WORD_SIZE, end);
312        first.into_iter().chain(middle).chain(last).next()
313    }
314
315    #[must_use]
316    pub fn find_last(&self, range: impl RangeBounds<usize>) -> Option<usize> {
317        let Range { start, end } = bounds_within(range, self.capacity);
318        if start >= end {
319            return None;
320        }
321
322        let s_ceil = (start + WORD_SIZE - 1) / WORD_SIZE;
323        let e_floor = end / WORD_SIZE;
324        if s_ceil > e_floor {
325            return self.single_word_bsr(start, end);
326        }
327
328        let first = self.single_word_bsr(start, s_ceil * WORD_SIZE);
329        let middle = self.buf[s_ceil..e_floor]
330            .iter()
331            .zip(s_ceil..e_floor)
332            .filter(|&(&w, _)| w != 0)
333            .map(|(&w, i)| i * WORD_SIZE + bsr(w))
334            .rev();
335        let last = self.single_word_bsr(e_floor * WORD_SIZE, end);
336        last.into_iter().chain(middle).chain(first).next()
337    }
338}
339
340pub struct Words<'a> {
341    range: (usize, usize),
342    bit_set: &'a BitSet,
343}
344
345impl<'a> Words<'a> {
346    fn new(bit_set: &'a BitSet, Range { start, end }: Range<usize>) -> Self {
347        Self { range: (start, end), bit_set }
348    }
349}
350
351impl Iterator for Words<'_> {
352    type Item = Word;
353
354    fn next(&mut self) -> Option<Self::Item> {
355        let (start, end) = self.range;
356        if start >= end {
357            return None;
358        }
359        self.range.0 = end.min(start + WORD_SIZE);
360        Some(self.bit_set.single_word(start, self.range.0))
361    }
362}
363
364impl DoubleEndedIterator for Words<'_> {
365    fn next_back(&mut self) -> Option<Self::Item> {
366        let (start, end) = self.range;
367        if start >= end {
368            return None;
369        }
370        self.range.1 = start.max(end.saturating_sub(WORD_SIZE));
371        Some(self.bit_set.single_word(self.range.1, end))
372    }
373}
374
375pub struct Indices<'a> {
376    range: (usize, usize),
377    bit_set: &'a BitSet,
378}
379
380impl<'a> Indices<'a> {
381    pub fn new(
382        bit_set: &'a BitSet,
383        Range { start, end }: Range<usize>,
384    ) -> Self {
385        Self { bit_set, range: (start, end) }
386    }
387}
388
389impl Iterator for Indices<'_> {
390    type Item = usize;
391
392    fn next(&mut self) -> Option<Self::Item> {
393        let (start, end) = self.range;
394        if start >= end {
395            return None;
396        }
397        let res = self.bit_set.find_first(start..end);
398        self.range.0 = res.map(|i| i + 1).unwrap_or(self.range.1);
399        res
400    }
401}
402
403impl DoubleEndedIterator for Indices<'_> {
404    fn next_back(&mut self) -> Option<Self::Item> {
405        let (start, end) = self.range;
406        if start >= end {
407            return None;
408        }
409        let res = self.bit_set.find_last(start..end);
410        self.range.1 = res.unwrap_or(self.range.0);
411        res
412    }
413}
414
415impl Ord for BitSet {
416    fn cmp(&self, other: &Self) -> Ordering {
417        // iter() の辞書順で比較することにしたい。
418        // words().map(|x| x.reverse_bits()) の辞書順で比較すると、
419        // 0 の word がたくさんあって buf.len() の差があるときに相違が出る。
420        //
421        // [..self.buf.len().min(other.buf.len())] までで比較して、
422        // 同じなら、len の辞書順で比較すればよさそう。
423        let min_len = self.buf.len().min(other.buf.len());
424        (self.buf[..min_len].iter().map(|x| x.reverse_bits()))
425            .cmp(other.buf[..min_len].iter().map(|x| x.reverse_bits()))
426            .then_with(|| self.len.cmp(&other.len))
427    }
428}
429
430impl PartialOrd for BitSet {
431    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
432        Some(self.cmp(&other))
433    }
434}
435
436impl PartialEq for BitSet {
437    fn eq(&self, other: &Self) -> bool { self.cmp(&other) == Ordering::Equal }
438}
439
440macro_rules! fused_shl_bitop {
441    ($self:ident, $lhs:expr, $rhs:expr, $shl:expr, $bitop:expr) => {
442        let (quot, rem) = ($shl / WORD_SIZE, $shl % WORD_SIZE);
443        let rhs_upper = $rhs.len() + quot + rem.min(1);
444        for i in (quot..$lhs.len().min(rhs_upper)).rev() {
445            let mut tmp = 0;
446            if i - quot < $rhs.len() {
447                tmp |= $rhs[i - quot] << rem;
448            }
449            if rem > 0 && i - quot >= 1 {
450                tmp |= $rhs[i - quot - 1] >> (WORD_SIZE - rem);
451            }
452            $lhs[i] = $bitop($lhs[i], tmp);
453        }
454        if $bitop(0b01, 0) != 0b01 {
455            if quot >= $lhs.len().min(rhs_upper) {
456                for lhs in &mut $lhs {
457                    *lhs = $bitop(*lhs, 0);
458                }
459            } else {
460                if quot < $lhs.len() {
461                    for lhs in &mut $lhs[..quot] {
462                        *lhs = $bitop(*lhs, 0);
463                    }
464                }
465                if $lhs.len() > rhs_upper {
466                    for lhs in &mut $lhs[rhs_upper..] {
467                        *lhs = $bitop(*lhs, 0);
468                    }
469                }
470            }
471        }
472        $self.fixup();
473    };
474}
475macro_rules! fused_shr_bitop {
476    ($self:ident, $lhs:expr, $rhs:expr, $shr:expr, $bitop:expr) => {
477        let (quot, rem) = ($shr / WORD_SIZE, $shr % WORD_SIZE);
478        let mid = $lhs.len().min($rhs.len() - quot);
479        for i in 0..mid {
480            let mut tmp = $rhs[i + quot] >> rem;
481            if rem > 0 && i + quot + 1 < $rhs.len() {
482                tmp |= $rhs[i + quot + 1] << (WORD_SIZE - rem);
483            }
484            $lhs[i] = $bitop($lhs[i], tmp);
485        }
486        if $bitop(0b01, 0) != 0b01 {
487            for lhs in &mut $lhs[mid..] {
488                *lhs = $bitop(*lhs, 0);
489            }
490        }
491        $self.fixup();
492    };
493}
494
495macro_rules! impl_fused {
496    ( ($name:ident, self $bin:tt self << $sh:ident) ) => {
497        pub fn $name(&mut self, $sh: usize) {
498            check_bounds($sh, self.capacity);
499            fused_shl_bitop! {
500                self, self.buf, self.buf, $sh, |x, y| x $bin y
501            }
502        }
503    };
504    ( ($name:ident, self $bin:tt other << $sh:ident) ) => {
505        pub fn $name(&mut self, $sh: usize, other: &Self) {
506            check_bounds($sh, other.capacity);
507            fused_shl_bitop! {
508                self, self.buf, other.buf, $sh, |x, y| x $bin y
509            }
510        }
511    };
512    ( ($name:ident, self $bin:tt self >> $sh:ident) ) => {
513        pub fn $name(&mut self, $sh: usize) {
514            check_bounds($sh, self.capacity);
515            fused_shr_bitop! {
516                self, self.buf, self.buf, $sh, |x, y| x $bin y
517            }
518        }
519    };
520    ( ($name:ident, self $bin:tt other >> $sh:ident) ) => {
521        pub fn $name(&mut self, $sh: usize, other: &Self) {
522            check_bounds($sh, other.capacity);
523            fused_shr_bitop! {
524                self, self.buf, other.buf, $sh, |x, y| x $bin y
525            }
526        }
527    };
528    ( ($op:ident, $op_assign:ident, $op_self:ident, $op_self_assign:ident) ) => {
529        pub fn $op(&self, sh: usize, other: &Self) -> Self {
530            let mut tmp = self.clone();
531            tmp.$op_assign(sh, other);
532            tmp
533        }
534        pub fn $op_self(&self, sh: usize) -> Self {
535            let mut tmp = self.clone();
536            tmp.$op_self_assign(sh);
537            tmp
538        }
539    };
540    ( $( ( $( $tt:tt )* ), )* ) => { $( impl_fused!{ ( $( $tt )* ) } )* }
541}
542
543// fused bit operations
544impl BitSet {
545    impl_fused! {
546        (shl_and_self_assign, self & self << x),
547        (shl_ior_self_assign, self | self << x),
548        (shl_xor_self_assign, self ^ self << x),
549        (shr_and_self_assign, self & self >> x),
550        (shr_ior_self_assign, self | self >> x),
551        (shr_xor_self_assign, self ^ self >> x),
552
553        (shl_and_assign, self & other << x),
554        (shl_ior_assign, self | other << x),
555        (shl_xor_assign, self ^ other << x),
556        (shr_and_assign, self & other >> x),
557        (shr_ior_assign, self | other >> x),
558        (shr_xor_assign, self ^ other >> x),
559
560        (shl_or_self_assign, self | self << x),
561        (shr_or_self_assign, self | self >> x),
562        (shl_or_assign,  self | other << x),
563        (shr_or_assign,  self | other >> x),
564
565        (shl_and, shl_and_assign, shl_and_self, shl_and_self_assign),
566        (shl_ior, shl_ior_assign, shl_ior_self, shl_ior_self_assign),
567        (shl_xor, shl_xor_assign, shl_xor_self, shl_xor_self_assign),
568        (shl_sub, shl_sub_assign, shl_sub_self, shl_sub_self_assign),
569        (shr_and, shr_and_assign, shr_and_self, shr_and_self_assign),
570        (shr_ior, shr_ior_assign, shr_ior_self, shr_ior_self_assign),
571        (shr_xor, shr_xor_assign, shr_xor_self, shr_xor_self_assign),
572        (shr_sub, shr_sub_assign, shr_sub_self, shr_sub_self_assign),
573
574        (shl_or, shl_or_assign, shl_or_self, shl_or_self_assign),
575        (shr_or, shr_or_assign, shr_or_self, shr_or_self_assign),
576    }
577
578    pub fn shl_sub_assign(&mut self, shl: usize, other: &Self) {
579        self.shl_op_assign(shl, other, |x, y| x & !y)
580    }
581    pub fn shl_sub_self_assign(&mut self, shl: usize) {
582        self.shl_op_self_assign(shl, |x, y| x & !y)
583    }
584    pub fn shr_sub_assign(&mut self, shr: usize, other: &Self) {
585        self.shr_op_assign(shr, other, |x, y| x & !y)
586    }
587    pub fn shr_sub_self_assign(&mut self, shr: usize) {
588        self.shr_op_self_assign(shr, |x, y| x & !y)
589    }
590
591    pub fn shl_op_assign(
592        &mut self,
593        shl: usize,
594        other: &Self,
595        f: impl Fn(Word, Word) -> Word,
596    ) {
597        fused_shl_bitop!(self, self.buf, other.buf, shl, f);
598    }
599    pub fn shl_op_self_assign(
600        &mut self,
601        shl: usize,
602        f: impl Fn(Word, Word) -> Word,
603    ) {
604        fused_shl_bitop!(self, self.buf, self.buf, shl, f);
605    }
606    pub fn shr_op_assign(
607        &mut self,
608        shr: usize,
609        other: &Self,
610        f: impl Fn(Word, Word) -> Word,
611    ) {
612        fused_shr_bitop!(self, self.buf, other.buf, shr, f);
613    }
614    pub fn shr_op_self_assign(
615        &mut self,
616        shr: usize,
617        f: impl Fn(Word, Word) -> Word,
618    ) {
619        fused_shr_bitop!(self, self.buf, self.buf, shr, f);
620    }
621
622    #[must_use]
623    pub fn shl_op(
624        &self,
625        sh: usize,
626        other: &Self,
627        f: impl Fn(Word, Word) -> Word,
628    ) -> Self {
629        let mut tmp = self.clone();
630        tmp.shl_op_assign(sh, other, f);
631        tmp
632    }
633    #[must_use]
634    pub fn shl_op_self(
635        &self,
636        sh: usize,
637        f: impl Fn(Word, Word) -> Word,
638    ) -> Self {
639        let mut tmp = self.clone();
640        tmp.shl_op_self_assign(sh, f);
641        tmp
642    }
643    #[must_use]
644    pub fn shr_op(
645        &self,
646        sh: usize,
647        other: &Self,
648        f: impl Fn(Word, Word) -> Word,
649    ) -> Self {
650        let mut tmp = self.clone();
651        tmp.shr_op_assign(sh, other, f);
652        tmp
653    }
654    #[must_use]
655    pub fn shr_op_self(
656        &self,
657        sh: usize,
658        f: impl Fn(Word, Word) -> Word,
659    ) -> Self {
660        let mut tmp = self.clone();
661        tmp.shr_op_self_assign(sh, f);
662        tmp
663    }
664}
665
666impl fmt::Binary for BitSet {
667    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
668        let n = self.capacity;
669        for &w in &self.buf[..n / WORD_SIZE] {
670            write!(f, "{0:01$b}", w.reverse_bits(), WORD_SIZE)?;
671        }
672        let rem = n % WORD_SIZE;
673        if rem != 0 {
674            let w = self.buf[n / WORD_SIZE].reverse_bits() >> (WORD_SIZE - rem);
675            write!(f, "{0:01$b}", w, rem)?;
676        }
677        Ok(())
678    }
679}
680
681impl fmt::Debug for BitSet {
682    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
683        f.debug_set().entries(self.indices(..)).finish()
684    }
685}
686
687impl Extend<usize> for BitSet {
688    fn extend<T>(&mut self, iter: T)
689    where
690        T: IntoIterator<Item = usize>,
691    {
692        for i in iter {
693            self.insert(i);
694        }
695    }
696}
697
698impl FromIterator<usize> for BitSet {
699    fn from_iter<T>(iter: T) -> Self
700    where
701        T: IntoIterator<Item = usize>,
702    {
703        let mut res = BitSet::new(0);
704        for i in iter {
705            res.reserve(i + 1);
706            res.insert(i);
707        }
708        res
709    }
710}
711
712macro_rules! impl_binary_op {
713    ( $( ($trait:ident, $method:ident, $op_assign:ident, $op:ident), )* ) => { $(
714        impl $trait::<&BitSet> for BitSet {
715            type Output = BitSet;
716            fn $method(mut self, other: &Self) -> Self {
717                self.$op_assign(other);
718                self
719            }
720        }
721        impl $trait::<BitSet> for BitSet {
722            type Output = BitSet;
723            fn $method(mut self, other: Self) -> Self {
724                self.$op_assign(&other);
725                self
726            }
727        }
728        impl<'a> $trait::<&'a BitSet> for &'a BitSet {
729            type Output = BitSet;
730            fn $method(self, other: Self) -> BitSet { self.$op(other) }
731        }
732        impl<'a> $trait::<BitSet> for &'a BitSet {
733            type Output = BitSet;
734            fn $method(self, other: BitSet) -> BitSet { self.$op(&other) }
735        }
736    )* }
737}
738
739macro_rules! impl_binary_op_assign {
740    ( $( ($trait:ident, $method:ident, $op_assign:ident), )* ) => { $(
741        impl $trait::<&BitSet> for BitSet {
742            fn $method(&mut self, other: &Self) {
743                self.$op_assign(other);
744            }
745        }
746        impl $trait::<BitSet> for BitSet {
747            fn $method(&mut self, other: Self) {
748                self.$op_assign(&other);
749            }
750        }
751    )* }
752}
753
754macro_rules! impl_shift {
755    ( $( ($trait:ident, $method:ident, $op_assign:ident, $op:ident), )* ) => { $(
756        impl $trait::<usize> for BitSet {
757            type Output = Self;
758            fn $method(mut self, sh: usize) -> BitSet {
759                self.$op_assign(sh);
760                self
761            }
762        }
763        impl $trait::<usize> for &'_ BitSet {
764            type Output = BitSet;
765            fn $method(self, sh: usize) -> BitSet { self.$op(sh) }
766        }
767    )* }
768}
769
770macro_rules! impl_shift_assign {
771    ( $( ($trait:ident, $method:ident, $op_assign:ident), )* ) => { $(
772        impl $trait::<usize> for BitSet {
773            fn $method(&mut self, sh: usize) {
774                self.$op_assign(sh);
775            }
776        }
777    )* }
778}
779
780impl_binary_op! {
781    (BitAnd, bitand, and_assign, and),
782    (BitOr, bitor, or_assign, or),
783    (BitXor, bitxor, xor_assign, xor),
784    (Sub, sub, sub_assign, sub),
785}
786
787impl_binary_op_assign! {
788    (BitAndAssign, bitand_assign, and_assign),
789    (BitOrAssign, bitor_assign, or_assign),
790    (BitXorAssign, bitxor_assign, xor_assign),
791    (SubAssign, sub_assign, sub_assign),
792}
793
794impl_shift! {
795    (Shl, shl, shl_assign, shl),
796    (Shr, shr, shr_assign, shr),
797}
798
799impl_shift_assign! {
800    (ShlAssign, shl_assign, shl_assign),
801    (ShrAssign, shr_assign, shr_assign),
802}
803
804impl Not for BitSet {
805    type Output = Self;
806    fn not(mut self) -> Self {
807        self.not_assign();
808        self
809    }
810}
811impl Not for &'_ BitSet {
812    type Output = BitSet;
813    fn not(self) -> BitSet { self.clone().not() }
814}
815
816fn bsf(w: Word) -> usize { w.trailing_zeros() as usize }
817fn bsr(w: Word) -> usize { WORD_SIZE - 1 - w.leading_zeros() as usize }
818
819#[cfg(test)]
820mod test {
821    use std::collections::BTreeSet;
822
823    use super::{BitSet, WORD_SIZE};
824
825    #[test]
826    fn fmt() {
827        let mut bs = BitSet::new(10);
828        assert_eq!(format!("{:b}", bs), "0000000000");
829
830        bs.insert(3);
831        assert_eq!(format!("{:b}", bs), "0001000000");
832
833        bs.insert(5);
834        assert_eq!(format!("{:b}", bs), "0001010000");
835
836        bs.reserve_exact(128);
837        assert_eq!(
838            format!("{:b}", bs),
839            "00010100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
840        );
841
842        bs.insert(126);
843        assert_eq!(
844            format!("{:b}", bs),
845            "00010100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010"
846        );
847
848        bs.reserve_exact(129);
849        assert_eq!(
850            format!("{:b}", bs),
851            "000101000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100"
852        );
853
854        bs.insert(128);
855        assert_eq!(
856            format!("{:b}", bs),
857            "000101000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000101"
858        );
859    }
860
861    #[test]
862    fn basics() {
863        let mut bs = BitSet::new(10);
864        assert!(bs.is_empty());
865        assert!(!bs.contains(3));
866
867        bs.insert(3);
868        assert_eq!(bs.len(), 1);
869        assert!(bs.contains(3));
870
871        bs.remove(2);
872        assert_eq!(bs.len(), 1);
873
874        bs.remove(3);
875        assert_eq!(bs.len(), 0);
876    }
877
878    const SET: &[usize] = &[
879        0, 1, 126, 127, // 0
880        128, 129, 253, 255, // 1
881        256, 258, 380, 383, // 2
882        384, 387, // 3
883        513, // 4
884        640, // 5
885        // 6
886        896,  // 7
887        1025, // 8
888        // 9
889        1407, // 10
890    ];
891
892    const CONSEC: &[usize] = &[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
893    const FIB: &[usize] = &[0, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144];
894    const POW2: &[usize] = &[1, 2, 4, 8, 16, 32, 64, 128];
895
896    #[test]
897    fn reserve_fix() {
898        let set = SET;
899
900        for i in 0..set.len() {
901            let mut bs: BitSet = set[..=i].iter().copied().collect();
902            assert_eq!(bs.len(), i + 1);
903
904            for j in 0..set.len() {
905                let mut bs = bs.clone();
906                bs.reserve_exact(set[j] + 1);
907                assert_eq!(bs.len(), i.min(j) + 1);
908            }
909            bs.reserve_exact(0);
910            assert_eq!(bs.len(), 0);
911        }
912    }
913
914    #[test]
915    fn indices() {
916        let set = SET;
917
918        let n = set.len();
919        let mut bs = BitSet::new(set[n - 1] + 1);
920        bs.extend(set.iter().copied());
921
922        let fwd: Vec<_> = bs.indices(..).collect();
923        let bck: Vec<_> = bs.indices(..).rev().collect();
924        let rev_bck: Vec<_> = bck.into_iter().rev().collect();
925        assert_eq!(fwd, rev_bck);
926    }
927
928    #[test]
929    fn words() {
930        let set = SET;
931
932        let n = set.len();
933        let capacity = set[n - 1] + 1;
934        let mut bs = BitSet::new(capacity);
935        bs.extend(set.iter().copied());
936
937        // fwd
938        for i in 0..WORD_SIZE {
939            let actual: Vec<_> = bs.words(i..).collect();
940            let expected: Vec<_> = {
941                let mut bs = BitSet::new(capacity);
942                bs.extend(set.iter().filter_map(|&x| x.checked_sub(i)));
943                bs.words(..).collect()
944            };
945            assert_eq!(actual, expected);
946        }
947
948        // bck
949        for i in 0..WORD_SIZE {
950            let actual: Vec<_> = bs.words(..capacity - i).rev().collect();
951            let expected: Vec<_> = {
952                let mut bs = BitSet::new(capacity);
953                bs.extend(set.iter().map(|&x| x + i).filter(|&x| x < capacity));
954                bs.words(i..).rev().collect()
955            };
956            assert_eq!(actual, expected);
957        }
958    }
959
960    #[test]
961    fn cmp() {
962        let mut bs1 = BitSet::new(10);
963        let mut bs2 = BitSet::new(1000);
964
965        assert!(bs1 == bs2);
966
967        bs2.insert(999);
968        assert!(bs1 < bs2);
969
970        bs1.insert(2);
971        assert!(bs1 > bs2);
972
973        bs2.insert(1);
974        assert!(bs1 < bs2);
975
976        bs2.reserve_exact(1);
977        assert!(bs1 > bs2);
978
979        bs1.reserve_exact(0);
980        assert!(bs1 == bs2);
981    }
982
983    #[test]
984    fn eq() {
985        let set = SET;
986        let n = set.len();
987        let m = set[n - 1] + 1;
988
989        let mut bs1 = BitSet::new(m);
990        bs1.extend(set.iter().copied());
991        let mut bs2 = BitSet::new(m + 1);
992        bs2.extend(set.iter().copied());
993        assert!(bs1 == bs2);
994
995        bs2.insert(m);
996        assert!(bs1 != bs2);
997
998        bs2.reserve_exact(m);
999        assert!(bs1 == bs2);
1000    }
1001
1002    #[test]
1003    fn shl() {
1004        let set = SET;
1005        let n = set.len();
1006        let m = set[n - 1] + 1;
1007
1008        let mut bs1 = BitSet::new(m);
1009        bs1.extend(set.iter().copied());
1010
1011        for i in 0..2 * WORD_SIZE {
1012            let mut bs2 = BitSet::new(m);
1013            bs2.extend(bs1.indices(..).map(|x| x + i).filter(|&x| x < m));
1014            assert!(&bs1 << i == bs2);
1015        }
1016    }
1017
1018    #[test]
1019    fn shr() {
1020        let set = SET;
1021        let n = set.len();
1022        let m = set[n - 1] + 1;
1023
1024        let mut bs1 = BitSet::new(m);
1025        bs1.extend(set.iter().copied());
1026
1027        for i in 0..2 * WORD_SIZE {
1028            let mut bs2 = BitSet::new(m);
1029            bs2.extend(bs1.indices(..).filter(|&x| x >= i).map(|x| x - i));
1030            assert!(&bs1 >> i == bs2);
1031        }
1032    }
1033
1034    #[test]
1035    fn not() {
1036        let bs: BitSet = [0, 1, 3, 6].iter().copied().collect();
1037        assert_eq!(bs.words(..).next(), Some(0b1001011));
1038        assert_eq!(bs.len(), 4);
1039        let not_bs = !&bs;
1040        assert_eq!(not_bs.words(..).next(), Some(0b0110100));
1041        assert_eq!(not_bs.len(), 3);
1042    }
1043
1044    #[test]
1045    fn binary_ops() {
1046        let m = 200;
1047        let (bs1, bs2, bs3) = {
1048            let mut bs1 = BitSet::new(m);
1049            bs1.extend(CONSEC.iter().copied());
1050            let mut bs2 = BitSet::new(m);
1051            bs2.extend(FIB.iter().copied());
1052            let mut bs3 = BitSet::new(m);
1053            bs3.extend(POW2.iter().copied());
1054            (bs1, bs2, bs3)
1055        };
1056
1057        let ts1: BTreeSet<_> = bs1.indices(..).collect();
1058        let ts2: BTreeSet<_> = bs2.indices(..).collect();
1059        let ts3: BTreeSet<_> = bs3.indices(..).collect();
1060
1061        fn is_eq(actual: BitSet, expected: BTreeSet<usize>) -> bool {
1062            actual.indices(..).eq(expected)
1063        }
1064
1065        assert!(is_eq(&bs1 & &bs1, &ts1 & &ts1));
1066        assert!(is_eq(&bs1 & &bs2, &ts1 & &ts2));
1067        assert!(is_eq(&bs1 & &bs3, &ts1 & &ts3));
1068        assert!(is_eq(&bs2 & &bs1, &ts2 & &ts1));
1069        assert!(is_eq(&bs2 & &bs2, &ts2 & &ts2));
1070        assert!(is_eq(&bs2 & &bs3, &ts2 & &ts3));
1071        assert!(is_eq(&bs3 & &bs1, &ts3 & &ts1));
1072        assert!(is_eq(&bs3 & &bs2, &ts3 & &ts2));
1073        assert!(is_eq(&bs3 & &bs3, &ts3 & &ts3));
1074
1075        assert!(is_eq(&bs1 | &bs1, &ts1 | &ts1));
1076        assert!(is_eq(&bs1 | &bs2, &ts1 | &ts2));
1077        assert!(is_eq(&bs1 | &bs3, &ts1 | &ts3));
1078        assert!(is_eq(&bs2 | &bs1, &ts2 | &ts1));
1079        assert!(is_eq(&bs2 | &bs2, &ts2 | &ts2));
1080        assert!(is_eq(&bs2 | &bs3, &ts2 | &ts3));
1081        assert!(is_eq(&bs3 | &bs1, &ts3 | &ts1));
1082        assert!(is_eq(&bs3 | &bs2, &ts3 | &ts2));
1083        assert!(is_eq(&bs3 | &bs3, &ts3 | &ts3));
1084
1085        assert!(is_eq(&bs1 ^ &bs1, &ts1 ^ &ts1));
1086        assert!(is_eq(&bs1 ^ &bs2, &ts1 ^ &ts2));
1087        assert!(is_eq(&bs1 ^ &bs3, &ts1 ^ &ts3));
1088        assert!(is_eq(&bs2 ^ &bs1, &ts2 ^ &ts1));
1089        assert!(is_eq(&bs2 ^ &bs2, &ts2 ^ &ts2));
1090        assert!(is_eq(&bs2 ^ &bs3, &ts2 ^ &ts3));
1091        assert!(is_eq(&bs3 ^ &bs1, &ts3 ^ &ts1));
1092        assert!(is_eq(&bs3 ^ &bs2, &ts3 ^ &ts2));
1093        assert!(is_eq(&bs3 ^ &bs3, &ts3 ^ &ts3));
1094
1095        assert!(is_eq(&bs1 - &bs1, &ts1 - &ts1));
1096        assert!(is_eq(&bs1 - &bs2, &ts1 - &ts2));
1097        assert!(is_eq(&bs1 - &bs3, &ts1 - &ts3));
1098        assert!(is_eq(&bs2 - &bs1, &ts2 - &ts1));
1099        assert!(is_eq(&bs2 - &bs2, &ts2 - &ts2));
1100        assert!(is_eq(&bs2 - &bs3, &ts2 - &ts3));
1101        assert!(is_eq(&bs3 - &bs1, &ts3 - &ts1));
1102        assert!(is_eq(&bs3 - &bs2, &ts3 - &ts2));
1103        assert!(is_eq(&bs3 - &bs3, &ts3 - &ts3));
1104    }
1105
1106    #[test]
1107    fn or_capacity() {
1108        let mut bs1 = BitSet::new(4);
1109        bs1.extend([0, 3]);
1110        let mut bs2 = BitSet::new(5);
1111        bs2.extend([0, 2, 4]);
1112
1113        assert_eq!((&bs1 | &bs2).words(..).next(), Some(0b1101));
1114        assert_eq!((&bs2 | &bs1).words(..).next(), Some(0b11101));
1115
1116        bs1.reserve_exact(5);
1117        assert_eq!((&bs1 | &bs2).words(..).next(), Some(0b11101));
1118        assert_eq!((&bs2 | &bs1).words(..).next(), Some(0b11101));
1119    }
1120
1121    #[test]
1122    fn and_capacity() {
1123        let mut bs1 = BitSet::new(10);
1124        bs1.extend([1, 2, 4, 8]);
1125        let mut bs2 = BitSet::new(1000);
1126        bs2.extend([2, 3, 4, 5, 100, 200, 500, 900]);
1127
1128        assert!((&bs1 & &bs2).indices(..).eq([2, 4]));
1129        assert!((&bs2 & &bs1).indices(..).eq([2, 4]));
1130    }
1131
1132    #[test]
1133    fn fused_shift_bitwise() {
1134        let bs0: BitSet = SET.iter().copied().collect();
1135        let bs1: BitSet = CONSEC.iter().copied().collect();
1136        let bs2: BitSet = FIB.iter().copied().collect();
1137        let bs3: BitSet = POW2.iter().copied().collect();
1138
1139        fn test_internal(lhs: &BitSet, rhs: &BitSet) -> Result<(), usize> {
1140            let rhs_x = {
1141                let mut rhs_x = rhs.clone();
1142                rhs_x.reserve(lhs.capacity());
1143                rhs_x
1144            };
1145
1146            for i in 0..rhs.capacity() {
1147                let actual = vec![
1148                    lhs.shl_and(i, &rhs),
1149                    lhs.shl_ior(i, &rhs),
1150                    lhs.shl_xor(i, &rhs),
1151                    lhs.shl_sub(i, &rhs),
1152                    lhs.shr_and(i, &rhs),
1153                    lhs.shr_ior(i, &rhs),
1154                    lhs.shr_xor(i, &rhs),
1155                    lhs.shr_sub(i, &rhs),
1156                    // shl + op
1157                    lhs.shl_op(i, &rhs, |_, _| 0), //        0000
1158                    lhs.shl_op(i, &rhs, |x, y| x & y), //    0001
1159                    lhs.shl_op(i, &rhs, |x, y| x & !y), //   0010
1160                    lhs.shl_op(i, &rhs, |x, _| x), //        0011
1161                    lhs.shl_op(i, &rhs, |x, y| !x & y), //   0100
1162                    lhs.shl_op(i, &rhs, |_, y| y), //        0101
1163                    lhs.shl_op(i, &rhs, |x, y| x ^ y), //    0110
1164                    lhs.shl_op(i, &rhs, |x, y| x | y), //    0111
1165                    lhs.shl_op(i, &rhs, |x, y| !x & !y), //  1000
1166                    lhs.shl_op(i, &rhs, |x, y| !(x ^ y)), // 1001
1167                    lhs.shl_op(i, &rhs, |_, y| !y), //       1010
1168                    lhs.shl_op(i, &rhs, |x, y| x | !y), //   1011
1169                    lhs.shl_op(i, &rhs, |x, _| !x), //       1100
1170                    lhs.shl_op(i, &rhs, |x, y| !x | y), //   1101
1171                    lhs.shl_op(i, &rhs, |x, y| !x | !y), //  1110
1172                    lhs.shl_op(i, &rhs, |_, _| !0), //       1111
1173                    // shr + op
1174                    lhs.shr_op(i, &rhs, |_, _| 0), //        0000
1175                    lhs.shr_op(i, &rhs, |x, y| x & y), //    0001
1176                    lhs.shr_op(i, &rhs, |x, y| x & !y), //   0010
1177                    lhs.shr_op(i, &rhs, |x, _| x), //        0011
1178                    lhs.shr_op(i, &rhs, |x, y| !x & y), //   0100
1179                    lhs.shr_op(i, &rhs, |_, y| y), //        0101
1180                    lhs.shr_op(i, &rhs, |x, y| x ^ y), //    0110
1181                    lhs.shr_op(i, &rhs, |x, y| x | y), //    0111
1182                    lhs.shr_op(i, &rhs, |x, y| !x & !y), //  1000
1183                    lhs.shr_op(i, &rhs, |x, y| !(x ^ y)), // 1001
1184                    lhs.shr_op(i, &rhs, |_, y| !y), //       1010
1185                    lhs.shr_op(i, &rhs, |x, y| x | !y), //   1011
1186                    lhs.shr_op(i, &rhs, |x, _| !x), //       1100
1187                    lhs.shr_op(i, &rhs, |x, y| !x | y), //   1101
1188                    lhs.shr_op(i, &rhs, |x, y| !x | !y), //  1110
1189                    lhs.shr_op(i, &rhs, |_, _| !0), //       1111
1190                ];
1191                let mut rhs_l = &rhs_x << i;
1192                let mut rhs_r = &rhs_x >> i;
1193                rhs_l.reserve_exact(lhs.capacity());
1194                rhs_r.reserve_exact(lhs.capacity());
1195                let rhs_l = &rhs_l;
1196                let rhs_r = &rhs_r;
1197
1198                let expected = vec![
1199                    lhs & rhs_l,
1200                    lhs | rhs_l,
1201                    lhs ^ rhs_l,
1202                    lhs - rhs_l,
1203                    lhs & rhs_r,
1204                    lhs | rhs_r,
1205                    lhs ^ rhs_r,
1206                    lhs - rhs_r,
1207                    // shl + op
1208                    lhs & !lhs,     // 0000
1209                    lhs & rhs_l,    // 0001
1210                    lhs & !rhs_l,   // 0010
1211                    lhs.clone(),    // 0011
1212                    !lhs & rhs_l,   // 0100
1213                    rhs_l.clone(),  // 0101
1214                    lhs ^ rhs_l,    // 0110
1215                    lhs | rhs_l,    // 0111
1216                    !lhs & !rhs_l,  // 1000
1217                    !(lhs ^ rhs_l), // 1001
1218                    !rhs_l,         // 1010
1219                    lhs | !rhs_l,   // 1011
1220                    !lhs,           // 1100
1221                    !lhs | rhs_l,   // 1101
1222                    !lhs | !rhs_l,  // 1110
1223                    lhs | !lhs,     // 1111
1224                    // shr + op
1225                    lhs & !lhs,     // 0000
1226                    lhs & rhs_r,    // 0001
1227                    lhs & !rhs_r,   // 0010
1228                    lhs.clone(),    // 0011
1229                    !lhs & rhs_r,   // 0100
1230                    rhs_r.clone(),  // 0101
1231                    lhs ^ rhs_r,    // 0110
1232                    lhs | rhs_r,    // 0111
1233                    !lhs & !rhs_r,  // 1000
1234                    !(lhs ^ rhs_r), // 1001
1235                    !rhs_r,         // 1010
1236                    lhs | !rhs_r,   // 1011
1237                    !lhs,           // 1100
1238                    !lhs | rhs_r,   // 1101
1239                    !lhs | !rhs_r,  // 1110
1240                    lhs | !lhs,     // 1111
1241                ];
1242                (actual == expected).then(|| ()).ok_or(i)?;
1243
1244                let actual_len = actual.iter().map(|b| b.len());
1245                let actual_count = actual.iter().map(|b| b.indices(..).count());
1246                let expected_len = expected.iter().map(|b| b.len());
1247                actual_len.eq(expected_len.clone()).then(|| ()).ok_or(i)?;
1248                actual_count.eq(expected_len).then(|| ()).ok_or(i)?;
1249            }
1250
1251            Ok(())
1252        }
1253
1254        assert_eq!(test_internal(&bs0, &bs0), Ok(()));
1255        assert_eq!(test_internal(&bs0, &bs1), Ok(()));
1256        assert_eq!(test_internal(&bs0, &bs2), Ok(()));
1257        assert_eq!(test_internal(&bs0, &bs3), Ok(()));
1258
1259        assert_eq!(test_internal(&bs1, &bs0), Ok(()));
1260        assert_eq!(test_internal(&bs1, &bs1), Ok(()));
1261        assert_eq!(test_internal(&bs1, &bs2), Ok(()));
1262        assert_eq!(test_internal(&bs1, &bs3), Ok(()));
1263
1264        assert_eq!(test_internal(&bs2, &bs0), Ok(()));
1265        assert_eq!(test_internal(&bs2, &bs1), Ok(()));
1266        assert_eq!(test_internal(&bs2, &bs2), Ok(()));
1267        assert_eq!(test_internal(&bs2, &bs3), Ok(()));
1268
1269        assert_eq!(test_internal(&bs3, &bs0), Ok(()));
1270        assert_eq!(test_internal(&bs3, &bs1), Ok(()));
1271        assert_eq!(test_internal(&bs3, &bs2), Ok(()));
1272        assert_eq!(test_internal(&bs3, &bs3), Ok(()));
1273    }
1274}