modint/
lib.rs

1use std::{
2    fmt,
3    hash::{Hash, Hasher},
4    iter::{Product, Sum},
5    ops::{
6        Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign,
7    },
8};
9
10use bin_iter::BinIter;
11
12#[derive(Clone, Copy, Eq, PartialEq)]
13pub struct StaticModInt<const MOD: u32>(u32);
14
15impl<const MOD: u32> StaticModInt<MOD> {
16    pub fn new(val: impl RemEuclidU32) -> Self { Self::from(val) }
17    pub fn modulus() -> u32 { MOD }
18}
19
20impl<const MOD: u32> AddAssign for StaticModInt<MOD> {
21    fn add_assign(&mut self, rhs: Self) {
22        self.0 += rhs.0;
23        if self.0 >= MOD {
24            self.0 -= MOD;
25        }
26    }
27}
28
29impl<const MOD: u32> SubAssign for StaticModInt<MOD> {
30    fn sub_assign(&mut self, rhs: Self) {
31        if self.0 < rhs.0 {
32            self.0 += MOD;
33        }
34        self.0 -= rhs.0
35    }
36}
37
38impl<const MOD: u32> MulAssign for StaticModInt<MOD> {
39    fn mul_assign(&mut self, rhs: Self) {
40        let tmp = (self.0 as u64) * (rhs.0 as u64) % MOD as u64;
41        self.0 = tmp as u32;
42    }
43}
44
45impl<const MOD: u32> DivAssign for StaticModInt<MOD> {
46    fn div_assign(&mut self, rhs: Self) { *self *= rhs.recip() }
47}
48
49impl<const MOD: u32> StaticModInt<MOD> {
50    pub fn recip(self) -> Self { self.checked_recip().unwrap() }
51    // XXX use Euclidean algorithm
52    pub fn checked_recip(self) -> Option<Self> { Some(self.pow(MOD - 2)) }
53    pub fn pow(self, exp: impl BinIter) -> Self {
54        let mut res = Self::new(1);
55        let mut dbl = self;
56        for b in exp.bin_iter() {
57            if b {
58                res *= dbl;
59            }
60            dbl *= dbl;
61        }
62        res
63    }
64}
65
66macro_rules! impl_bin_op_inner {
67    ( $(
68        impl<_> $op_trait:ident<$(&$ltr:lifetime)? Self> for $(&$ltl:lifetime)? Self {
69            fn $op:ident(..) -> _ { self.$op_assign:ident() }
70        }
71    )* ) => { $(
72        impl<const MOD: u32> $op_trait<$(&$ltr)? StaticModInt<MOD>> for $(&$ltl)? StaticModInt<MOD> {
73            type Output = StaticModInt<MOD>;
74            fn $op(self, rhs: $(&$ltr)? StaticModInt<MOD>) -> Self::Output {
75                let mut tmp = self.to_owned();
76                tmp.$op_assign(rhs);
77                tmp
78            }
79        }
80    )* };
81}
82
83macro_rules! impl_bin_op {
84    ( $( ($op:ident, $op_trait:ident, $op_assign:ident, $op_assign_trait:ident), )* ) => { $(
85        impl_bin_op_inner! {
86            impl<_> $op_trait<Self> for Self { fn $op(..) -> _ { self.$op_assign() } }
87            impl<_> $op_trait<&'_ Self> for Self { fn $op(..) -> _ { self.$op_assign() } }
88            impl<_> $op_trait<Self> for &'_ Self { fn $op(..) -> _ { self.$op_assign() } }
89            impl<_> $op_trait<&'_ Self> for &'_ Self { fn $op(..) -> _ { self.$op_assign() } }
90        }
91        impl<const MOD: u32> $op_assign_trait<&Self> for StaticModInt<MOD> {
92            fn $op_assign(&mut self, rhs: &Self) { self.$op_assign(rhs.to_owned()) }
93        }
94    )* }
95}
96
97impl_bin_op! {
98    ( add, Add, add_assign, AddAssign ),
99    ( sub, Sub, sub_assign, SubAssign ),
100    ( mul, Mul, mul_assign, MulAssign ),
101    ( div, Div, div_assign, DivAssign ),
102}
103
104impl<const MOD: u32> Neg for StaticModInt<MOD> {
105    type Output = StaticModInt<MOD>;
106    fn neg(self) -> Self::Output {
107        if self.0 == 0 { self } else { StaticModInt(MOD - self.0) }
108    }
109}
110
111impl<const MOD: u32> Neg for &StaticModInt<MOD> {
112    type Output = StaticModInt<MOD>;
113    fn neg(self) -> Self::Output {
114        if self.0 == 0 { *self } else { StaticModInt(MOD - self.0) }
115    }
116}
117
118impl<const MOD: u32> fmt::Display for StaticModInt<MOD> {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) }
120}
121
122impl<const MOD: u32> fmt::Debug for StaticModInt<MOD> {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        write!(f, "{} mod {}", self, MOD)
125    }
126}
127
128pub trait RemEuclidU32 {
129    fn rem_euclid_u32(&self, rem: u32) -> u32;
130}
131
132macro_rules! impl_rem_euclid_u32 {
133    ( $( ($lhs:ty, $common:ty) ),* ) => { $(
134        impl RemEuclidU32 for $lhs {
135            fn rem_euclid_u32(&self, rem: u32) -> u32 {
136                (*self as $common).rem_euclid(rem as $common) as u32
137            }
138        }
139    )* }
140}
141
142macro_rules! impl_rem_euclid_u32_small {
143    ( $($lhs:ty)* ) => { impl_rem_euclid_u32! { $( ($lhs, u32) ),* } }
144}
145
146macro_rules! impl_rem_euclid_u32_large {
147    ( $($lhs:ty)* ) => { impl_rem_euclid_u32! { $( ($lhs, $lhs) ),* } }
148}
149
150impl_rem_euclid_u32_small! { u8 u16 u32 }
151impl_rem_euclid_u32_large! { i64 i128 u64 u128 }
152impl_rem_euclid_u32! { (i8, i32), (i16, i32), (i32, i64), (isize, isize), (usize, usize) }
153
154impl<const MOD: u32, I: RemEuclidU32> From<I> for StaticModInt<MOD> {
155    fn from(val: I) -> Self { Self(val.rem_euclid_u32(MOD)) }
156}
157
158macro_rules! impl_folding_inner {
159    ( $(
160        impl<$($lt:lifetime,)? _> $op_trait:ident<$(&$ltr:lifetime)? Self> for Self {
161            fn $op:ident(..) -> _ { $unit:literal; self.$op_assign:ident($($deref:tt)?) }
162        }
163    )* ) => { $(
164        impl<$($lt,)? const MOD: u32> $op_trait<$(&$ltr)? StaticModInt<MOD>> for StaticModInt<MOD> {
165            fn $op<I>(iter: I) -> StaticModInt<MOD>
166            where
167                I: Iterator<Item = $(&$ltr)? StaticModInt<MOD>>,
168            {
169                let mut res = StaticModInt::new($unit);
170                for x in iter {
171                    res.$op_assign(x);
172                }
173                res
174            }
175        }
176    )* };
177}
178
179macro_rules! impl_folding {
180    ( $( ($op:ident, $op_trait:ident, $op_assign:ident, $unit:literal), )* ) => { $(
181        impl_folding_inner! {
182            impl<_> $op_trait<Self> for Self { fn $op(..) -> _ { $unit; self.$op_assign() } }
183            impl<'a, _> $op_trait<&'a Self> for Self { fn $op(..) -> _ { $unit; self.$op_assign() } }
184        }
185    )* }
186}
187
188impl_folding! {
189    ( sum, Sum, add_assign, 0 ),
190    ( product, Product, mul_assign, 1 ),
191}
192
193impl<const MOD: u32> Hash for StaticModInt<MOD> {
194    fn hash<H: Hasher>(&self, state: &mut H) { self.0.hash(state) }
195}
196
197pub type ModInt998244353 = StaticModInt<998244353>;
198pub type ModInt1000000007 = StaticModInt<1000000007>;
199
200#[test]
201fn arithmetic() {
202    type Mi = ModInt998244353;
203
204    let zero = Mi::new(0);
205    let half = Mi::new(499122177);
206    let quarter = Mi::new(748683265);
207    let one = Mi::new(1);
208    let two = Mi::new(2);
209    assert_eq!(Mi::new(Mi::modulus()), zero);
210    assert_eq!(half + half, one);
211    assert_eq!(zero - half, -half);
212    assert_eq!(one - half, half);
213    assert_eq!(half * two, one);
214    assert_eq!(half * half, quarter);
215    assert_eq!(one / two, half);
216    assert_eq!(two.pow(998244352_u64), one);
217}
218
219#[test]
220fn folding() {
221    type Mi = ModInt998244353;
222
223    let a: Vec<_> = [1, 2, 3, 4].iter().copied().map(Mi::new).collect();
224    let sum = Mi::new(10);
225    let prod = Mi::new(24);
226    assert_eq!(a.iter().sum::<Mi>(), sum);
227    assert_eq!(a.iter().product::<Mi>(), prod);
228    assert_eq!(a.iter().copied().sum::<Mi>(), sum);
229    assert_eq!(a.iter().copied().product::<Mi>(), prod);
230
231    let two = Mi::new(2);
232    assert_eq!(two.pow(100_u32), [two; 100].iter().product());
233}
234
235#[test]
236fn fmt() {
237    type Mi = ModInt998244353;
238
239    let one = Mi::new(1);
240    assert_eq!(format!("{}", one), "1");
241    assert_eq!(format!("{:?}", one), "1 mod 998244353");
242    assert_eq!(format!("{:?}", [one; 2]), "[1 mod 998244353, 1 mod 998244353]");
243}
244
245#[test]
246fn conversion() {
247    type Mi = ModInt998244353;
248
249    assert_eq!(Mi::new(-1_i8).0, 998244352);
250    assert_eq!(Mi::new(-1_i16).0, 998244352);
251    assert_eq!(Mi::new(-1_i32).0, 998244352);
252    assert_eq!(Mi::new(-1_i64).0, 998244352);
253    assert_eq!(Mi::new(-1_i128).0, 998244352);
254
255    assert_eq!(Mi::new(998244354_i32).0, 1);
256    assert_eq!(Mi::new(998244354_i64).0, 1);
257    assert_eq!(Mi::new(998244354_i128).0, 1);
258
259    assert_eq!(Mi::new(998244354_u32).0, 1);
260    assert_eq!(Mi::new(998244354_u64).0, 1);
261    assert_eq!(Mi::new(998244354_u128).0, 1);
262
263    assert_eq!(Mi::new(10_i8).0, 10);
264    assert_eq!(Mi::new(10_i16).0, 10);
265    assert_eq!(Mi::new(10_i32).0, 10);
266    assert_eq!(Mi::new(10_i64).0, 10);
267    assert_eq!(Mi::new(10_i128).0, 10);
268    assert_eq!(Mi::new(10_isize).0, 10);
269
270    assert_eq!(Mi::new(10_u8).0, 10);
271    assert_eq!(Mi::new(10_u16).0, 10);
272    assert_eq!(Mi::new(10_u32).0, 10);
273    assert_eq!(Mi::new(10_u64).0, 10);
274    assert_eq!(Mi::new(10_u128).0, 10);
275    assert_eq!(Mi::new(10_usize).0, 10);
276}