Skip to main content

nekolib/math/
fraction_bisect.rs

1use std::ops::{Add, AddAssign, Mul, Neg};
2
3// https://atcoder.jp/contests/abc294/editorial/6017
4pub trait FractionBisect: Sized + SbInt {
5    fn fraction_bisect(
6        self,
7        pred: impl Fn(Self, Self) -> bool,
8    ) -> ((Self, Self), (Self, Self)) {
9        let bound = self;
10
11        let fr_neg_infty = (Self::ONE.neg(), Self::ZERO);
12        let fr_zero = (Self::ZERO, Self::ONE);
13        let ztf = pred(fr_zero.0, fr_zero.1);
14        let pred = {
15            if !ztf && !Self::SIGNED {
16                return (fr_zero, fr_zero);
17            }
18            if Self::SIGNED && !ztf && !pred(fr_neg_infty.0, fr_neg_infty.1) {
19                return (fr_neg_infty, fr_neg_infty);
20            }
21            move |fr: Fraction<Self>| {
22                if ztf { pred(fr.0, fr.1) } else { !pred(fr.0.neg(), fr.1) }
23            }
24        };
25
26        let mut lower = Fraction::zero();
27        let mut upper = Fraction::infty();
28        let (small, large) = 'outer: loop {
29            let cur = lower + upper;
30            if cur.is_deeper(bound) {
31                break (lower, upper);
32            }
33
34            let tf = pred(cur);
35            let (from, to) = if tf { (lower, upper) } else { (upper, lower) };
36
37            let mut lo = Self::ONE;
38            let mut hi = lo + Self::ONE;
39            while pred(from + to * hi) == tf {
40                lo += lo;
41                hi += hi;
42                if (from + to * lo).is_deeper(bound) {
43                    let steps = bound.steps(from.into_inner(), to.into_inner());
44                    let front = from + to * steps;
45                    let res = if tf { (front, upper) } else { (lower, front) };
46                    break 'outer res;
47                }
48            }
49
50            while lo.lt1(hi) {
51                let mid = lo.avg(hi);
52                let tmp = from + to * mid;
53                let cur = pred(tmp) == tf && !tmp.is_deeper(bound);
54                *(if cur { &mut lo } else { &mut hi }) = mid;
55            }
56
57            let next = from + to * lo;
58            *(if tf { &mut lower } else { &mut upper }) = next;
59        };
60
61        let (left, right) = if ztf { (small, large) } else { (-large, -small) };
62        (left.into_inner(), right.into_inner())
63    }
64}
65
66impl<I: SbInt> FractionBisect for I {}
67
68#[derive(Clone, Copy, Eq, PartialEq)]
69struct Fraction<I>(I, I);
70
71pub trait SbInt:
72    Copy
73    + Eq
74    + PartialOrd<Self>
75    + AddAssign<Self>
76    + Add<Self, Output = Self>
77    + Mul<Self, Output = Self>
78    + std::fmt::Display
79{
80    const ZERO: Self;
81    const ONE: Self;
82    const SIGNED: bool;
83    fn lt1(self, other: Self) -> bool;
84    fn avg(self, other: Self) -> Self;
85    fn abs(self) -> Self;
86    fn neg(self) -> Self;
87    fn steps(self, from: (Self, Self), to: (Self, Self)) -> Self;
88}
89
90impl<I: SbInt> Neg for Fraction<I> {
91    type Output = Self;
92    fn neg(self) -> Self { self.neg() }
93}
94
95macro_rules! impl_uint {
96    ( $($ty:ty)* ) => { $(
97        impl SbInt for $ty {
98            const ZERO: $ty = 0;
99            const ONE: $ty = 1;
100            const SIGNED: bool = false;
101            fn lt1(self, other: Self) -> bool { self + 1 < other }
102            fn avg(self, other: Self) -> Self {
103                self + (other - self) / 2
104            }
105            fn abs(self) -> Self { self }
106            fn neg(self) -> Self { self.wrapping_neg() } // not to be called
107            fn steps(self, from: (Self, Self), to: (Self, Self)) -> Self {
108                if to.1 == 0 {
109                    Self::ZERO
110                } else {
111                    (self - from.1) / to.1
112                }
113            }
114        }
115    )* }
116}
117
118impl_uint! { u8 u16 u32 u64 u128 usize }
119
120macro_rules! impl_int {
121    ( $($ty:ty)* ) => { $(
122        impl SbInt for $ty {
123            const ZERO: $ty = 0;
124            const ONE: $ty = 1;
125            const SIGNED: bool = true;
126            fn lt1(self, other: Self) -> bool { self + 1 < other }
127            fn avg(self, other: Self) -> Self {
128                self + (other - self) / 2
129            }
130            fn abs(self) -> Self { self.abs() }
131            fn neg(self) -> Self { -self}
132            fn steps(self, from: (Self, Self), to: (Self, Self)) -> Self {
133                if to.1 == 0 {
134                    Self::ZERO
135                } else {
136                    (self - from.1) / to.1
137                }
138            }
139        }
140    )* }
141}
142
143impl_int! { i8 i16 i32 i64 i128 isize }
144
145impl<I: SbInt> Fraction<I> {
146    fn zero() -> Self { Self(I::ZERO, I::ONE) }
147    fn infty() -> Self { Self(I::ONE, I::ZERO) }
148}
149
150impl<I: SbInt> Mul<I> for Fraction<I> {
151    type Output = Self;
152    fn mul(self, a: I) -> Self { Self(self.0 * a, self.1 * a) }
153}
154
155impl<I: SbInt> Add<Fraction<I>> for Fraction<I> {
156    type Output = Self;
157    fn add(self, other: Self) -> Self {
158        // mediant
159        Self(self.0 + other.0, self.1 + other.1)
160    }
161}
162
163impl<I: SbInt> Fraction<I> {
164    fn is_deeper(self, bound: I) -> bool { self.1.abs() > bound }
165    fn neg(self) -> Self { Self(self.0.neg(), self.1) }
166    fn into_inner(self) -> (I, I) { (self.0, self.1) }
167}
168
169#[test]
170fn sanity_check() {
171    let sqrt3 = 5000_u64.fraction_bisect(|x, y| x * x <= 3 * y * y);
172    assert_eq!(sqrt3, ((3691, 2131), (5042, 2911)));
173
174    assert_eq!(10_u64.fraction_bisect(|_, _| false), ((0, 1), (0, 1)));
175    assert_eq!(10_i64.fraction_bisect(|_, _| false), ((-1, 0), (-1, 0)));
176
177    let neg_sqrt3 = 5000_i64.fraction_bisect(|x, y| x < 0 && x * x > 3 * y * y);
178    assert_eq!(neg_sqrt3, ((-5042, 2911), (-3691, 2131)));
179
180    let lt = 5000_i64.fraction_bisect(|x, y| 5 * x < 2 * y);
181    assert_eq!(lt, ((1999, 4998), (2, 5)));
182    let le = 5000_i64.fraction_bisect(|x, y| 5 * x <= 2 * y);
183    assert_eq!(le, ((2, 5), (1999, 4997)));
184}
185
186#[test]
187fn sqrt() {
188    let sqrt3 = 10_u128.pow(18).fraction_bisect(|x, y| x * x <= 3 * y * y);
189    let sqrt4 = 10_u128.pow(18).fraction_bisect(|x, y| x * x <= 4 * y * y);
190
191    assert_eq!(sqrt3.0, (734231055024833855, 423908497265970753));
192    assert_eq!(sqrt3.1, (1002978273411373057, 579069776145402304));
193
194    assert_eq!(sqrt4.0, (2, 1));
195    assert_eq!(sqrt4.1, (999999999999999999, 499999999999999999));
196}
197
198#[test]
199fn improper_fraction() {
200    let x = 6_u32.fraction_bisect(|x, y| x * 5 <= y * 13);
201    assert_eq!(x, ((13, 5), (8, 3)));
202}
203
204#[test]
205fn iterate() {
206    let bound = 6_u32;
207    let next = |p, q| bound.fraction_bisect(|x, y| x * q <= p * y).1;
208
209    let (p, q): (Vec<_>, Vec<_>) =
210        std::iter::successors(Some((0, 1)), |&(p, q)| Some(next(p, q)))
211            .take_while(|&(p, q)| p <= q)
212            .unzip();
213
214    assert_eq!(p, [0, 1, 1, 1, 1, 2, 1, 3, 2, 3, 4, 5, 1]);
215    assert_eq!(q, [1, 6, 5, 4, 3, 5, 2, 5, 3, 4, 5, 6, 1]);
216}