1use std::ops::{Add, AddAssign, Mul, Neg};
2
3pub trait FractionBisect: Sized + SbInt {
5 fn fraction_bisect(
6 self,
7 pred: impl Fn(Self, Self) -> bool,
8 ) -> ((Self, Self), (Self, Self)) {
9 let bound = self;
10
11 let fr_neg_infty = (Self::ONE.neg(), Self::ZERO);
12 let fr_zero = (Self::ZERO, Self::ONE);
13 let ztf = pred(fr_zero.0, fr_zero.1);
14 let pred = {
15 if !ztf && !Self::SIGNED {
16 return (fr_zero, fr_zero);
17 }
18 if Self::SIGNED && !ztf && !pred(fr_neg_infty.0, fr_neg_infty.1) {
19 return (fr_neg_infty, fr_neg_infty);
20 }
21 move |fr: Fraction<Self>| {
22 if ztf { pred(fr.0, fr.1) } else { !pred(fr.0.neg(), fr.1) }
23 }
24 };
25
26 let mut lower = Fraction::zero();
27 let mut upper = Fraction::infty();
28 let (small, large) = 'outer: loop {
29 let cur = lower + upper;
30 if cur.is_deeper(bound) {
31 break (lower, upper);
32 }
33
34 let tf = pred(cur);
35 let (from, to) = if tf { (lower, upper) } else { (upper, lower) };
36
37 let mut lo = Self::ONE;
38 let mut hi = lo + Self::ONE;
39 while pred(from + to * hi) == tf {
40 lo += lo;
41 hi += hi;
42 if (from + to * lo).is_deeper(bound) {
43 let steps = bound.steps(from.into_inner(), to.into_inner());
44 let front = from + to * steps;
45 let res = if tf { (front, upper) } else { (lower, front) };
46 break 'outer res;
47 }
48 }
49
50 while lo.lt1(hi) {
51 let mid = lo.avg(hi);
52 let tmp = from + to * mid;
53 let cur = pred(tmp) == tf && !tmp.is_deeper(bound);
54 *(if cur { &mut lo } else { &mut hi }) = mid;
55 }
56
57 let next = from + to * lo;
58 *(if tf { &mut lower } else { &mut upper }) = next;
59 };
60
61 let (left, right) = if ztf { (small, large) } else { (-large, -small) };
62 (left.into_inner(), right.into_inner())
63 }
64}
65
66impl<I: SbInt> FractionBisect for I {}
67
68#[derive(Clone, Copy, Eq, PartialEq)]
69struct Fraction<I>(I, I);
70
71pub trait SbInt:
72 Copy
73 + Eq
74 + PartialOrd<Self>
75 + AddAssign<Self>
76 + Add<Self, Output = Self>
77 + Mul<Self, Output = Self>
78 + std::fmt::Display
79{
80 const ZERO: Self;
81 const ONE: Self;
82 const SIGNED: bool;
83 fn lt1(self, other: Self) -> bool;
84 fn avg(self, other: Self) -> Self;
85 fn abs(self) -> Self;
86 fn neg(self) -> Self;
87 fn steps(self, from: (Self, Self), to: (Self, Self)) -> Self;
88}
89
90impl<I: SbInt> Neg for Fraction<I> {
91 type Output = Self;
92 fn neg(self) -> Self { self.neg() }
93}
94
95macro_rules! impl_uint {
96 ( $($ty:ty)* ) => { $(
97 impl SbInt for $ty {
98 const ZERO: $ty = 0;
99 const ONE: $ty = 1;
100 const SIGNED: bool = false;
101 fn lt1(self, other: Self) -> bool { self + 1 < other }
102 fn avg(self, other: Self) -> Self {
103 self + (other - self) / 2
104 }
105 fn abs(self) -> Self { self }
106 fn neg(self) -> Self { self.wrapping_neg() } fn steps(self, from: (Self, Self), to: (Self, Self)) -> Self {
108 if to.1 == 0 {
109 Self::ZERO
110 } else {
111 (self - from.1) / to.1
112 }
113 }
114 }
115 )* }
116}
117
118impl_uint! { u8 u16 u32 u64 u128 usize }
119
120macro_rules! impl_int {
121 ( $($ty:ty)* ) => { $(
122 impl SbInt for $ty {
123 const ZERO: $ty = 0;
124 const ONE: $ty = 1;
125 const SIGNED: bool = true;
126 fn lt1(self, other: Self) -> bool { self + 1 < other }
127 fn avg(self, other: Self) -> Self {
128 self + (other - self) / 2
129 }
130 fn abs(self) -> Self { self.abs() }
131 fn neg(self) -> Self { -self}
132 fn steps(self, from: (Self, Self), to: (Self, Self)) -> Self {
133 if to.1 == 0 {
134 Self::ZERO
135 } else {
136 (self - from.1) / to.1
137 }
138 }
139 }
140 )* }
141}
142
143impl_int! { i8 i16 i32 i64 i128 isize }
144
145impl<I: SbInt> Fraction<I> {
146 fn zero() -> Self { Self(I::ZERO, I::ONE) }
147 fn infty() -> Self { Self(I::ONE, I::ZERO) }
148}
149
150impl<I: SbInt> Mul<I> for Fraction<I> {
151 type Output = Self;
152 fn mul(self, a: I) -> Self { Self(self.0 * a, self.1 * a) }
153}
154
155impl<I: SbInt> Add<Fraction<I>> for Fraction<I> {
156 type Output = Self;
157 fn add(self, other: Self) -> Self {
158 Self(self.0 + other.0, self.1 + other.1)
160 }
161}
162
163impl<I: SbInt> Fraction<I> {
164 fn is_deeper(self, bound: I) -> bool { self.1.abs() > bound }
165 fn neg(self) -> Self { Self(self.0.neg(), self.1) }
166 fn into_inner(self) -> (I, I) { (self.0, self.1) }
167}
168
169#[test]
170fn sanity_check() {
171 let sqrt3 = 5000_u64.fraction_bisect(|x, y| x * x <= 3 * y * y);
172 assert_eq!(sqrt3, ((3691, 2131), (5042, 2911)));
173
174 assert_eq!(10_u64.fraction_bisect(|_, _| false), ((0, 1), (0, 1)));
175 assert_eq!(10_i64.fraction_bisect(|_, _| false), ((-1, 0), (-1, 0)));
176
177 let neg_sqrt3 = 5000_i64.fraction_bisect(|x, y| x < 0 && x * x > 3 * y * y);
178 assert_eq!(neg_sqrt3, ((-5042, 2911), (-3691, 2131)));
179
180 let lt = 5000_i64.fraction_bisect(|x, y| 5 * x < 2 * y);
181 assert_eq!(lt, ((1999, 4998), (2, 5)));
182 let le = 5000_i64.fraction_bisect(|x, y| 5 * x <= 2 * y);
183 assert_eq!(le, ((2, 5), (1999, 4997)));
184}
185
186#[test]
187fn sqrt() {
188 let sqrt3 = 10_u128.pow(18).fraction_bisect(|x, y| x * x <= 3 * y * y);
189 let sqrt4 = 10_u128.pow(18).fraction_bisect(|x, y| x * x <= 4 * y * y);
190
191 assert_eq!(sqrt3.0, (734231055024833855, 423908497265970753));
192 assert_eq!(sqrt3.1, (1002978273411373057, 579069776145402304));
193
194 assert_eq!(sqrt4.0, (2, 1));
195 assert_eq!(sqrt4.1, (999999999999999999, 499999999999999999));
196}
197
198#[test]
199fn improper_fraction() {
200 let x = 6_u32.fraction_bisect(|x, y| x * 5 <= y * 13);
201 assert_eq!(x, ((13, 5), (8, 3)));
202}
203
204#[test]
205fn iterate() {
206 let bound = 6_u32;
207 let next = |p, q| bound.fraction_bisect(|x, y| x * q <= p * y).1;
208
209 let (p, q): (Vec<_>, Vec<_>) =
210 std::iter::successors(Some((0, 1)), |&(p, q)| Some(next(p, q)))
211 .take_while(|&(p, q)| p <= q)
212 .unzip();
213
214 assert_eq!(p, [0, 1, 1, 1, 1, 2, 1, 3, 2, 3, 4, 5, 1]);
215 assert_eq!(q, [1, 6, 5, 4, 3, 5, 2, 5, 3, 4, 5, 6, 1]);
216}