Skip to main content

nekolib/math/
garner.rs

1use super::gcd_recip;
2use gcd_recip::GcdRecip;
3
4pub trait CrtMod {
5    type I;
6    fn crt_mod(&self, m: Self::I) -> Self::I;
7}
8
9macro_rules! impl_crt_mod {
10    ( $($ty:ty)* ) => { $(
11        impl CrtMod for [($ty, $ty)] {
12            type I = $ty;
13            fn crt_mod(&self, mu: $ty) -> $ty {
14                let n = self.len();
15                let mut s = vec![0; n];
16                for i in 0..n {
17                    let (ri, mi) = self[i];
18                    let mut prod = 1;
19                    let mut sum = 0;
20                    for j in 0..i {
21                        let mj = self[j].1;
22                        sum = (sum + s[j] * prod % mi) % mi;
23                        prod = (prod * mj) % mi;
24                    }
25                    let left = (ri + (mi - sum)) % mi;
26                    let right = prod.gcd_recip(mi).1;
27                    s[i] = (left * right) % mi;
28                }
29
30                let mut prod = 1;
31                let mut sum = 0;
32                for j in 0..n {
33                    let mj = self[j].1;
34                    sum = (sum + s[j] * prod % mu) % mu;
35                    prod = (prod * mj) % mu;
36                }
37                sum
38            }
39        }
40    )* };
41}
42
43impl_crt_mod! { u8 u16 u32 u64 u128 usize }
44
45pub trait CrtWrapping {
46    type I;
47    fn crt_wrapping(&self) -> Self::I;
48}
49
50macro_rules! impl_crt_wrapping {
51    ( $($ty:ty)* ) => { $(
52        impl CrtWrapping for [($ty, $ty)] {
53            type I = $ty;
54            fn crt_wrapping(&self) -> $ty {
55                let n = self.len();
56                let mut s = vec![0; n];
57                for i in 0..n {
58                    let (ri, mi) = self[i];
59                    let mut prod = 1;
60                    let mut sum = 0;
61                    for j in 0..i {
62                        let mj = self[j].1;
63                        sum = (sum + s[j] * prod % mi) % mi;
64                        prod = (prod * mj) % mi;
65                    }
66                    let left = (ri + (mi - sum)) % mi;
67                    let right = prod.gcd_recip(mi).1;
68                    s[i] = (left * right) % mi;
69                }
70
71                let mut prod: $ty = 1;
72                let mut sum: $ty = 0;
73                for j in 0..n {
74                    let mj = self[j].1;
75                    sum = sum.wrapping_add(s[j].wrapping_mul(prod));
76                    prod = prod.wrapping_mul(mj);
77                }
78                sum
79            }
80        }
81    )* };
82}
83
84impl_crt_wrapping! { u8 u16 u32 u64 u128 usize }
85
86#[test]
87fn sanity_check_mod() {
88    let a2pow80 = [
89        (254739770_u64, 7 << 26 | 1),
90        (1481734260, 27 << 26 | 1),
91        (1038248692, 15 << 27 | 1),
92    ]; // 2^80
93    assert_eq!(a2pow80.crt_mod(998244353), 382013690);
94
95    let a0 = [(0_u64, 7 << 26 | 1), (0, 27 << 26 | 1), (0, 15 << 27 | 1)];
96    assert_eq!(a0.crt_mod(998244353), 0);
97}
98
99#[test]
100fn sanity_check_wrapping() {
101    let a3pow55 = [
102        (285021974_u64, 7 << 26 | 1),
103        (723309387, 27 << 26 | 1),
104        (1219762234, 15 << 27 | 1),
105    ]; // 3^55
106    assert_eq!(a3pow55.crt_wrapping(), 12511015583298303947);
107
108    let a0 = [(0_u64, 7 << 26 | 1), (0, 27 << 26 | 1), (0, 15 << 27 | 1)];
109    assert_eq!(a0.crt_wrapping(), 0);
110}
111
112#[test]
113fn large() {
114    let large_u64 = [
115        (867145189_u64, 1107296257),
116        (1121462194, 1711276033),
117        (567952613, 2113929217),
118        (292122917, 469762049),
119        (1550969568, 1811939329),
120        (1001085957, 2013265921),
121    ]; // (2^64-1)^2 2^25
122    assert_eq!(large_u64.crt_mod(998244353), 2258058);
123    assert_eq!(large_u64.crt_wrapping(), 1 << 25);
124
125    let large_u128 = [
126        (305535025_u128, 754974721),
127        (1105782392, 1224736769),
128        (1129452415, 2130706433),
129        (42581335, 167772161),
130        (736341624, 1107296257),
131        (937167787, 1711276033),
132        (218059526, 2113929217),
133        (100381360, 469762049),
134        (1394496118, 1811939329),
135        (1096317127, 2013265921),
136    ]; // (2^128-1)^2 2^24
137    assert_eq!(large_u128.crt_mod(998244353), 577639010);
138    assert_eq!(large_u128.crt_wrapping(), 1 << 24);
139}