Skip to main content

nekolib/algo/
bisect_.rs

1//! 二分探索。
2
3use std::ops::Range;
4
5/// 二分探索で境界を探す。
6///
7/// `pred(&buf[i])` が `false` となる最小の `i` を返す。
8/// ただし `i < buf.len()` なる全ての `i` で `true` の場合は `buf.len()` を返す。
9/// 先頭から `i` 個の要素が条件を満たすと考えるのがよい。
10///
11/// # Requirements
12/// `pred(&buf[i])` が `false` となる `i` が存在するとき、`i < j` なる全ての `j` で
13/// `pred(&buf[j])` が `false` となる。
14///
15/// # Complexity
16/// `buf.len()` を $n$ として、高々 $\\lceil\\log_2(n+1)\\rceil$ 回の `pred` の呼び出しを行う。
17///
18/// # Examples
19/// ```
20/// use nekolib::algo::bisect_slice;
21///
22/// assert_eq!(bisect_slice(&[2, 4, 7], |&x| x < 4), 1);
23/// assert_eq!(bisect_slice(&[2, 4, 7], |&x| x <= 4), 2);
24/// assert_eq!(bisect_slice(&[2, 4, 7], |&_| false), 0);
25/// assert_eq!(bisect_slice(&[2, 4, 7], |&_| true), 3);
26/// ```
27pub fn bisect_slice<T>(buf: &[T], mut pred: impl FnMut(&T) -> bool) -> usize {
28    bisect(0..buf.len(), |i| pred(&buf[i]))
29}
30
31/// 二分探索で境界を探す。
32///
33/// `pred(i)` が `false` となる最小の `i` を返す。
34/// ただし `start..end` 内の全ての `i` で `true` の場合は `end` を返す。
35///
36/// # Requirements
37/// `pred(i)` が `false` となる `i` が存在するとき、`i < j` なる全ての `j` で
38/// `pred(j)` が `false` となる。
39///
40/// # Complexity
41/// 区間の長さを $n$ として、高々 $\\lceil\\log_2(n+1)\\rceil$ 回の `pred` の呼び出しを行う。
42///
43/// # Suggestions
44/// 範囲の型を `PrimInt` なり `Ord` なりにしたい気もする。区間長と中間値の取得をよしなにできると助かる。
45///
46/// # Examples
47/// ```
48/// use nekolib::algo::bisect;
49///
50/// let floor_sqrt = |i| if i <= 1 { i } else {
51///     bisect(0..i, |j| match (j + 1).overflowing_pow(2) {
52///         (x, false) if x <= i => true,
53///         _ => false
54///     })
55/// };
56/// assert_eq!(floor_sqrt(8), 2);
57/// assert_eq!(floor_sqrt(9), 3);
58/// assert_eq!(floor_sqrt(10), 3);
59/// assert_eq!(floor_sqrt(1 << 60), 1 << 30);
60/// ```
61pub fn bisect(
62    Range { start, end }: Range<usize>,
63    mut pred: impl FnMut(usize) -> bool,
64) -> usize {
65    if start == end {
66        return start;
67    }
68    let mut ok = start;
69    let mut bad = end;
70    while bad - ok > 1 {
71        let mid = ok + (bad - ok) / 2;
72        if pred(mid) {
73            ok = mid;
74        } else {
75            bad = mid;
76        }
77    }
78    if ok == start && !pred(start) {
79        start
80    } else {
81        bad
82    }
83}
84
85#[test]
86fn bisect_count() {
87    for n in 0..=128 {
88        for k in 0..=n {
89            let mut count = 0;
90            let f = |i| {
91                count += 1;
92                i < k
93            };
94            let res = bisect(0..n, f);
95            assert!(count <= (n + 1).next_power_of_two().trailing_zeros());
96            assert_eq!(res, k);
97        }
98    }
99}