Skip to main content

nekolib/math/
modint.rs

1use super::gcd_recip;
2use std::fmt::{self, Debug, Display};
3use std::hash::{Hash, Hasher};
4use std::iter::{Product, Sum};
5use std::marker::PhantomData;
6use std::ops::{
7    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign,
8};
9use std::sync::atomic::{self, AtomicU32, AtomicU64};
10
11use gcd_recip::GcdRecip;
12
13#[derive(Copy, Clone, Eq, PartialEq)]
14pub struct StaticModInt<M> {
15    val: u32,
16    _phd: PhantomData<fn() -> M>,
17}
18
19#[derive(Copy, Clone, Eq, PartialEq)]
20pub struct DynamicModInt<I> {
21    val: u32,
22    _phd: PhantomData<fn() -> I>,
23}
24
25pub trait Modulus: 'static + Copy + Eq {
26    const VALUE: u32;
27    #[cfg(ignore)]
28    const IS_PRIME: bool = is_prime_32(Self::VALUE);
29}
30
31pub trait ModIntBase:
32    Copy
33    + Eq
34    + Hash
35    + Add<Output = Self>
36    + Sub<Output = Self>
37    + Mul<Output = Self>
38    + Div<Output = Self>
39    + Neg
40    + AddAssign
41    + SubAssign
42    + MulAssign
43    + DivAssign
44{
45    fn modulus() -> u32;
46    fn get(self) -> u32;
47    fn new(n: impl RemEuclidU32) -> Self {
48        let n = n.rem_euclid_u32(Self::modulus());
49        unsafe { Self::new_unchecked(n) }
50    }
51    unsafe fn new_unchecked(n: u32) -> Self;
52    fn recip(self) -> Self { self.checked_recip().unwrap() }
53    fn checked_recip(self) -> Option<Self> {
54        let (g, r) = (self.get() as u64).gcd_recip(Self::modulus() as u64);
55        let r = r as u32;
56        if g == 1 { Some(unsafe { Self::new_unchecked(r) }) } else { None }
57    }
58    fn pow(self, mut iexp: u64) -> Self {
59        let mut res = Self::new(1);
60        let mut a = self;
61        while iexp > 0 {
62            if iexp & 1 != 0 {
63                res *= a;
64            }
65            a *= a;
66            iexp >>= 1;
67        }
68        res
69    }
70}
71
72trait InternalImpls: ModIntBase {
73    fn hash_impl(&self, state: &mut impl Hasher) { self.get().hash(state) }
74    fn display_impl(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        fmt::Display::fmt(&self.get(), f)
76    }
77    fn debug_impl(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        write!(f, "{} (mod {})", self.get(), Self::modulus())
79    }
80    fn neg_impl(self) -> Self {
81        let v = if self.get() == 0 { 0 } else { Self::modulus() - self.get() };
82        unsafe { Self::new_unchecked(v) }
83    }
84}
85
86impl<M: Modulus> StaticModInt<M> {
87    fn modulus() -> u32 { M::VALUE }
88    fn zero() -> Self { unsafe { Self::new_unchecked(0) } }
89    fn add_impl(self, rhs: Self) -> Self {
90        let mut tmp = self;
91        tmp += rhs;
92        tmp
93    }
94    fn sub_impl(self, rhs: Self) -> Self {
95        let mut tmp = self;
96        tmp -= rhs;
97        tmp
98    }
99    fn mul_impl(self, rhs: Self) -> Self {
100        let v = ((self.val as u64 * rhs.val as u64) % Self::modulus() as u64)
101            as u32;
102        unsafe { Self::new_unchecked(v) }
103    }
104    fn div_impl(self, rhs: Self) -> Self { self.mul_impl(rhs.recip()) }
105    fn add_assign_impl(&mut self, rhs: Self) {
106        self.val += rhs.val;
107        if self.val >= Self::modulus() {
108            self.val -= Self::modulus()
109        }
110    }
111    fn sub_assign_impl(&mut self, rhs: Self) {
112        if self.val < rhs.val {
113            self.val += Self::modulus()
114        }
115        self.val -= rhs.val
116    }
117    fn mul_assign_impl(&mut self, rhs: Self) { *self = self.mul_impl(rhs) }
118    fn div_assign_impl(&mut self, rhs: Self) { *self = self.div_impl(rhs) }
119}
120
121impl<M: Modulus> ModIntBase for StaticModInt<M> {
122    fn modulus() -> u32 { Self::modulus() }
123    fn get(self) -> u32 { self.val }
124    unsafe fn new_unchecked(val: u32) -> Self {
125        Self { val, _phd: PhantomData }
126    }
127}
128
129impl<M: Modulus> InternalImpls for StaticModInt<M> {}
130
131impl<I: RemEuclidU32, M: Modulus> From<I> for StaticModInt<M> {
132    fn from(x: I) -> Self { Self::new(x) }
133}
134
135#[cfg(ignore)]
136const fn is_sprp_32(n: u32, a: u32) -> bool {
137    let n = n as u64;
138    let s = (n - 1).trailing_zeros();
139    let d = n >> s;
140    let mut cur = {
141        let mut cur = 1;
142        let mut pow = d;
143        let mut a = a as u64;
144        while pow > 0 {
145            if pow & 1 != 0 {
146                cur = cur * a % n;
147            }
148            a = a * a % n;
149            pow >>= 1;
150        }
151        cur
152    };
153    if cur == 1 {
154        return true;
155    }
156    let mut i = 0;
157    while i < s {
158        if cur == n - 1 {
159            return true;
160        }
161        cur = cur * cur % n;
162        i += 1;
163    }
164    false
165}
166
167#[cfg(ignore)]
168const fn is_prime_32(n: u32) -> bool {
169    if n == 2 || n == 3 || n == 5 || n == 7 {
170        true
171    } else if n % 2 == 0 || n % 3 == 0 || n % 5 == 0 || n % 7 == 0 {
172        false
173    } else if n < 121 {
174        n > 1
175    } else {
176        is_sprp_32(n, 2) && is_sprp_32(n, 7) && is_sprp_32(n, 61)
177    }
178}
179
180pub trait DynamicModIntId: 'static + Copy + Eq {
181    fn barrett() -> &'static Barrett;
182}
183
184#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
185pub enum DefaultId {}
186
187pub struct Barrett {
188    m: AtomicU32,
189    im: AtomicU64,
190}
191
192impl DynamicModIntId for DefaultId {
193    fn barrett() -> &'static Barrett {
194        static BARRETT: Barrett = Barrett::default();
195        &BARRETT
196    }
197}
198
199impl Barrett {
200    pub const fn new(m: u32) -> Self {
201        Self {
202            m: AtomicU32::new(m),
203            im: AtomicU64::new(Self::im(m)),
204        }
205    }
206
207    pub const fn default() -> Self { Self::new(1) }
208
209    const fn im(m: u32) -> u64 {
210        (0_u64.wrapping_sub(1) / m as u64).wrapping_add(1)
211    }
212
213    fn set(&self, m: u32) {
214        self.m.store(m, atomic::Ordering::SeqCst);
215        self.im.store(Self::im(m), atomic::Ordering::SeqCst);
216    }
217
218    fn modulus(&self) -> u32 { self.m.load(atomic::Ordering::SeqCst) }
219
220    fn mul(&self, a: u32, b: u32) -> u32 {
221        let m = self.m.load(atomic::Ordering::SeqCst);
222        let im = self.im.load(atomic::Ordering::SeqCst);
223
224        let z = a as u64 * b as u64;
225        let x = ((z as u128 * im as u128) >> 64) as u64;
226        let y = x.wrapping_mul(m as u64);
227        let v = z.wrapping_sub(y) as u32;
228        v.wrapping_add(if z < y { m } else { 0 })
229    }
230}
231
232impl Default for Barrett {
233    fn default() -> Self { Self::default() }
234}
235
236impl<I: DynamicModIntId> DynamicModInt<I> {
237    pub fn modulus() -> u32 { I::barrett().modulus() }
238    fn zero() -> Self { unsafe { Self::new_unchecked(0) } }
239    fn add_impl(self, rhs: Self) -> Self {
240        let mut tmp = self;
241        tmp += rhs;
242        tmp
243    }
244    fn sub_impl(self, rhs: Self) -> Self {
245        let mut tmp = self;
246        tmp -= rhs;
247        tmp
248    }
249    fn mul_impl(self, rhs: Self) -> Self {
250        let v = I::barrett().mul(self.val, rhs.val);
251        unsafe { Self::new_unchecked(v) }
252    }
253    fn div_impl(self, rhs: Self) -> Self { self.mul_impl(rhs.recip()) }
254    fn add_assign_impl(&mut self, rhs: Self) {
255        self.val += rhs.val;
256        if self.val >= Self::modulus() {
257            self.val -= Self::modulus()
258        }
259    }
260    fn sub_assign_impl(&mut self, rhs: Self) {
261        if self.val < rhs.val {
262            self.val += Self::modulus()
263        }
264        self.val -= rhs.val
265    }
266    fn mul_assign_impl(&mut self, rhs: Self) { *self = self.mul_impl(rhs) }
267    fn div_assign_impl(&mut self, rhs: Self) { *self = self.div_impl(rhs) }
268
269    pub fn set_modulus(m: u32) {
270        // (m - 1) + (m - 1) < 2 ** 32
271        // m <= 2 ** 31
272        if !(1..=1 << 31).contains(&m) {
273            panic!("the modulus must be in range (1, 2**31)");
274        }
275        I::barrett().set(m);
276    }
277}
278
279impl<I: DynamicModIntId> ModIntBase for DynamicModInt<I> {
280    fn modulus() -> u32 { Self::modulus() }
281    fn get(self) -> u32 { self.val }
282    unsafe fn new_unchecked(val: u32) -> Self {
283        Self { val, _phd: PhantomData }
284    }
285}
286
287impl<I: DynamicModIntId> InternalImpls for DynamicModInt<I> {}
288
289impl<J: RemEuclidU32, I: DynamicModIntId> From<J> for DynamicModInt<I> {
290    fn from(x: J) -> Self { Self::new(x) }
291}
292
293macro_rules! impl_modint {
294    ( $( ($mod:ident, $val:literal, $modint:ident), )* ) => { $(
295        #[derive(Clone, Copy, Eq, PartialEq)]
296        pub struct $mod;
297        impl Modulus for $mod {
298            const VALUE: u32 = $val;
299        }
300        pub type $modint = StaticModInt<$mod>;
301    )* }
302}
303
304impl_modint! {
305    (Mod998244353, 998244353, ModInt998244353),
306    (Mod1000000007, 1000000007, ModInt1000000007),
307}
308
309macro_rules! impl_bin_ops {
310    () => {};
311    (
312        for<$($generic_param:ident : $bound:tt),*>
313            <$lhs_ty:ty> @ <$rhs_ty:ty> -> $output:ty
314        { self @ $($rhs_deref:tt)? } $($rest:tt)*
315    ) => {
316        impl <$($generic_param: $bound),*> Add<$rhs_ty> for $lhs_ty {
317            type Output = $output;
318            fn add(self, rhs: $rhs_ty) -> $output { self.add_impl($($rhs_deref)? rhs) }
319        }
320        impl <$($generic_param: $bound),*> Sub<$rhs_ty> for $lhs_ty {
321            type Output = $output;
322            fn sub(self, rhs: $rhs_ty) -> $output { self.sub_impl($($rhs_deref)? rhs) }
323        }
324        impl <$($generic_param: $bound),*> Mul<$rhs_ty> for $lhs_ty {
325            type Output = $output;
326            fn mul(self, rhs: $rhs_ty) -> $output { self.mul_impl($($rhs_deref)? rhs) }
327        }
328        impl <$($generic_param: $bound),*> Div<$rhs_ty> for $lhs_ty {
329            type Output = $output;
330            fn div(self, rhs: $rhs_ty) -> $output { self.div_impl($($rhs_deref)? rhs) }
331        }
332        impl_bin_ops!($($rest)*);
333    };
334}
335
336macro_rules! impl_assign_ops {
337    () => {};
338    (
339        for<$($generic_param:ident : $bound:tt),*>
340            <$lhs_ty:ty> @= <$rhs_ty:ty>
341        { self @= $($rhs_deref:tt)? } $($rest:tt)*
342    ) => {
343        impl <$($generic_param: $bound),*> AddAssign<$rhs_ty> for $lhs_ty {
344            fn add_assign(&mut self, rhs: $rhs_ty) {
345                self.add_assign_impl($($rhs_deref)? rhs);
346            }
347        }
348        impl <$($generic_param: $bound),*> SubAssign<$rhs_ty> for $lhs_ty {
349            fn sub_assign(&mut self, rhs: $rhs_ty) {
350                self.sub_assign_impl($($rhs_deref)? rhs);
351            }
352        }
353        impl <$($generic_param: $bound),*> MulAssign<$rhs_ty> for $lhs_ty {
354            fn mul_assign(&mut self, rhs: $rhs_ty) {
355                self.mul_assign_impl($($rhs_deref)? rhs);
356            }
357        }
358        impl <$($generic_param: $bound),*> DivAssign<$rhs_ty> for $lhs_ty {
359            fn div_assign(&mut self, rhs: $rhs_ty) {
360                self.div_assign_impl($($rhs_deref)? rhs);
361            }
362        }
363        impl_assign_ops!($($rest)*);
364    };
365}
366
367macro_rules! impl_folding {
368    () => {};
369    (
370        impl <$generic_param:ident : $bound:tt> $trait:ident<_>
371            for $self:ty
372        {
373            fn $method:ident(_) -> _ { _($unit:expr, $op:expr) }
374        }
375        $($rest:tt)*
376    ) => {
377        impl<$generic_param: $bound> $trait<Self> for $self {
378            fn $method<S>(iter: S) -> Self
379            where
380                S: Iterator<Item = Self>,
381            {
382                iter.fold($unit, $op)
383            }
384        }
385        impl<'a, $generic_param: $bound> $trait<&'a Self> for $self {
386            fn $method<S>(iter: S) -> Self
387            where
388                S: Iterator<Item = &'a Self>,
389            {
390                iter.fold($unit, $op)
391            }
392        }
393        impl_folding!($($rest)*);
394    };
395}
396
397macro_rules! impl_basic_traits {
398    () => {};
399    (impl<$generic_param:ident : $bound:tt> _ for $self:ty; $($rest:tt)*) => {
400        impl<$generic_param: $bound> Default for $self {
401            fn default() -> Self { Self::zero() }
402        }
403        impl<$generic_param: $bound> Hash for $self {
404            fn hash<H: Hasher>(&self, state: &mut H) { self.hash_impl(state) }
405        }
406        impl<$generic_param: $bound> Display for $self {
407            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
408                self.display_impl(f)
409            }
410        }
411        impl<$generic_param: $bound> Debug for $self {
412            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413                self.debug_impl(f)
414            }
415        }
416        impl<$generic_param: $bound> Neg for $self {
417            type Output = $self;
418            fn neg(self) -> Self { self.neg_impl() }
419        }
420        impl_basic_traits!($($rest)*);
421    };
422}
423
424impl_bin_ops! {
425    for<M: Modulus> <StaticModInt<M>> @ <StaticModInt<M>> -> StaticModInt<M> { self @ }
426    for<M: Modulus> <StaticModInt<M>> @ <&'_ StaticModInt<M>> -> StaticModInt<M> { self @ * }
427    for<M: Modulus> <&'_ StaticModInt<M>> @ <StaticModInt<M>> -> StaticModInt<M> { self @ }
428    for<M: Modulus> <&'_ StaticModInt<M>> @ <&'_ StaticModInt<M>> -> StaticModInt<M> { self @ * }
429    for<I: DynamicModIntId> <DynamicModInt<I>> @ <DynamicModInt<I>> -> DynamicModInt<I> { self @ }
430    for<I: DynamicModIntId> <DynamicModInt<I>> @ <&'_ DynamicModInt<I>> -> DynamicModInt<I> { self @ * }
431    for<I: DynamicModIntId> <&'_ DynamicModInt<I>> @ <DynamicModInt<I>> -> DynamicModInt<I> { self @ }
432    for<I: DynamicModIntId> <&'_ DynamicModInt<I>> @ <&'_ DynamicModInt<I>> -> DynamicModInt<I> { self @ * }
433}
434
435impl_assign_ops! {
436    for<M: Modulus> <StaticModInt<M>> @= <StaticModInt<M>> { self @= }
437    for<M: Modulus> <StaticModInt<M>> @= <&'_ StaticModInt<M>> { self @= * }
438    for<I: DynamicModIntId> <DynamicModInt<I>> @= <DynamicModInt<I>> { self @= }
439    for<I: DynamicModIntId> <DynamicModInt<I>> @= <&'_ DynamicModInt<I>> { self @= * }
440}
441
442impl_folding! {
443    impl<M: Modulus> Sum<_> for StaticModInt<M> { fn sum(_) -> _ { _(Self::new(0), Add::add)} }
444    impl<M: Modulus> Product<_> for StaticModInt<M> { fn product(_) -> _ { _(Self::new(1), Mul::mul)} }
445    impl<I: DynamicModIntId> Sum<_> for DynamicModInt<I> { fn sum(_) -> _ { _(Self::new(0), Add::add)} }
446    impl<I: DynamicModIntId> Product<_> for DynamicModInt<I> { fn product(_) -> _ { _(Self::new(1), Mul::mul)} }
447}
448
449impl_basic_traits! {
450    impl<M: Modulus> _ for StaticModInt<M>;
451    impl<I: DynamicModIntId> _ for DynamicModInt<I>;
452}
453
454pub trait RemEuclidU32 {
455    fn rem_euclid_u32(self, n: u32) -> u32;
456}
457
458macro_rules! impl_rem_euclid_u32 {
459    ( $($ty:ty)* ) => { $(
460        impl RemEuclidU32 for $ty {
461            fn rem_euclid_u32(self, n: u32) -> u32 {
462                self.rem_euclid(n as $ty) as u32
463            }
464        }
465    )* }
466}
467
468macro_rules! impl_rem_euclid_u32_small {
469    ( $($ty:ty)* ) => { $(
470        impl RemEuclidU32 for $ty {
471            fn rem_euclid_u32(self, n: u32) -> u32 {
472                (self as u32).rem_euclid(n)
473            }
474        }
475    )* }
476}
477
478impl_rem_euclid_u32! {
479    i64 i128 isize u32 u64 u128 usize
480}
481
482impl_rem_euclid_u32_small! {
483    i8 i16 u8 u16
484}
485
486impl RemEuclidU32 for i32 {
487    fn rem_euclid_u32(self, n: u32) -> u32 {
488        if self >= 0 {
489            (self as u32).rem_euclid(n)
490        } else {
491            (self as i64).rem_euclid(n as i64) as u32
492        }
493    }
494}
495
496#[test]
497fn sanity_check() {
498    // assert!(Mod998244353::IS_PRIME);
499    // assert!(Mod1000000007::IS_PRIME);
500
501    type Mi = ModInt998244353;
502    assert_eq!(Mi::new(1) + Mi::new(998244352), Mi::new(0));
503    assert_eq!((Mi::new(1) / Mi::new(2)).get(), (Mi::modulus() + 1) / 2);
504
505    let sum10: Mi = (1..=10).map(Mi::new).sum();
506    assert_eq!(sum10, Mi::new(55));
507
508    let prod10: Mi = (1..=10).map(Mi::new).product();
509    assert_eq!(prod10, Mi::new(3628800));
510
511    type Md = DynamicModInt<DefaultId>;
512    Md::set_modulus(10);
513    assert_eq!(Md::new(5) + Md::new(7), Md::new(2));
514    assert_eq!(Md::new(3) * Md::new(4), Md::new(2));
515
516    Md::set_modulus(4);
517    assert_eq!(Md::new(5) + Md::new(7), Md::new(0));
518
519    let sum10: Md = (1..=10).map(Md::new).sum();
520    assert_eq!(sum10, Md::new(55));
521    assert_eq!(sum10.val, 55 % 4);
522}
523
524#[test]
525fn negative() {
526    assert_eq!(ModInt998244353::new(-1).get(), 998244352);
527}
528
529#[test]
530fn fmt() {
531    type Mi = ModInt998244353;
532
533    let x = Mi::new(123);
534    assert_eq!(format!("{}", x), "123");
535    assert_eq!(format!("{:?}", x), "123 (mod 998244353)");
536}