1use std::{
2 fmt,
3 hash::{Hash, Hasher},
4 iter::{Product, Sum},
5 ops::{
6 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign,
7 },
8};
9
10use bin_iter::BinIter;
11
12#[derive(Clone, Copy, Eq, PartialEq)]
13pub struct StaticModInt<const MOD: u32>(u32);
14
15impl<const MOD: u32> StaticModInt<MOD> {
16 pub fn new(val: impl RemEuclidU32) -> Self { Self::from(val) }
17 pub fn modulus() -> u32 { MOD }
18}
19
20impl<const MOD: u32> AddAssign for StaticModInt<MOD> {
21 fn add_assign(&mut self, rhs: Self) {
22 self.0 += rhs.0;
23 if self.0 >= MOD {
24 self.0 -= MOD;
25 }
26 }
27}
28
29impl<const MOD: u32> SubAssign for StaticModInt<MOD> {
30 fn sub_assign(&mut self, rhs: Self) {
31 if self.0 < rhs.0 {
32 self.0 += MOD;
33 }
34 self.0 -= rhs.0
35 }
36}
37
38impl<const MOD: u32> MulAssign for StaticModInt<MOD> {
39 fn mul_assign(&mut self, rhs: Self) {
40 let tmp = (self.0 as u64) * (rhs.0 as u64) % MOD as u64;
41 self.0 = tmp as u32;
42 }
43}
44
45impl<const MOD: u32> DivAssign for StaticModInt<MOD> {
46 fn div_assign(&mut self, rhs: Self) { *self *= rhs.recip() }
47}
48
49impl<const MOD: u32> StaticModInt<MOD> {
50 pub fn recip(self) -> Self { self.checked_recip().unwrap() }
51 pub fn checked_recip(self) -> Option<Self> { Some(self.pow(MOD - 2)) }
53 pub fn pow(self, exp: impl BinIter) -> Self {
54 let mut res = Self::new(1);
55 let mut dbl = self;
56 for b in exp.bin_iter() {
57 if b {
58 res *= dbl;
59 }
60 dbl *= dbl;
61 }
62 res
63 }
64}
65
66macro_rules! impl_bin_op_inner {
67 ( $(
68 impl<_> $op_trait:ident<$(&$ltr:lifetime)? Self> for $(&$ltl:lifetime)? Self {
69 fn $op:ident(..) -> _ { self.$op_assign:ident() }
70 }
71 )* ) => { $(
72 impl<const MOD: u32> $op_trait<$(&$ltr)? StaticModInt<MOD>> for $(&$ltl)? StaticModInt<MOD> {
73 type Output = StaticModInt<MOD>;
74 fn $op(self, rhs: $(&$ltr)? StaticModInt<MOD>) -> Self::Output {
75 let mut tmp = self.to_owned();
76 tmp.$op_assign(rhs);
77 tmp
78 }
79 }
80 )* };
81}
82
83macro_rules! impl_bin_op {
84 ( $( ($op:ident, $op_trait:ident, $op_assign:ident, $op_assign_trait:ident), )* ) => { $(
85 impl_bin_op_inner! {
86 impl<_> $op_trait<Self> for Self { fn $op(..) -> _ { self.$op_assign() } }
87 impl<_> $op_trait<&'_ Self> for Self { fn $op(..) -> _ { self.$op_assign() } }
88 impl<_> $op_trait<Self> for &'_ Self { fn $op(..) -> _ { self.$op_assign() } }
89 impl<_> $op_trait<&'_ Self> for &'_ Self { fn $op(..) -> _ { self.$op_assign() } }
90 }
91 impl<const MOD: u32> $op_assign_trait<&Self> for StaticModInt<MOD> {
92 fn $op_assign(&mut self, rhs: &Self) { self.$op_assign(rhs.to_owned()) }
93 }
94 )* }
95}
96
97impl_bin_op! {
98 ( add, Add, add_assign, AddAssign ),
99 ( sub, Sub, sub_assign, SubAssign ),
100 ( mul, Mul, mul_assign, MulAssign ),
101 ( div, Div, div_assign, DivAssign ),
102}
103
104impl<const MOD: u32> Neg for StaticModInt<MOD> {
105 type Output = StaticModInt<MOD>;
106 fn neg(self) -> Self::Output {
107 if self.0 == 0 { self } else { StaticModInt(MOD - self.0) }
108 }
109}
110
111impl<const MOD: u32> Neg for &StaticModInt<MOD> {
112 type Output = StaticModInt<MOD>;
113 fn neg(self) -> Self::Output {
114 if self.0 == 0 { *self } else { StaticModInt(MOD - self.0) }
115 }
116}
117
118impl<const MOD: u32> fmt::Display for StaticModInt<MOD> {
119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) }
120}
121
122impl<const MOD: u32> fmt::Debug for StaticModInt<MOD> {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 write!(f, "{} mod {}", self, MOD)
125 }
126}
127
128pub trait RemEuclidU32 {
129 fn rem_euclid_u32(&self, rem: u32) -> u32;
130}
131
132macro_rules! impl_rem_euclid_u32 {
133 ( $( ($lhs:ty, $common:ty) ),* ) => { $(
134 impl RemEuclidU32 for $lhs {
135 fn rem_euclid_u32(&self, rem: u32) -> u32 {
136 (*self as $common).rem_euclid(rem as $common) as u32
137 }
138 }
139 )* }
140}
141
142macro_rules! impl_rem_euclid_u32_small {
143 ( $($lhs:ty)* ) => { impl_rem_euclid_u32! { $( ($lhs, u32) ),* } }
144}
145
146macro_rules! impl_rem_euclid_u32_large {
147 ( $($lhs:ty)* ) => { impl_rem_euclid_u32! { $( ($lhs, $lhs) ),* } }
148}
149
150impl_rem_euclid_u32_small! { u8 u16 u32 }
151impl_rem_euclid_u32_large! { i64 i128 u64 u128 }
152impl_rem_euclid_u32! { (i8, i32), (i16, i32), (i32, i64), (isize, isize), (usize, usize) }
153
154impl<const MOD: u32, I: RemEuclidU32> From<I> for StaticModInt<MOD> {
155 fn from(val: I) -> Self { Self(val.rem_euclid_u32(MOD)) }
156}
157
158macro_rules! impl_folding_inner {
159 ( $(
160 impl<$($lt:lifetime,)? _> $op_trait:ident<$(&$ltr:lifetime)? Self> for Self {
161 fn $op:ident(..) -> _ { $unit:literal; self.$op_assign:ident($($deref:tt)?) }
162 }
163 )* ) => { $(
164 impl<$($lt,)? const MOD: u32> $op_trait<$(&$ltr)? StaticModInt<MOD>> for StaticModInt<MOD> {
165 fn $op<I>(iter: I) -> StaticModInt<MOD>
166 where
167 I: Iterator<Item = $(&$ltr)? StaticModInt<MOD>>,
168 {
169 let mut res = StaticModInt::new($unit);
170 for x in iter {
171 res.$op_assign(x);
172 }
173 res
174 }
175 }
176 )* };
177}
178
179macro_rules! impl_folding {
180 ( $( ($op:ident, $op_trait:ident, $op_assign:ident, $unit:literal), )* ) => { $(
181 impl_folding_inner! {
182 impl<_> $op_trait<Self> for Self { fn $op(..) -> _ { $unit; self.$op_assign() } }
183 impl<'a, _> $op_trait<&'a Self> for Self { fn $op(..) -> _ { $unit; self.$op_assign() } }
184 }
185 )* }
186}
187
188impl_folding! {
189 ( sum, Sum, add_assign, 0 ),
190 ( product, Product, mul_assign, 1 ),
191}
192
193impl<const MOD: u32> Hash for StaticModInt<MOD> {
194 fn hash<H: Hasher>(&self, state: &mut H) { self.0.hash(state) }
195}
196
197pub type ModInt998244353 = StaticModInt<998244353>;
198pub type ModInt1000000007 = StaticModInt<1000000007>;
199
200#[test]
201fn arithmetic() {
202 type Mi = ModInt998244353;
203
204 let zero = Mi::new(0);
205 let half = Mi::new(499122177);
206 let quarter = Mi::new(748683265);
207 let one = Mi::new(1);
208 let two = Mi::new(2);
209 assert_eq!(Mi::new(Mi::modulus()), zero);
210 assert_eq!(half + half, one);
211 assert_eq!(zero - half, -half);
212 assert_eq!(one - half, half);
213 assert_eq!(half * two, one);
214 assert_eq!(half * half, quarter);
215 assert_eq!(one / two, half);
216 assert_eq!(two.pow(998244352_u64), one);
217}
218
219#[test]
220fn folding() {
221 type Mi = ModInt998244353;
222
223 let a: Vec<_> = [1, 2, 3, 4].iter().copied().map(Mi::new).collect();
224 let sum = Mi::new(10);
225 let prod = Mi::new(24);
226 assert_eq!(a.iter().sum::<Mi>(), sum);
227 assert_eq!(a.iter().product::<Mi>(), prod);
228 assert_eq!(a.iter().copied().sum::<Mi>(), sum);
229 assert_eq!(a.iter().copied().product::<Mi>(), prod);
230
231 let two = Mi::new(2);
232 assert_eq!(two.pow(100_u32), [two; 100].iter().product());
233}
234
235#[test]
236fn fmt() {
237 type Mi = ModInt998244353;
238
239 let one = Mi::new(1);
240 assert_eq!(format!("{}", one), "1");
241 assert_eq!(format!("{:?}", one), "1 mod 998244353");
242 assert_eq!(format!("{:?}", [one; 2]), "[1 mod 998244353, 1 mod 998244353]");
243}
244
245#[test]
246fn conversion() {
247 type Mi = ModInt998244353;
248
249 assert_eq!(Mi::new(-1_i8).0, 998244352);
250 assert_eq!(Mi::new(-1_i16).0, 998244352);
251 assert_eq!(Mi::new(-1_i32).0, 998244352);
252 assert_eq!(Mi::new(-1_i64).0, 998244352);
253 assert_eq!(Mi::new(-1_i128).0, 998244352);
254
255 assert_eq!(Mi::new(998244354_i32).0, 1);
256 assert_eq!(Mi::new(998244354_i64).0, 1);
257 assert_eq!(Mi::new(998244354_i128).0, 1);
258
259 assert_eq!(Mi::new(998244354_u32).0, 1);
260 assert_eq!(Mi::new(998244354_u64).0, 1);
261 assert_eq!(Mi::new(998244354_u128).0, 1);
262
263 assert_eq!(Mi::new(10_i8).0, 10);
264 assert_eq!(Mi::new(10_i16).0, 10);
265 assert_eq!(Mi::new(10_i32).0, 10);
266 assert_eq!(Mi::new(10_i64).0, 10);
267 assert_eq!(Mi::new(10_i128).0, 10);
268 assert_eq!(Mi::new(10_isize).0, 10);
269
270 assert_eq!(Mi::new(10_u8).0, 10);
271 assert_eq!(Mi::new(10_u16).0, 10);
272 assert_eq!(Mi::new(10_u32).0, 10);
273 assert_eq!(Mi::new(10_u64).0, 10);
274 assert_eq!(Mi::new(10_u128).0, 10);
275 assert_eq!(Mi::new(10_usize).0, 10);
276}