1use super::gcd_recip;
2use std::fmt::{self, Debug, Display};
3use std::hash::{Hash, Hasher};
4use std::iter::{Product, Sum};
5use std::marker::PhantomData;
6use std::ops::{
7 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign,
8};
9use std::sync::atomic::{self, AtomicU32, AtomicU64};
10
11use gcd_recip::GcdRecip;
12
13#[derive(Copy, Clone, Eq, PartialEq)]
14pub struct StaticModInt<M> {
15 val: u32,
16 _phd: PhantomData<fn() -> M>,
17}
18
19#[derive(Copy, Clone, Eq, PartialEq)]
20pub struct DynamicModInt<I> {
21 val: u32,
22 _phd: PhantomData<fn() -> I>,
23}
24
25pub trait Modulus: 'static + Copy + Eq {
26 const VALUE: u32;
27 #[cfg(ignore)]
28 const IS_PRIME: bool = is_prime_32(Self::VALUE);
29}
30
31pub trait ModIntBase:
32 Copy
33 + Eq
34 + Hash
35 + Add<Output = Self>
36 + Sub<Output = Self>
37 + Mul<Output = Self>
38 + Div<Output = Self>
39 + Neg
40 + AddAssign
41 + SubAssign
42 + MulAssign
43 + DivAssign
44{
45 fn modulus() -> u32;
46 fn get(self) -> u32;
47 fn new(n: impl RemEuclidU32) -> Self {
48 let n = n.rem_euclid_u32(Self::modulus());
49 unsafe { Self::new_unchecked(n) }
50 }
51 unsafe fn new_unchecked(n: u32) -> Self;
52 fn recip(self) -> Self { self.checked_recip().unwrap() }
53 fn checked_recip(self) -> Option<Self> {
54 let (g, r) = (self.get() as u64).gcd_recip(Self::modulus() as u64);
55 let r = r as u32;
56 if g == 1 { Some(unsafe { Self::new_unchecked(r) }) } else { None }
57 }
58 fn pow(self, mut iexp: u64) -> Self {
59 let mut res = Self::new(1);
60 let mut a = self;
61 while iexp > 0 {
62 if iexp & 1 != 0 {
63 res *= a;
64 }
65 a *= a;
66 iexp >>= 1;
67 }
68 res
69 }
70}
71
72trait InternalImpls: ModIntBase {
73 fn hash_impl(&self, state: &mut impl Hasher) { self.get().hash(state) }
74 fn display_impl(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 fmt::Display::fmt(&self.get(), f)
76 }
77 fn debug_impl(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 write!(f, "{} (mod {})", self.get(), Self::modulus())
79 }
80 fn neg_impl(self) -> Self {
81 let v = if self.get() == 0 { 0 } else { Self::modulus() - self.get() };
82 unsafe { Self::new_unchecked(v) }
83 }
84}
85
86impl<M: Modulus> StaticModInt<M> {
87 fn modulus() -> u32 { M::VALUE }
88 fn zero() -> Self { unsafe { Self::new_unchecked(0) } }
89 fn add_impl(self, rhs: Self) -> Self {
90 let mut tmp = self;
91 tmp += rhs;
92 tmp
93 }
94 fn sub_impl(self, rhs: Self) -> Self {
95 let mut tmp = self;
96 tmp -= rhs;
97 tmp
98 }
99 fn mul_impl(self, rhs: Self) -> Self {
100 let v = ((self.val as u64 * rhs.val as u64) % Self::modulus() as u64)
101 as u32;
102 unsafe { Self::new_unchecked(v) }
103 }
104 fn div_impl(self, rhs: Self) -> Self { self.mul_impl(rhs.recip()) }
105 fn add_assign_impl(&mut self, rhs: Self) {
106 self.val += rhs.val;
107 if self.val >= Self::modulus() {
108 self.val -= Self::modulus()
109 }
110 }
111 fn sub_assign_impl(&mut self, rhs: Self) {
112 if self.val < rhs.val {
113 self.val += Self::modulus()
114 }
115 self.val -= rhs.val
116 }
117 fn mul_assign_impl(&mut self, rhs: Self) { *self = self.mul_impl(rhs) }
118 fn div_assign_impl(&mut self, rhs: Self) { *self = self.div_impl(rhs) }
119}
120
121impl<M: Modulus> ModIntBase for StaticModInt<M> {
122 fn modulus() -> u32 { Self::modulus() }
123 fn get(self) -> u32 { self.val }
124 unsafe fn new_unchecked(val: u32) -> Self {
125 Self { val, _phd: PhantomData }
126 }
127}
128
129impl<M: Modulus> InternalImpls for StaticModInt<M> {}
130
131impl<I: RemEuclidU32, M: Modulus> From<I> for StaticModInt<M> {
132 fn from(x: I) -> Self { Self::new(x) }
133}
134
135#[cfg(ignore)]
136const fn is_sprp_32(n: u32, a: u32) -> bool {
137 let n = n as u64;
138 let s = (n - 1).trailing_zeros();
139 let d = n >> s;
140 let mut cur = {
141 let mut cur = 1;
142 let mut pow = d;
143 let mut a = a as u64;
144 while pow > 0 {
145 if pow & 1 != 0 {
146 cur = cur * a % n;
147 }
148 a = a * a % n;
149 pow >>= 1;
150 }
151 cur
152 };
153 if cur == 1 {
154 return true;
155 }
156 let mut i = 0;
157 while i < s {
158 if cur == n - 1 {
159 return true;
160 }
161 cur = cur * cur % n;
162 i += 1;
163 }
164 false
165}
166
167#[cfg(ignore)]
168const fn is_prime_32(n: u32) -> bool {
169 if n == 2 || n == 3 || n == 5 || n == 7 {
170 true
171 } else if n % 2 == 0 || n % 3 == 0 || n % 5 == 0 || n % 7 == 0 {
172 false
173 } else if n < 121 {
174 n > 1
175 } else {
176 is_sprp_32(n, 2) && is_sprp_32(n, 7) && is_sprp_32(n, 61)
177 }
178}
179
180pub trait DynamicModIntId: 'static + Copy + Eq {
181 fn barrett() -> &'static Barrett;
182}
183
184#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
185pub enum DefaultId {}
186
187pub struct Barrett {
188 m: AtomicU32,
189 im: AtomicU64,
190}
191
192impl DynamicModIntId for DefaultId {
193 fn barrett() -> &'static Barrett {
194 static BARRETT: Barrett = Barrett::default();
195 &BARRETT
196 }
197}
198
199impl Barrett {
200 pub const fn new(m: u32) -> Self {
201 Self {
202 m: AtomicU32::new(m),
203 im: AtomicU64::new(Self::im(m)),
204 }
205 }
206
207 pub const fn default() -> Self { Self::new(1) }
208
209 const fn im(m: u32) -> u64 {
210 (0_u64.wrapping_sub(1) / m as u64).wrapping_add(1)
211 }
212
213 fn set(&self, m: u32) {
214 self.m.store(m, atomic::Ordering::SeqCst);
215 self.im.store(Self::im(m), atomic::Ordering::SeqCst);
216 }
217
218 fn modulus(&self) -> u32 { self.m.load(atomic::Ordering::SeqCst) }
219
220 fn mul(&self, a: u32, b: u32) -> u32 {
221 let m = self.m.load(atomic::Ordering::SeqCst);
222 let im = self.im.load(atomic::Ordering::SeqCst);
223
224 let z = a as u64 * b as u64;
225 let x = ((z as u128 * im as u128) >> 64) as u64;
226 let y = x.wrapping_mul(m as u64);
227 let v = z.wrapping_sub(y) as u32;
228 v.wrapping_add(if z < y { m } else { 0 })
229 }
230}
231
232impl Default for Barrett {
233 fn default() -> Self { Self::default() }
234}
235
236impl<I: DynamicModIntId> DynamicModInt<I> {
237 pub fn modulus() -> u32 { I::barrett().modulus() }
238 fn zero() -> Self { unsafe { Self::new_unchecked(0) } }
239 fn add_impl(self, rhs: Self) -> Self {
240 let mut tmp = self;
241 tmp += rhs;
242 tmp
243 }
244 fn sub_impl(self, rhs: Self) -> Self {
245 let mut tmp = self;
246 tmp -= rhs;
247 tmp
248 }
249 fn mul_impl(self, rhs: Self) -> Self {
250 let v = I::barrett().mul(self.val, rhs.val);
251 unsafe { Self::new_unchecked(v) }
252 }
253 fn div_impl(self, rhs: Self) -> Self { self.mul_impl(rhs.recip()) }
254 fn add_assign_impl(&mut self, rhs: Self) {
255 self.val += rhs.val;
256 if self.val >= Self::modulus() {
257 self.val -= Self::modulus()
258 }
259 }
260 fn sub_assign_impl(&mut self, rhs: Self) {
261 if self.val < rhs.val {
262 self.val += Self::modulus()
263 }
264 self.val -= rhs.val
265 }
266 fn mul_assign_impl(&mut self, rhs: Self) { *self = self.mul_impl(rhs) }
267 fn div_assign_impl(&mut self, rhs: Self) { *self = self.div_impl(rhs) }
268
269 pub fn set_modulus(m: u32) {
270 if !(1..=1 << 31).contains(&m) {
273 panic!("the modulus must be in range (1, 2**31)");
274 }
275 I::barrett().set(m);
276 }
277}
278
279impl<I: DynamicModIntId> ModIntBase for DynamicModInt<I> {
280 fn modulus() -> u32 { Self::modulus() }
281 fn get(self) -> u32 { self.val }
282 unsafe fn new_unchecked(val: u32) -> Self {
283 Self { val, _phd: PhantomData }
284 }
285}
286
287impl<I: DynamicModIntId> InternalImpls for DynamicModInt<I> {}
288
289impl<J: RemEuclidU32, I: DynamicModIntId> From<J> for DynamicModInt<I> {
290 fn from(x: J) -> Self { Self::new(x) }
291}
292
293macro_rules! impl_modint {
294 ( $( ($mod:ident, $val:literal, $modint:ident), )* ) => { $(
295 #[derive(Clone, Copy, Eq, PartialEq)]
296 pub struct $mod;
297 impl Modulus for $mod {
298 const VALUE: u32 = $val;
299 }
300 pub type $modint = StaticModInt<$mod>;
301 )* }
302}
303
304impl_modint! {
305 (Mod998244353, 998244353, ModInt998244353),
306 (Mod1000000007, 1000000007, ModInt1000000007),
307}
308
309macro_rules! impl_bin_ops {
310 () => {};
311 (
312 for<$($generic_param:ident : $bound:tt),*>
313 <$lhs_ty:ty> @ <$rhs_ty:ty> -> $output:ty
314 { self @ $($rhs_deref:tt)? } $($rest:tt)*
315 ) => {
316 impl <$($generic_param: $bound),*> Add<$rhs_ty> for $lhs_ty {
317 type Output = $output;
318 fn add(self, rhs: $rhs_ty) -> $output { self.add_impl($($rhs_deref)? rhs) }
319 }
320 impl <$($generic_param: $bound),*> Sub<$rhs_ty> for $lhs_ty {
321 type Output = $output;
322 fn sub(self, rhs: $rhs_ty) -> $output { self.sub_impl($($rhs_deref)? rhs) }
323 }
324 impl <$($generic_param: $bound),*> Mul<$rhs_ty> for $lhs_ty {
325 type Output = $output;
326 fn mul(self, rhs: $rhs_ty) -> $output { self.mul_impl($($rhs_deref)? rhs) }
327 }
328 impl <$($generic_param: $bound),*> Div<$rhs_ty> for $lhs_ty {
329 type Output = $output;
330 fn div(self, rhs: $rhs_ty) -> $output { self.div_impl($($rhs_deref)? rhs) }
331 }
332 impl_bin_ops!($($rest)*);
333 };
334}
335
336macro_rules! impl_assign_ops {
337 () => {};
338 (
339 for<$($generic_param:ident : $bound:tt),*>
340 <$lhs_ty:ty> @= <$rhs_ty:ty>
341 { self @= $($rhs_deref:tt)? } $($rest:tt)*
342 ) => {
343 impl <$($generic_param: $bound),*> AddAssign<$rhs_ty> for $lhs_ty {
344 fn add_assign(&mut self, rhs: $rhs_ty) {
345 self.add_assign_impl($($rhs_deref)? rhs);
346 }
347 }
348 impl <$($generic_param: $bound),*> SubAssign<$rhs_ty> for $lhs_ty {
349 fn sub_assign(&mut self, rhs: $rhs_ty) {
350 self.sub_assign_impl($($rhs_deref)? rhs);
351 }
352 }
353 impl <$($generic_param: $bound),*> MulAssign<$rhs_ty> for $lhs_ty {
354 fn mul_assign(&mut self, rhs: $rhs_ty) {
355 self.mul_assign_impl($($rhs_deref)? rhs);
356 }
357 }
358 impl <$($generic_param: $bound),*> DivAssign<$rhs_ty> for $lhs_ty {
359 fn div_assign(&mut self, rhs: $rhs_ty) {
360 self.div_assign_impl($($rhs_deref)? rhs);
361 }
362 }
363 impl_assign_ops!($($rest)*);
364 };
365}
366
367macro_rules! impl_folding {
368 () => {};
369 (
370 impl <$generic_param:ident : $bound:tt> $trait:ident<_>
371 for $self:ty
372 {
373 fn $method:ident(_) -> _ { _($unit:expr, $op:expr) }
374 }
375 $($rest:tt)*
376 ) => {
377 impl<$generic_param: $bound> $trait<Self> for $self {
378 fn $method<S>(iter: S) -> Self
379 where
380 S: Iterator<Item = Self>,
381 {
382 iter.fold($unit, $op)
383 }
384 }
385 impl<'a, $generic_param: $bound> $trait<&'a Self> for $self {
386 fn $method<S>(iter: S) -> Self
387 where
388 S: Iterator<Item = &'a Self>,
389 {
390 iter.fold($unit, $op)
391 }
392 }
393 impl_folding!($($rest)*);
394 };
395}
396
397macro_rules! impl_basic_traits {
398 () => {};
399 (impl<$generic_param:ident : $bound:tt> _ for $self:ty; $($rest:tt)*) => {
400 impl<$generic_param: $bound> Default for $self {
401 fn default() -> Self { Self::zero() }
402 }
403 impl<$generic_param: $bound> Hash for $self {
404 fn hash<H: Hasher>(&self, state: &mut H) { self.hash_impl(state) }
405 }
406 impl<$generic_param: $bound> Display for $self {
407 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
408 self.display_impl(f)
409 }
410 }
411 impl<$generic_param: $bound> Debug for $self {
412 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413 self.debug_impl(f)
414 }
415 }
416 impl<$generic_param: $bound> Neg for $self {
417 type Output = $self;
418 fn neg(self) -> Self { self.neg_impl() }
419 }
420 impl_basic_traits!($($rest)*);
421 };
422}
423
424impl_bin_ops! {
425 for<M: Modulus> <StaticModInt<M>> @ <StaticModInt<M>> -> StaticModInt<M> { self @ }
426 for<M: Modulus> <StaticModInt<M>> @ <&'_ StaticModInt<M>> -> StaticModInt<M> { self @ * }
427 for<M: Modulus> <&'_ StaticModInt<M>> @ <StaticModInt<M>> -> StaticModInt<M> { self @ }
428 for<M: Modulus> <&'_ StaticModInt<M>> @ <&'_ StaticModInt<M>> -> StaticModInt<M> { self @ * }
429 for<I: DynamicModIntId> <DynamicModInt<I>> @ <DynamicModInt<I>> -> DynamicModInt<I> { self @ }
430 for<I: DynamicModIntId> <DynamicModInt<I>> @ <&'_ DynamicModInt<I>> -> DynamicModInt<I> { self @ * }
431 for<I: DynamicModIntId> <&'_ DynamicModInt<I>> @ <DynamicModInt<I>> -> DynamicModInt<I> { self @ }
432 for<I: DynamicModIntId> <&'_ DynamicModInt<I>> @ <&'_ DynamicModInt<I>> -> DynamicModInt<I> { self @ * }
433}
434
435impl_assign_ops! {
436 for<M: Modulus> <StaticModInt<M>> @= <StaticModInt<M>> { self @= }
437 for<M: Modulus> <StaticModInt<M>> @= <&'_ StaticModInt<M>> { self @= * }
438 for<I: DynamicModIntId> <DynamicModInt<I>> @= <DynamicModInt<I>> { self @= }
439 for<I: DynamicModIntId> <DynamicModInt<I>> @= <&'_ DynamicModInt<I>> { self @= * }
440}
441
442impl_folding! {
443 impl<M: Modulus> Sum<_> for StaticModInt<M> { fn sum(_) -> _ { _(Self::new(0), Add::add)} }
444 impl<M: Modulus> Product<_> for StaticModInt<M> { fn product(_) -> _ { _(Self::new(1), Mul::mul)} }
445 impl<I: DynamicModIntId> Sum<_> for DynamicModInt<I> { fn sum(_) -> _ { _(Self::new(0), Add::add)} }
446 impl<I: DynamicModIntId> Product<_> for DynamicModInt<I> { fn product(_) -> _ { _(Self::new(1), Mul::mul)} }
447}
448
449impl_basic_traits! {
450 impl<M: Modulus> _ for StaticModInt<M>;
451 impl<I: DynamicModIntId> _ for DynamicModInt<I>;
452}
453
454pub trait RemEuclidU32 {
455 fn rem_euclid_u32(self, n: u32) -> u32;
456}
457
458macro_rules! impl_rem_euclid_u32 {
459 ( $($ty:ty)* ) => { $(
460 impl RemEuclidU32 for $ty {
461 fn rem_euclid_u32(self, n: u32) -> u32 {
462 self.rem_euclid(n as $ty) as u32
463 }
464 }
465 )* }
466}
467
468macro_rules! impl_rem_euclid_u32_small {
469 ( $($ty:ty)* ) => { $(
470 impl RemEuclidU32 for $ty {
471 fn rem_euclid_u32(self, n: u32) -> u32 {
472 (self as u32).rem_euclid(n)
473 }
474 }
475 )* }
476}
477
478impl_rem_euclid_u32! {
479 i64 i128 isize u32 u64 u128 usize
480}
481
482impl_rem_euclid_u32_small! {
483 i8 i16 u8 u16
484}
485
486impl RemEuclidU32 for i32 {
487 fn rem_euclid_u32(self, n: u32) -> u32 {
488 if self >= 0 {
489 (self as u32).rem_euclid(n)
490 } else {
491 (self as i64).rem_euclid(n as i64) as u32
492 }
493 }
494}
495
496#[test]
497fn sanity_check() {
498 type Mi = ModInt998244353;
502 assert_eq!(Mi::new(1) + Mi::new(998244352), Mi::new(0));
503 assert_eq!((Mi::new(1) / Mi::new(2)).get(), (Mi::modulus() + 1) / 2);
504
505 let sum10: Mi = (1..=10).map(Mi::new).sum();
506 assert_eq!(sum10, Mi::new(55));
507
508 let prod10: Mi = (1..=10).map(Mi::new).product();
509 assert_eq!(prod10, Mi::new(3628800));
510
511 type Md = DynamicModInt<DefaultId>;
512 Md::set_modulus(10);
513 assert_eq!(Md::new(5) + Md::new(7), Md::new(2));
514 assert_eq!(Md::new(3) * Md::new(4), Md::new(2));
515
516 Md::set_modulus(4);
517 assert_eq!(Md::new(5) + Md::new(7), Md::new(0));
518
519 let sum10: Md = (1..=10).map(Md::new).sum();
520 assert_eq!(sum10, Md::new(55));
521 assert_eq!(sum10.val, 55 % 4);
522}
523
524#[test]
525fn negative() {
526 assert_eq!(ModInt998244353::new(-1).get(), 998244352);
527}
528
529#[test]
530fn fmt() {
531 type Mi = ModInt998244353;
532
533 let x = Mi::new(123);
534 assert_eq!(format!("{}", x), "123");
535 assert_eq!(format!("{:?}", x), "123 (mod 998244353)");
536}