1use super::garner;
2use super::modint;
3use std::sync::OnceLock;
4
5use garner::{CrtMod as CrtModInternal, CrtWrapping};
6use modint::{Mod998244353, ModIntBase, Modulus, RemEuclidU32, StaticModInt};
7
8pub struct ButterflyCache<M: NttFriendly> {
9 root: Vec<StaticModInt<M>>,
10 iroot: Vec<StaticModInt<M>>,
11 rate2: Vec<StaticModInt<M>>,
12 irate2: Vec<StaticModInt<M>>,
13 rate3: Vec<StaticModInt<M>>,
14 irate3: Vec<StaticModInt<M>>,
15}
16
17const fn primitive_root(p: u32) -> u32 {
18 if p == 2 {
19 return 1;
20 }
21
22 let mut divs = [0; 10];
24 divs[0] = 2;
25 let mut index = 1;
26 let mut x = (p - 1) / 2;
27 while x % 2 == 0 {
28 x /= 2;
29 }
30 let mut d = 3;
31 while d <= x / d {
32 if x % d == 0 {
33 divs[index] = d;
34 index += 1;
35 while x % d == 0 {
36 x /= d;
37 }
38 }
39 d += 2;
40 }
41 if x > 1 {
42 divs[index] = x;
43 index += 1;
44 }
45
46 let mut g = 2;
47 loop {
48 let mut ok = true;
49 let mut i = 0;
50 while i < index {
51 if mod_pow(g, (p - 1) / divs[i], p) == 1 {
52 ok = false;
53 break;
54 }
55 i += 1;
56 }
57 if ok {
58 return g;
59 }
60 g += 1;
61 }
62}
63
64const fn mod_pow(a: u32, mut e: u32, m: u32) -> u32 {
65 let mut res = 1;
66 let mut a = a as u64;
67 let m = m as u64;
68 while e > 0 {
69 if e & 1 != 0 {
70 res = res * a % m;
71 }
72 a = a * a % m;
73 e >>= 1;
74 }
75 res as u32
76}
77
78pub trait NttFriendly: Modulus {
79 const PRIMITIVE_ROOT: u32 = primitive_root(Self::VALUE);
80 const EXP: u32 = (Self::VALUE - 1).trailing_zeros();
82 const ODD: u32 = Self::VALUE >> Self::EXP;
83
84 fn cache() -> &'static OnceLock<ButterflyCache<Self>>;
85 fn butterfly_cache() -> &'static ButterflyCache<Self> {
86 Self::cache().get_or_init(precompute_butterfly)
87 }
88}
89
90static MOD998244353_CACHE: OnceLock<ButterflyCache<Mod998244353>> =
91 OnceLock::new();
92
93impl NttFriendly for Mod998244353 {
94 fn cache() -> &'static OnceLock<ButterflyCache<Self>> {
95 &MOD998244353_CACHE
96 }
97}
98
99fn precompute_butterfly<M: NttFriendly>() -> ButterflyCache<M> {
100 let g = StaticModInt::<M>::new(M::PRIMITIVE_ROOT);
101 let rank2 = M::EXP as usize;
102
103 let mut root = vec![StaticModInt::new(0); rank2 + 1];
104 let mut iroot = vec![StaticModInt::new(0); rank2 + 1];
105 root[rank2] = g.pow(M::ODD.into());
106 iroot[rank2] = root[rank2].recip();
107 for i in (0..rank2).rev() {
108 root[i] = root[i + 1] * root[i + 1];
109 iroot[i] = iroot[i + 1] * iroot[i + 1];
110 }
111
112 let mut rate2 = vec![StaticModInt::new(0); rank2];
113 let mut irate2 = vec![StaticModInt::new(0); rank2];
114 {
115 let mut prod = StaticModInt::new(1);
116 let mut iprod = StaticModInt::new(1);
117 for i in 0..=rank2 - 2 {
118 rate2[i] = root[i + 2] * prod;
119 irate2[i] = iroot[i + 2] * iprod;
120 prod *= iroot[i + 2];
121 iprod *= root[i + 2];
122 }
123 }
124
125 let mut rate3 = vec![StaticModInt::new(0); rank2];
126 let mut irate3 = vec![StaticModInt::new(0); rank2];
127 {
128 let mut prod = StaticModInt::new(1);
129 let mut iprod = StaticModInt::new(1);
130 for i in 0..=rank2 - 3 {
131 rate3[i] = root[i + 3] * prod;
132 irate3[i] = iroot[i + 3] * iprod;
133 prod *= iroot[i + 3];
134 iprod *= root[i + 3];
135 }
136 }
137
138 ButterflyCache { root, iroot, rate2, irate2, rate3, irate3 }
139}
140
141pub fn butterfly<M: NttFriendly>(a: &mut [StaticModInt<M>]) {
142 let n = a.len();
143 let h = ceil_pow2(n as u32);
144
145 let ButterflyCache { root, rate2, rate3, .. } = M::butterfly_cache();
146
147 let mut len = 0;
149 while len < h {
150 if h - len == 1 {
151 let p = 1 << (h - len - 1);
152 let mut rot = StaticModInt::new(1);
153 for s in 0..1 << len {
154 let offset = s << (h - len);
155 for i in 0..p {
156 let l = a[i + offset];
157 let r = a[i + offset + p] * rot;
158 a[i + offset] = l + r;
159 a[i + offset + p] = l - r;
160 }
161 if s + 1 != 1 << len {
162 rot *= rate2[(!s).trailing_zeros() as usize];
163 }
164 }
165 len += 1;
166 } else {
167 let p = 1 << (h - len - 2);
169 let imag_u64 = root[2].get() as u64;
170 let mut rot = StaticModInt::new(1);
171
172 for s in 0..1 << len {
173 let rot2 = rot * rot;
174 let rot3 = rot2 * rot;
175
176 let rot_u64 = rot.get() as u64;
177 let rot2_u64 = rot2.get() as u64;
178 let rot3_u64 = rot3.get() as u64;
179
180 let offset = s << (h - len);
181 for i in 0..p {
182 let mod2 = (M::VALUE as u64).pow(2);
183 let a0 = a[i + offset].get() as u64;
184 let a1 = a[i + offset + p].get() as u64 * rot_u64;
185 let a2 = a[i + offset + 2 * p].get() as u64 * rot2_u64;
186 let a3 = a[i + offset + 3 * p].get() as u64 * rot3_u64;
187
188 let a1na3 = StaticModInt::<M>::new(a1 + mod2 - a3);
189 let a1na3imag = a1na3.get() as u64 * imag_u64;
190 let na2 = mod2 - a2;
191
192 a[i + offset] = StaticModInt::new(a0 + a2 + a1 + a3);
193 a[i + offset + p] =
194 StaticModInt::new(a0 + a2 + (2 * mod2 - (a1 + a3)));
195 a[i + offset + 2 * p] =
196 StaticModInt::new(a0 + na2 + a1na3imag);
197 a[i + offset + 3 * p] =
198 StaticModInt::new(a0 + na2 + (mod2 - a1na3imag));
199 }
200
201 if s + 1 != 1 << len {
202 rot *= rate3[(!s).trailing_zeros() as usize];
203 }
204 }
205 len += 2;
206 }
207 }
208}
209
210pub fn butterfly_inv<M: NttFriendly>(a: &mut [StaticModInt<M>]) {
211 let n = a.len();
212 let h = ceil_pow2(n as u32);
213
214 let ButterflyCache { iroot, irate2, irate3, .. } = M::butterfly_cache();
215
216 let mut len = h;
218 while len > 0 {
219 if len == 1 {
220 let p = 1 << (h - len);
221 let mut irot = StaticModInt::new(1);
222 for s in 0..1 << (len - 1) {
223 let offset = s << (h - len + 1);
224 for i in 0..p {
225 let l = a[i + offset];
226 let r = a[i + offset + p];
227 a[i + offset] = l + r;
228 a[i + offset + p] = (l - r) * irot
229 }
230
231 if s + 1 != 1 << (len - 1) {
232 irot *= irate2[(!s).trailing_zeros() as usize];
233 }
234 }
235 len -= 1;
236 } else {
237 let p = 1 << (h - len);
239 let mod1 = M::VALUE as u64;
240 let iimag_u64 = iroot[2].get() as u64;
241
242 let mut irot = StaticModInt::new(1);
243 for s in 0..1 << (len - 2) {
244 let irot2 = irot * irot;
245 let irot3 = irot2 * irot;
246
247 let irot_u64 = irot.get() as u64;
248 let irot2_u64 = irot2.get() as u64;
249 let irot3_u64 = irot3.get() as u64;
250
251 let offset = s << (h - len + 2);
252 for i in 0..p {
253 let a0 = a[i + offset].get() as u64;
254 let a1 = a[i + offset + p].get() as u64;
255 let a2 = a[i + offset + 2 * p].get() as u64;
256 let a3 = a[i + offset + 3 * p].get() as u64;
257
258 let a2na3_u64 =
259 StaticModInt::<M>::new(mod1 + a2 - a3).get() as u64;
260 let a2na3iimag =
261 StaticModInt::<M>::new(a2na3_u64 * iimag_u64);
262 let a2na3iimag_u64 = a2na3iimag.get() as u64;
263
264 a[i + offset] = StaticModInt::new(a0 + a1 + a2 + a3);
265 a[i + offset + p] = StaticModInt::new(
266 (a0 + (mod1 - a1) + a2na3iimag_u64) * irot_u64,
267 );
268 a[i + offset + 2 * p] = StaticModInt::new(
269 (a0 + a1 + (mod1 - a2) + (mod1 - a3)) * irot2_u64,
270 );
271 a[i + offset + 3 * p] = StaticModInt::new(
272 (a0 + (mod1 - a1) + (mod1 - a2na3iimag_u64))
273 * irot3_u64,
274 );
275 }
276 if s + 1 != 1 << (len - 2) {
277 irot *= irate3[(!s).trailing_zeros() as usize];
278 }
279 }
280 len -= 2;
281 }
282 }
283}
284
285pub fn convolve<M: NttFriendly>(
286 a: Vec<StaticModInt<M>>,
287 b: Vec<StaticModInt<M>>,
288) -> Vec<StaticModInt<M>> {
289 if a.is_empty() || b.is_empty() {
290 return vec![];
291 }
292 let (n, m) = (a.len(), b.len());
293
294 if n.min(m) <= 60 {
295 convolve_naive(&a, &b)
296 } else if (n + m - 2).is_power_of_two() {
297 convolve_pow2p1(a, b)
298 } else {
299 convolve_fft(a, b)
300 }
301}
302
303fn convolve_naive<M: NttFriendly>(
304 a: &[StaticModInt<M>],
305 b: &[StaticModInt<M>],
306) -> Vec<StaticModInt<M>> {
307 let (n, m) = (a.len(), b.len());
308 let (n, m, a, b) = if n < m { (m, n, b, a) } else { (n, m, a, b) };
309 let mut res = vec![StaticModInt::new(0); n + m - 1];
310 for i in 0..n {
311 for j in 0..m {
312 res[i + j] += a[i] * b[j];
313 }
314 }
315 res
316}
317
318fn convolve_pow2p1<M: NttFriendly>(
319 a: Vec<StaticModInt<M>>,
320 b: Vec<StaticModInt<M>>,
321) -> Vec<StaticModInt<M>> {
322 let n = a.len();
323 let m = b.len();
324 let len = n + m - 1;
325 assert!((len - 1).is_power_of_two());
326
327 let mut res = convolve_fft(a[1..].to_vec(), b[1..].to_vec());
333 res.splice(0..0, (0..2).map(|_| StaticModInt::new(0)));
334
335 res[0] += a[0] * b[0];
336 for i in 1..n {
337 res[i] += a[i] * b[0];
338 }
339 for j in 1..m {
340 res[j] += a[0] * b[j];
341 }
342 res
343}
344
345fn convolve_fft<M: NttFriendly>(
346 mut a: Vec<StaticModInt<M>>,
347 mut b: Vec<StaticModInt<M>>,
348) -> Vec<StaticModInt<M>> {
349 let (n, m) = (a.len(), b.len());
350 let z = (n + m - 1).next_power_of_two();
351 a.resize(z, StaticModInt::new(0));
352 b.resize(z, StaticModInt::new(0));
353
354 butterfly(&mut a);
355 butterfly(&mut b);
356
357 for (ai, bi) in a.iter_mut().zip(&mut b) {
358 *ai *= *bi;
359 }
360 butterfly_inv(&mut a);
361
362 a.truncate(n + m - 1);
363 let iz = StaticModInt::new(z).recip();
364 for ai in &mut a {
365 *ai *= iz;
366 }
367
368 a
369}
370
371macro_rules! impl_modint_ntt {
372 ( $( ($mod:ident, $val:expr, $cache:ident), )* ) => { $(
373 #[derive(Clone, Copy, Eq, PartialEq)]
374 struct $mod;
375 static $cache: OnceLock<ButterflyCache<$mod>> = OnceLock::new();
376 impl Modulus for $mod {
377 const VALUE: u32 = $val;
378 }
379 impl NttFriendly for $mod {
380 fn cache() -> &'static OnceLock<ButterflyCache<$mod>> { &$cache }
381 }
382 )* }
383}
384
385impl_modint_ntt! {
386 (Mod45e24p1, 45 << 24 | 1, MOD45E24P1_CACHE),
387 (Mod73e24p1, 73 << 24 | 1, MOD73E24P1_CACHE),
388 (Mod127e24p1, 127 << 24 | 1, MOD127E24P1_CACHE),
389 (Mod5e25p1, 5 << 25 | 1, MOD5E25P1_CACHE),
390 (Mod33e25p1, 33 << 25 | 1, MOD33E25P1_CACHE),
391 (Mod51e25p1, 51 << 25 | 1, MOD51E25P1_CACHE),
392 (Mod63e25p1, 63 << 25 | 1, MOD63E25P1_CACHE),
393 (Mod7e26p1, 7 << 26 | 1, MOD7E26P1_CACHE),
394 (Mod27e26p1, 27 << 26 | 1, MOD27E26P1_CACHE),
395 (Mod15e27p1, 15 << 27 | 1, MOD15E27P1_CACHE),
396}
397
398type Mod0 = Mod15e27p1;
399type Mod1 = Mod27e26p1;
400type Mod2 = Mod7e26p1;
401type Mod3 = Mod63e25p1;
402type Mod4 = Mod51e25p1;
403type Mod5 = Mod33e25p1;
404type Mod6 = Mod5e25p1;
405type Mod7 = Mod127e24p1;
406type Mod8 = Mod73e24p1;
407type Mod9 = Mod45e24p1;
408
409const MOD0: u32 = Mod0::VALUE;
410const MOD1: u32 = Mod1::VALUE;
411const MOD2: u32 = Mod2::VALUE;
412const MOD3: u32 = Mod3::VALUE;
413const MOD4: u32 = Mod4::VALUE;
414const MOD5: u32 = Mod5::VALUE;
415const MOD6: u32 = Mod6::VALUE;
416const MOD7: u32 = Mod7::VALUE;
417const MOD8: u32 = Mod8::VALUE;
418const MOD9: u32 = Mod9::VALUE;
419
420fn convolve_from<M: NttFriendly, I: RemEuclidU32 + Copy>(
421 a: &[I],
422 b: &[I],
423) -> Vec<u32> {
424 let a: Vec<_> = a.iter().map(|&ai| StaticModInt::new(ai)).collect();
425 let b: Vec<_> = b.iter().map(|&bi| StaticModInt::new(bi)).collect();
426 convolve(a, b).into_iter().map(|x: StaticModInt<M>| x.get()).collect()
427}
428
429pub fn convolve_u64_acl(a: &[u64], b: &[u64]) -> Vec<u64> {
430 if a.is_empty() || b.is_empty() {
431 return vec![];
432 }
433 let n = a.len();
434 let m = b.len();
435
436 let mod1 = Mod45e24p1::VALUE as u64;
437 let mod2 = Mod5e25p1::VALUE as u64;
438 let mod3 = Mod7e26p1::VALUE as u64;
439 let m2m3 = mod2 * mod3;
440 let m1m3 = mod1 * mod3;
441 let m1m2 = mod1 * mod2;
442 let m1m2m3 = m1m2.wrapping_mul(mod3);
443
444 type ModInt754974721 = StaticModInt<Mod45e24p1>;
445 type ModInt167772161 = StaticModInt<Mod5e25p1>;
446 type ModInt469762049 = StaticModInt<Mod7e26p1>;
447
448 let i1 = ModInt754974721::new(m2m3).recip().get() as u64;
449 let i2 = ModInt167772161::new(m1m3).recip().get() as u64;
450 let i3 = ModInt469762049::new(m1m2).recip().get() as u64;
451
452 let max_bit = 24;
453 assert_eq!(mod1 % (1 << max_bit), 1);
454 assert_eq!(mod2 % (1 << max_bit), 1);
455 assert_eq!(mod3 % (1 << max_bit), 1);
456 assert!(n + m - 1 <= (1 << max_bit));
457
458 let c1 = convolve_from::<Mod45e24p1, _>(&a, &b);
459 let c2 = convolve_from::<Mod5e25p1, _>(&a, &b);
460 let c3 = convolve_from::<Mod7e26p1, _>(&a, &b);
461
462 c1.into_iter()
463 .zip(c2)
464 .zip(c3)
465 .map(|((c1i, c2i), c3i)| {
466 let c1i = c1i as u64;
467 let c2i = c2i as u64;
468 let c3i = c3i as u64;
469
470 let mut x = 0;
471 x += (c1i * i1) % mod1 * m2m3;
472 x += (c2i * i2) % mod2 * m1m3;
473 x += (c3i * i3) % mod3 * m1m2;
474 let rem = x.rem_euclid(mod1);
475 let diff = if c1i >= rem { c1i - rem } else { mod1 - (rem - c1i) };
476 let offset = [0, 0, m1m2m3, 2 * m1m2m3, 3 * m1m2m3];
477 x - offset[diff as usize % 5]
478 })
479 .collect()
480}
481
482enum CrtU64 {}
483enum CrtWrappingU64 {}
484enum CrtU128 {}
485enum CrtWrappingU128 {}
486#[derive(Copy, Clone)]
487struct CrtU32Mod(u32);
488#[derive(Copy, Clone)]
489struct CrtU64Mod(u64);
490#[derive(Copy, Clone)]
491struct CrtU128Mod(u128);
492
493type U32x3 = ((u32, u32), u32);
494type U32x5 = ((U32x3, u32), u32);
495type U32x6 = (U32x5, u32);
496type U32x10 = ((((U32x6, u32), u32), u32), u32);
497
498trait ToArray {
499 type Output;
500 fn to_array(self) -> Self::Output;
501}
502
503impl ToArray for U32x3 {
504 type Output = [u32; 3];
505 fn to_array(self) -> Self::Output {
506 let ((x0, x1), x2) = self;
507 [x0, x1, x2]
508 }
509}
510
511impl ToArray for U32x5 {
512 type Output = [u32; 5];
513 fn to_array(self) -> Self::Output {
514 let ((((x0, x1), x2), x3), x4) = self;
515 [x0, x1, x2, x3, x4]
516 }
517}
518
519impl ToArray for U32x6 {
520 type Output = [u32; 6];
521 fn to_array(self) -> Self::Output {
522 let (((((x0, x1), x2), x3), x4), x5) = self;
523 [x0, x1, x2, x3, x4, x5]
524 }
525}
526
527impl ToArray for U32x10 {
528 type Output = [u32; 10];
529 fn to_array(self) -> Self::Output {
530 let (((((((((x0, x1), x2), x3), x4), x5), x6), x7), x8), x9) = self;
531 [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9]
532 }
533}
534
535trait Crt {
536 type Input;
537 type Output;
538 fn crt(i: Self::Input) -> Self::Output;
539}
540
541impl Crt for CrtU64 {
542 type Input = U32x3;
543 type Output = u64;
544 fn crt(xs: Self::Input) -> u64 {
545 let [x0, x1, x2] = xs.to_array();
546 [
547 (x0 as u64, MOD0 as u64),
548 (x1 as u64, MOD1 as u64),
549 (x2 as u64, MOD2 as u64),
550 ]
551 .crt_wrapping()
552 }
553}
554
555impl Crt for CrtWrappingU64 {
556 type Input = U32x6;
557 type Output = u64;
558 fn crt(xs: Self::Input) -> u64 {
559 let [x0, x1, x2, x3, x4, x5] = xs.to_array();
560 [
561 (x0 as u64, MOD0 as u64),
562 (x1 as u64, MOD1 as u64),
563 (x2 as u64, MOD2 as u64),
564 (x3 as u64, MOD3 as u64),
565 (x4 as u64, MOD4 as u64),
566 (x5 as u64, MOD5 as u64),
567 ]
568 .crt_wrapping()
569 }
570}
571
572impl Crt for CrtU128 {
573 type Input = U32x5;
574 type Output = u128;
575 fn crt(xs: Self::Input) -> u128 {
576 let [x0, x1, x2, x3, x4] = xs.to_array();
577 [
578 (x0 as u128, MOD0 as u128),
579 (x1 as u128, MOD1 as u128),
580 (x2 as u128, MOD2 as u128),
581 (x3 as u128, MOD3 as u128),
582 (x4 as u128, MOD4 as u128),
583 ]
584 .crt_wrapping()
585 }
586}
587
588impl Crt for CrtWrappingU128 {
589 type Input = U32x10;
590 type Output = u128;
591 fn crt(xs: Self::Input) -> u128 {
592 let [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9] = xs.to_array();
593 [
594 (x0 as u128, MOD0 as u128),
595 (x1 as u128, MOD1 as u128),
596 (x2 as u128, MOD2 as u128),
597 (x3 as u128, MOD3 as u128),
598 (x4 as u128, MOD4 as u128),
599 (x5 as u128, MOD5 as u128),
600 (x6 as u128, MOD6 as u128),
601 (x7 as u128, MOD7 as u128),
602 (x8 as u128, MOD8 as u128),
603 (x9 as u128, MOD9 as u128),
604 ]
605 .crt_wrapping()
606 }
607}
608
609trait CrtMod {
610 type Input;
611 type Output;
612 fn crt_mod(self, i: Self::Input) -> Self::Output;
613}
614
615impl CrtU32Mod {
616 fn new(m: u32) -> Self { Self(m) }
617}
618
619impl CrtU64Mod {
620 fn new(m: u64) -> Self { Self(m) }
621}
622
623impl CrtU128Mod {
624 fn new(m: u128) -> Self { Self(m) }
625}
626
627impl CrtMod for CrtU32Mod {
628 type Input = U32x3;
629 type Output = u32;
630 fn crt_mod(self, xs: Self::Input) -> Self::Output {
631 let [x0, x1, x2] = xs.to_array();
632 [
633 (x0 as u64, MOD0 as u64),
634 (x1 as u64, MOD1 as u64),
635 (x2 as u64, MOD2 as u64),
636 ]
637 .crt_mod(self.0 as u64) as u32
638 }
639}
640
641impl CrtMod for CrtU64Mod {
642 type Input = U32x6;
643 type Output = u64;
644 fn crt_mod(self, xs: Self::Input) -> Self::Output {
645 let [x0, x1, x2, x3, x4, x5] = xs.to_array();
646 [
647 (x0 as u64, MOD0 as u64),
648 (x1 as u64, MOD1 as u64),
649 (x2 as u64, MOD2 as u64),
650 (x3 as u64, MOD3 as u64),
651 (x4 as u64, MOD4 as u64),
652 (x5 as u64, MOD5 as u64),
653 ]
654 .crt_mod(self.0)
655 }
656}
657
658impl CrtMod for CrtU128Mod {
659 type Input = U32x10;
660 type Output = u128;
661 fn crt_mod(self, xs: Self::Input) -> Self::Output {
662 let [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9] = xs.to_array();
663 [
664 (x0 as u128, MOD0 as u128),
665 (x1 as u128, MOD1 as u128),
666 (x2 as u128, MOD2 as u128),
667 (x3 as u128, MOD3 as u128),
668 (x4 as u128, MOD4 as u128),
669 (x5 as u128, MOD5 as u128),
670 (x6 as u128, MOD6 as u128),
671 (x7 as u128, MOD7 as u128),
672 (x8 as u128, MOD8 as u128),
673 (x9 as u128, MOD9 as u128),
674 ]
675 .crt_mod(self.0)
676 }
677}
678
679macro_rules! impl_convolve {
680 ( $( ($fn:ident, $ty:ty, $crt:path, [$mod1:ty, $( $mod:ty ),*]), )* ) => { $(
681 pub fn $fn(a: &[$ty], b: &[$ty]) -> Vec<$ty> {
682 if a.is_empty() || b.is_empty() {
683 return vec![];
684 }
685 let n = a.len();
686 let m = b.len();
687 assert!(n + m - 1 <= 1_usize << <$mod1>::EXP);
688 $( assert!(n + m - 1 <= 1_usize << <$mod>::EXP) );*;
689 convolve_from::<$mod1, _>(&a, &b)
690 .into_iter()
691 $( .zip(convolve_from::<$mod, _>(&a, &b)) )*
692 .map($crt)
693 .collect()
694 }
695 )* }
696}
697
698macro_rules! impl_convolve_mod {
699 ( $( ($fn:ident, $ty:ty, $crt:ident, [$mod1:ty, $( $mod:ty ),*]), )* ) => { $(
700 pub fn $fn(a: &[$ty], b: &[$ty], mm: $ty) -> Vec<$ty> {
701 if a.is_empty() || b.is_empty() {
702 return vec![];
703 }
704 let n = a.len();
705 let m = b.len();
706 assert!(n + m - 1 <= 1_usize << <$mod1>::EXP);
707 $( assert!(n + m - 1 <= 1_usize << <$mod>::EXP) );*;
708 let crt = $crt::new(mm);
709 convolve_from::<$mod1, _>(&a, &b)
710 .into_iter()
711 $( .zip(convolve_from::<$mod, _>(&a, &b)) )*
712 .map(|x| crt.crt_mod(x))
713 .collect()
714 }
715 )* }
716}
717
718impl_convolve! {
719 (convolve_u64, u64, CrtU64::crt, [Mod0, Mod1, Mod2]),
720 (convolve_u128, u128, CrtU128::crt, [Mod0, Mod1, Mod2, Mod3, Mod4]),
721 (convolve_wrapping_u64, u64, CrtWrappingU64::crt, [Mod0, Mod1, Mod2, Mod3, Mod4, Mod5]),
722 (convolve_wrapping_u128, u128, CrtWrappingU128::crt, [Mod0, Mod1, Mod2, Mod3, Mod4, Mod5, Mod6, Mod7, Mod8, Mod9]),
723}
724
725impl_convolve_mod! {
726 (convolve_u32_mod, u32, CrtU32Mod, [Mod0, Mod1, Mod2]),
727 (convolve_u64_mod, u64, CrtU64Mod, [Mod0, Mod1, Mod2, Mod3, Mod4, Mod5]),
728 (convolve_u128_mod, u128, CrtU128Mod, [Mod0, Mod1, Mod2, Mod3, Mod4, Mod5, Mod6, Mod7, Mod8, Mod9]),
729}
730
731fn ceil_pow2(n: u32) -> u32 { 32 - n.saturating_sub(1).leading_zeros() }
732
733#[test]
734fn sanity_check() {
735 type Mi = modint::ModInt998244353;
736
737 let a: Vec<_> = [0, 1, 2, 3, 4].iter().map(|&x| Mi::new(x)).collect();
738 let b: Vec<_> = [0, 1, 2, 4, 8].iter().map(|&x| Mi::new(x)).collect();
739 let c: Vec<_> = convolve_fft(a, b).iter().map(|x| x.get()).collect();
740
741 assert_eq!(c, [0, 0, 1, 4, 11, 26, 36, 40, 32]);
742}
743
744#[test]
745fn proot() {
746 assert_eq!(Mod45e24p1::PRIMITIVE_ROOT, 11);
747 assert_eq!(Mod73e24p1::PRIMITIVE_ROOT, 3);
748 assert_eq!(Mod127e24p1::PRIMITIVE_ROOT, 3);
749 assert_eq!(Mod5e25p1::PRIMITIVE_ROOT, 3);
750 assert_eq!(Mod33e25p1::PRIMITIVE_ROOT, 10);
751 assert_eq!(Mod51e25p1::PRIMITIVE_ROOT, 29);
752 assert_eq!(Mod63e25p1::PRIMITIVE_ROOT, 5);
753 assert_eq!(Mod7e26p1::PRIMITIVE_ROOT, 3);
754 assert_eq!(Mod27e26p1::PRIMITIVE_ROOT, 13);
755 assert_eq!(Mod15e27p1::PRIMITIVE_ROOT, 31);
756}
757
758#[test]
759fn large() {
760 let max32 = u32::MAX as u64;
761 assert_eq!(convolve_u64(&[max32], &[max32]), [max32 * max32]);
762
763 let max64 = u64::MAX as u128;
764 assert_eq!(convolve_u128(&[max64], &[max64]), [max64 * max64]);
765}
766
767#[test]
768fn long_wrapping_u64() {
769 let max32 = u32::MAX as u64;
770 let n = 1 << 24;
771 let long32 = vec![max32; n];
772 let a = convolve_wrapping_u64(&long32, &long32);
773 for i in 0..n {
774 assert_eq!(a[i], a[n + n - 2 - i]);
775 assert_eq!(a[i], (max32 * max32).wrapping_mul(i as u64 + 1));
776 }
777}
778
779#[test]
780fn long_wrapping_u128() {
781 let max64 = u64::MAX as u128;
782 let n = 1 << 23;
783 let long64 = vec![max64; n];
784 let a = convolve_wrapping_u128(&long64, &long64);
785 for i in 0..n {
786 assert_eq!(a[i], a[n + n - 2 - i]);
787 assert_eq!(a[i], (max64 * max64).wrapping_mul(i as u128 + 1));
788 }
789}
790
791#[test]
792fn long_u32_mod() {
793 let max32 = u32::MAX;
794 let n = 1 << 24;
795 let p = 998244353;
796 let long32 = vec![max32; n];
797 let a = convolve_u32_mod(&long32, &long32, p as u32);
798 for i in 0..n {
799 assert_eq!(a[i], a[n + n - 2 - i]);
800 let expected = (max32 as u64 % p).pow(2) % p * (i as u64 + 1) % p;
801 assert_eq!(a[i], expected as u32);
802 }
803}
804
805#[test]
806fn long_u64_mod() {
807 let max32 = u32::MAX as u64;
808 let n = 1 << 24;
809 let p = 998244353;
810 let long32 = vec![max32; n];
811 let a = convolve_u64_mod(&long32, &long32, p);
812 for i in 0..n {
813 assert_eq!(a[i], a[n + n - 2 - i]);
814 assert_eq!(a[i], (max32 * max32 % p) * (i as u64 + 1) % p);
815 }
816}
817
818#[test]
819fn long_u128_mod() {
820 let max64 = u64::MAX as u128;
821 let n = 1 << 23;
822 let p = 998244353;
823 let long64 = vec![max64; n];
824 let a = convolve_u128_mod(&long64, &long64, p);
825 for i in 0..n {
826 assert_eq!(a[i], a[n + n - 2 - i]);
827 assert_eq!(a[i], (max64 * max64 % p) * (i as u128 + 1) % p);
828 }
829}
830
831#[test]
832fn pow2p1() {
833 type Mi = modint::ModInt998244353;
834
835 let n = 1 << 6 | 1;
836 let a: Vec<_> = (0..n).map(|x| Mi::new(x + 1)).collect();
837 let b = a.clone();
838 let expected = convolve_naive(&a, &b);
839 assert_eq!(convolve(a, b), expected);
840}