Skip to main content

nekolib/math/
miller_rabin.rs

1pub trait MillerRabin {
2    fn is_prime(self) -> bool;
3}
4
5impl MillerRabin for u8 {
6    fn is_prime(self) -> bool {
7        let x = self;
8        if x == 2 || x == 3 || x == 5 || x == 7 || x == 11 || x == 13 {
9            return true;
10        }
11        x > 1
12            && x % 2 > 0
13            && x % 3 > 0
14            && x % 5 > 0
15            && x % 7 > 0
16            && x % 11 > 0
17            && x % 13 > 0
18    }
19}
20
21impl MillerRabin for u16 {
22    fn is_prime(self) -> bool { (self as u32).is_prime() }
23}
24
25impl MillerRabin for u32 {
26    fn is_prime(self) -> bool {
27        let x = self;
28        if x == 2 || x == 3 || x == 5 || x == 7 {
29            return true;
30        }
31        if x % 2 == 0 || x % 3 == 0 || x % 5 == 0 || x % 7 == 0 {
32            return false;
33        }
34        if x < 121 {
35            return x > 1;
36        }
37        let h = x as u64;
38        let h = ((h >> 16) ^ h).wrapping_mul(0x45d9f3b);
39        let h = ((h >> 16) ^ h).wrapping_mul(0x45d9f3b);
40        let h = ((h >> 16) ^ h) & 255;
41        is_sprp_32(x, BASES[h as usize] as u32)
42    }
43}
44
45impl MillerRabin for u64 {
46    fn is_prime(self) -> bool {
47        let x = self;
48        if x == 2 || x == 3 || x == 5 || x == 7 {
49            return true;
50        }
51        if x % 2 == 0 || x % 3 == 0 || x % 5 == 0 || x % 7 == 0 {
52            return false;
53        }
54        if x < 121 {
55            return x > 1;
56        }
57
58        [2, 325, 9375, 28178, 450775, 9780504, 1795265022]
59            .iter()
60            .all(|&b| b % x == 0 || is_sprp_64(x, b % x))
61    }
62}
63
64// ForiĊĦek, Michal, and Jakub Jancina.
65// "Fast Primality Testing for Integers That Fit into a Machine Word." (2015).
66fn is_sprp_32(n: u32, a: u32) -> bool {
67    let s = (n - 1).trailing_zeros();
68    let d = n >> s;
69    let mut cur = {
70        let mut cur = 1;
71        let mut pow = d;
72        let mut a = a;
73        while pow > 0 {
74            if pow & 1 != 0 {
75                cur = (cur as u64 * a as u64 % n as u64) as u32;
76            }
77            a = ((a as u64).pow(2) % n as u64) as u32;
78            pow >>= 1;
79        }
80        cur
81    };
82    if cur == 1 {
83        return true;
84    }
85    for _ in 0..s {
86        if cur == n - 1 {
87            return true;
88        }
89        cur = ((cur as u64).pow(2) % n as u64) as u32;
90    }
91    false
92}
93
94#[rustfmt::skip]
95const BASES: [u16; 256] = [
96    0x3ce7, 0x07e2, 0x00a6, 0x1d05, 0x1f80, 0x3ead, 0x2907, 0x112f,
97    0x079d, 0x050f, 0x0ad8, 0x0e24, 0x0230, 0x0c38, 0x145c, 0x0a61,
98    0x08fc, 0x07e5, 0x122c, 0x05bf, 0x2478, 0x0fb2, 0x095e, 0x4fee,
99    0x2825, 0x1f5c, 0x08a5, 0x184b, 0x026c, 0x0eb3, 0x12f4, 0x1394,
100    0x0c71, 0x0535, 0x1853, 0x14b2, 0x0432, 0x0957, 0x13f9, 0x1b95,
101    0x0323, 0x04f5, 0x0f23, 0x01a6, 0x02ef, 0x0244, 0x1279, 0x27ff,
102    0x02ea, 0x0b87, 0x022c, 0x089e, 0x0ec2, 0x01e1, 0x05f2, 0x0d94,
103    0x01e1, 0x09b7, 0x0cc2, 0x1601, 0x01e8, 0x0d2d, 0x1929, 0x0d10,
104    0x0011, 0x3b01, 0x05d2, 0x103a, 0x07f4, 0x075a, 0x0715, 0x01d3,
105    0x0ceb, 0x36da, 0x18e3, 0x0292, 0x03ed, 0x0387, 0x02e1, 0x075f,
106    0x1d17, 0x0760, 0x0b20, 0x06f8, 0x1d87, 0x0d48, 0x03b7, 0x3691,
107    0x10d0, 0x00b1, 0x0029, 0x4da3, 0x0c26, 0x33a5, 0x2216, 0x023b,
108    0x1b83, 0x1b1f, 0x04af, 0x0160, 0x1923, 0x00a5, 0x0491, 0x0cf3,
109    0x03d2, 0x00e9, 0x0bbb, 0x0a02, 0x0bb2, 0x295b, 0x272e, 0x0949,
110    0x076e, 0x14ea, 0x115f, 0x0613, 0x0107, 0x6993, 0x08eb, 0x0131,
111    0x029d, 0x0778, 0x0259, 0x182a, 0x01ad, 0x078a, 0x3a19, 0x06f8,
112    0x067d, 0x020c, 0x0df9, 0x00ec, 0x0938, 0x1802, 0x0b22, 0xd955,
113    0x06d9, 0x1052, 0x2112, 0x00de, 0x0a13, 0x0ab7, 0x07ef, 0x08b2,
114    0x08e4, 0x0176, 0x0854, 0x032d, 0x5cec, 0x064a, 0x1146, 0x1427,
115    0x06bd, 0x0e0d, 0x0d26, 0x3800, 0x0243, 0x00a5, 0x055f, 0x2722,
116    0x3148, 0x2658, 0x055b, 0x0218, 0x074b, 0x2a70, 0x0359, 0x089e,
117    0x169c, 0x01b2, 0x1f95, 0x44d2, 0x02d7, 0x0e37, 0x063b, 0x1350,
118    0x0851, 0x07ed, 0x2003, 0x2098, 0x1858, 0x23df, 0x1fbe, 0x074e,
119    0x0ce0, 0x1d1f, 0x22f3, 0x61b9, 0x021d, 0x4aab, 0x0170, 0x0236,
120    0x162a, 0x019b, 0x020a, 0x0403, 0x2017, 0x0802, 0x1990, 0x2741,
121    0x0266, 0x0306, 0x091d, 0x0bbf, 0x8981, 0x1262, 0x0480, 0x06f9,
122    0x0404, 0x0604, 0x0e9f, 0x01ed, 0x117a, 0x09d9, 0x68dd, 0x20a2,
123    0x0360, 0x49e3, 0x1559, 0x098f, 0x002a, 0x119f, 0x067c, 0x00a6,
124    0x04e1, 0x1873, 0x09f9, 0x0130, 0x0110, 0x1c76, 0x0049, 0x199a,
125    0x0383, 0x0b00, 0x144d, 0x3412, 0x1b8e, 0x0b02, 0x0c7f, 0x032b,
126    0x039a, 0x015e, 0x1d5a, 0x1164, 0x0d79, 0x0a67, 0x1264, 0x01a2,
127    0x0655, 0x0493, 0x0d8f, 0x0058, 0x2c51, 0x019c, 0x0617, 0x00c2,
128];
129
130// http://miller-rabin.appspot.com/
131// http://web.archive.org/web/20220921163920/http://www.janfeitsma.nl/math/psp2/index
132fn is_sprp_64(n: u64, a: u64) -> bool {
133    let s = (n - 1).trailing_zeros();
134    let d = n >> s;
135    let mut cur = {
136        let mut cur = 1;
137        let mut pow = d;
138        let mut a = a;
139        while pow > 0 {
140            if pow & 1 != 0 {
141                cur = (cur as u128 * a as u128 % n as u128) as u64;
142            }
143            a = ((a as u128).pow(2) % n as u128) as u64;
144            pow >>= 1;
145        }
146        cur
147    };
148    if cur == 1 {
149        return true;
150    }
151    for _ in 0..s {
152        if cur == n - 1 {
153            return true;
154        }
155        cur = ((cur as u128).pow(2) % n as u128) as u64;
156    }
157    false
158}
159
160#[test]
161fn exhaustive_u8() {
162    let is_prime_naive = |x: u8| x > 1 && (2..x).all(|y| x % y > 0);
163
164    for i in 0..=255 {
165        assert_eq!(i.is_prime(), is_prime_naive(i), "{i}");
166    }
167}
168
169#[test]
170fn exhaustive_u16() {
171    let w = 64;
172    let n = w << 10;
173    let is_prime = {
174        let mut dp = vec![!0_u64; n / w + 1];
175        dp[0] &= !0 << 2;
176        for i in (2..=n).take_while(|&i| i <= n / i) {
177            let (qi, ri) = (i / w, i % w);
178            if dp[qi] >> ri & 1 == 0 {
179                continue;
180            }
181            for j in i..=n / i {
182                let (qj, rj) = (i * j / w, i * j % w);
183                dp[qj] &= !(1 << rj);
184            }
185        }
186        dp
187    };
188
189    for i in 2..n {
190        let actual = (i as u16).is_prime();
191        let expected = is_prime[i / w] >> (i % w) & 1 != 0;
192        assert_eq!(actual, expected, "{i}");
193    }
194}
195
196#[test]
197fn exhaustive_u32() {
198    let w = 64;
199    let n = w << 26; // takes ~30s
200    let is_prime = {
201        let mut dp = vec![!0_u64; n / w + 1];
202        dp[0] &= !0 << 2;
203        for i in (2..=n).take_while(|&i| i <= n / i) {
204            let (qi, ri) = (i / w, i % w);
205            if dp[qi] >> ri & 1 == 0 {
206                continue;
207            }
208            for j in i..=n / i {
209                let (qj, rj) = (i * j / w, i * j % w);
210                dp[qj] &= !(1 << rj);
211            }
212        }
213        dp
214    };
215
216    for i in 2..n {
217        let actual = (i as u32).is_prime();
218        let expected = is_prime[i / w] >> (i % w) & 1 != 0;
219        assert_eq!(actual, expected, "{i}");
220    }
221}
222
223#[test]
224fn small_u64() {
225    let w = 64;
226    let n = w << 18;
227    let is_prime = {
228        let mut dp = vec![!0_u64; n / w + 1];
229        dp[0] &= !0 << 2;
230        for i in (2..=n).take_while(|&i| i <= n / i) {
231            let (qi, ri) = (i / w, i % w);
232            if dp[qi] >> ri & 1 == 0 {
233                continue;
234            }
235            for j in i..=n / i {
236                let (qj, rj) = (i * j / w, i * j % w);
237                dp[qj] &= !(1 << rj);
238            }
239        }
240        dp
241    };
242
243    for i in 2..n {
244        let actual = (i as u64).is_prime();
245        let expected = is_prime[i / w] >> (i % w) & 1 != 0;
246        assert_eq!(actual, expected, "{i}");
247    }
248}
249
250#[test]
251fn mul_u64() {
252    let primes = [2, 3, 5, 13, 19, 73, 193, 407521, 299210837];
253    let max = u64::MAX;
254    let mult = {
255        let n = primes.len();
256        let mut mult = vec![];
257        let mut q: Vec<_> = (0..n).map(|i| (primes[i], i)).collect();
258        while let Some((x, i)) = q.pop() {
259            for j in i..n {
260                let p = primes[j];
261                if x > max / p {
262                    continue;
263                }
264                let y = x * p;
265                q.push((y, j));
266                mult.push(y);
267            }
268        }
269        mult
270    };
271
272    assert!(mult.iter().all(|&x| !x.is_prime()));
273}