gcd_recip/
lib.rs

1pub trait GcdRecip: Sized {
2    fn gcd_recip(self, other: Self) -> (Self, Self);
3}
4
5macro_rules! impl_uint {
6    ($t:ty) => {
7        impl GcdRecip for $t {
8            fn gcd_recip(self, other: Self) -> (Self, Self) {
9                assert!(other > 0);
10                let a = self % other;
11                if a == 0 {
12                    return (other, 0);
13                }
14
15                let mut s = other;
16                let mut t = a;
17                let mut m0 = 0;
18                let mut m1 = 1;
19                while t > 0 {
20                    let u = s / t;
21                    s -= t * u;
22                    // m0 -= m1 * u;
23                    let v = (m1 * u) % other;
24                    m0 = if m0 < v { m0 + other - v } else { m0 - v };
25                    std::mem::swap(&mut s, &mut t);
26                    std::mem::swap(&mut m0, &mut m1);
27                }
28                (s, m0 % (other / s))
29            }
30        }
31    };
32    ( $($t:ty)* ) => { $(impl_uint!($t);)* };
33}
34
35macro_rules! impl_int {
36    ($t:ty) => {
37        impl GcdRecip for $t {
38            fn gcd_recip(self, other: Self) -> (Self, Self) {
39                assert!(other > 0);
40                let a = self.rem_euclid(other);
41                if a == 0 {
42                    return (other, 0);
43                }
44
45                let mut s = other;
46                let mut t = a;
47                let mut m0 = 0;
48                let mut m1 = 1;
49                while t > 0 {
50                    let u = s / t;
51                    s -= t * u;
52                    m0 -= m1 * u;
53                    std::mem::swap(&mut s, &mut t);
54                    std::mem::swap(&mut m0, &mut m1);
55                }
56                if m0 < 0 {
57                    m0 += other / s;
58                }
59                (s, m0)
60            }
61        }
62    };
63    ( $($t:ty)* ) => { $(impl_int!($t);)* };
64}
65
66impl_uint!(u8 u16 u32 u64 u128 usize);
67impl_int!(i8 i16 i32 i64 i128 isize);
68
69#[test]
70fn test() {
71    for b in 1_i32..=1000 {
72        for a in 0..b {
73            let (g, r) = a.gcd_recip(b);
74            assert!(0 <= r && r < b / g);
75            assert_eq!(a * r % b, g % b);
76        }
77    }
78    for b in 1_u32..=1000 {
79        for a in 0..b {
80            let (g, r) = a.gcd_recip(b);
81            assert!(r < b / g);
82            assert_eq!(a * r % b, g % b);
83        }
84    }
85}