Skip to main content

nekolib/math/
convolution.rs

1use super::garner;
2use super::modint;
3use std::sync::OnceLock;
4
5use garner::{CrtMod as CrtModInternal, CrtWrapping};
6use modint::{Mod998244353, ModIntBase, Modulus, RemEuclidU32, StaticModInt};
7
8pub struct ButterflyCache<M: NttFriendly> {
9    root: Vec<StaticModInt<M>>,
10    iroot: Vec<StaticModInt<M>>,
11    rate2: Vec<StaticModInt<M>>,
12    irate2: Vec<StaticModInt<M>>,
13    rate3: Vec<StaticModInt<M>>,
14    irate3: Vec<StaticModInt<M>>,
15}
16
17const fn primitive_root(p: u32) -> u32 {
18    if p == 2 {
19        return 1;
20    }
21
22    // 2*3*5*7*11*13*17*19*23*29 > 2**32
23    let mut divs = [0; 10];
24    divs[0] = 2;
25    let mut index = 1;
26    let mut x = (p - 1) / 2;
27    while x % 2 == 0 {
28        x /= 2;
29    }
30    let mut d = 3;
31    while d <= x / d {
32        if x % d == 0 {
33            divs[index] = d;
34            index += 1;
35            while x % d == 0 {
36                x /= d;
37            }
38        }
39        d += 2;
40    }
41    if x > 1 {
42        divs[index] = x;
43        index += 1;
44    }
45
46    let mut g = 2;
47    loop {
48        let mut ok = true;
49        let mut i = 0;
50        while i < index {
51            if mod_pow(g, (p - 1) / divs[i], p) == 1 {
52                ok = false;
53                break;
54            }
55            i += 1;
56        }
57        if ok {
58            return g;
59        }
60        g += 1;
61    }
62}
63
64const fn mod_pow(a: u32, mut e: u32, m: u32) -> u32 {
65    let mut res = 1;
66    let mut a = a as u64;
67    let m = m as u64;
68    while e > 0 {
69        if e & 1 != 0 {
70            res = res * a % m;
71        }
72        a = a * a % m;
73        e >>= 1;
74    }
75    res as u32
76}
77
78pub trait NttFriendly: Modulus {
79    const PRIMITIVE_ROOT: u32 = primitive_root(Self::VALUE);
80    // ODD << EXP | 1 == MOD
81    const EXP: u32 = (Self::VALUE - 1).trailing_zeros();
82    const ODD: u32 = Self::VALUE >> Self::EXP;
83
84    fn cache() -> &'static OnceLock<ButterflyCache<Self>>;
85    fn butterfly_cache() -> &'static ButterflyCache<Self> {
86        Self::cache().get_or_init(precompute_butterfly)
87    }
88}
89
90static MOD998244353_CACHE: OnceLock<ButterflyCache<Mod998244353>> =
91    OnceLock::new();
92
93impl NttFriendly for Mod998244353 {
94    fn cache() -> &'static OnceLock<ButterflyCache<Self>> {
95        &MOD998244353_CACHE
96    }
97}
98
99fn precompute_butterfly<M: NttFriendly>() -> ButterflyCache<M> {
100    let g = StaticModInt::<M>::new(M::PRIMITIVE_ROOT);
101    let rank2 = M::EXP as usize;
102
103    let mut root = vec![StaticModInt::new(0); rank2 + 1];
104    let mut iroot = vec![StaticModInt::new(0); rank2 + 1];
105    root[rank2] = g.pow(M::ODD.into());
106    iroot[rank2] = root[rank2].recip();
107    for i in (0..rank2).rev() {
108        root[i] = root[i + 1] * root[i + 1];
109        iroot[i] = iroot[i + 1] * iroot[i + 1];
110    }
111
112    let mut rate2 = vec![StaticModInt::new(0); rank2];
113    let mut irate2 = vec![StaticModInt::new(0); rank2];
114    {
115        let mut prod = StaticModInt::new(1);
116        let mut iprod = StaticModInt::new(1);
117        for i in 0..=rank2 - 2 {
118            rate2[i] = root[i + 2] * prod;
119            irate2[i] = iroot[i + 2] * iprod;
120            prod *= iroot[i + 2];
121            iprod *= root[i + 2];
122        }
123    }
124
125    let mut rate3 = vec![StaticModInt::new(0); rank2];
126    let mut irate3 = vec![StaticModInt::new(0); rank2];
127    {
128        let mut prod = StaticModInt::new(1);
129        let mut iprod = StaticModInt::new(1);
130        for i in 0..=rank2 - 3 {
131            rate3[i] = root[i + 3] * prod;
132            irate3[i] = iroot[i + 3] * iprod;
133            prod *= iroot[i + 3];
134            iprod *= root[i + 3];
135        }
136    }
137
138    ButterflyCache { root, iroot, rate2, irate2, rate3, irate3 }
139}
140
141pub fn butterfly<M: NttFriendly>(a: &mut [StaticModInt<M>]) {
142    let n = a.len();
143    let h = ceil_pow2(n as u32);
144
145    let ButterflyCache { root, rate2, rate3, .. } = M::butterfly_cache();
146
147    // a[i, i + (n >> len), i + 2 * (n >> len), ...] is transformed
148    let mut len = 0;
149    while len < h {
150        if h - len == 1 {
151            let p = 1 << (h - len - 1);
152            let mut rot = StaticModInt::new(1);
153            for s in 0..1 << len {
154                let offset = s << (h - len);
155                for i in 0..p {
156                    let l = a[i + offset];
157                    let r = a[i + offset + p] * rot;
158                    a[i + offset] = l + r;
159                    a[i + offset + p] = l - r;
160                }
161                if s + 1 != 1 << len {
162                    rot *= rate2[(!s).trailing_zeros() as usize];
163                }
164            }
165            len += 1;
166        } else {
167            // 4-base
168            let p = 1 << (h - len - 2);
169            let imag_u64 = root[2].get() as u64;
170            let mut rot = StaticModInt::new(1);
171
172            for s in 0..1 << len {
173                let rot2 = rot * rot;
174                let rot3 = rot2 * rot;
175
176                let rot_u64 = rot.get() as u64;
177                let rot2_u64 = rot2.get() as u64;
178                let rot3_u64 = rot3.get() as u64;
179
180                let offset = s << (h - len);
181                for i in 0..p {
182                    let mod2 = (M::VALUE as u64).pow(2);
183                    let a0 = a[i + offset].get() as u64;
184                    let a1 = a[i + offset + p].get() as u64 * rot_u64;
185                    let a2 = a[i + offset + 2 * p].get() as u64 * rot2_u64;
186                    let a3 = a[i + offset + 3 * p].get() as u64 * rot3_u64;
187
188                    let a1na3 = StaticModInt::<M>::new(a1 + mod2 - a3);
189                    let a1na3imag = a1na3.get() as u64 * imag_u64;
190                    let na2 = mod2 - a2;
191
192                    a[i + offset] = StaticModInt::new(a0 + a2 + a1 + a3);
193                    a[i + offset + p] =
194                        StaticModInt::new(a0 + a2 + (2 * mod2 - (a1 + a3)));
195                    a[i + offset + 2 * p] =
196                        StaticModInt::new(a0 + na2 + a1na3imag);
197                    a[i + offset + 3 * p] =
198                        StaticModInt::new(a0 + na2 + (mod2 - a1na3imag));
199                }
200
201                if s + 1 != 1 << len {
202                    rot *= rate3[(!s).trailing_zeros() as usize];
203                }
204            }
205            len += 2;
206        }
207    }
208}
209
210pub fn butterfly_inv<M: NttFriendly>(a: &mut [StaticModInt<M>]) {
211    let n = a.len();
212    let h = ceil_pow2(n as u32);
213
214    let ButterflyCache { iroot, irate2, irate3, .. } = M::butterfly_cache();
215
216    // a[i, i + (n >> len), i + 2 * (n >> len), ...] is transformed
217    let mut len = h;
218    while len > 0 {
219        if len == 1 {
220            let p = 1 << (h - len);
221            let mut irot = StaticModInt::new(1);
222            for s in 0..1 << (len - 1) {
223                let offset = s << (h - len + 1);
224                for i in 0..p {
225                    let l = a[i + offset];
226                    let r = a[i + offset + p];
227                    a[i + offset] = l + r;
228                    a[i + offset + p] = (l - r) * irot
229                }
230
231                if s + 1 != 1 << (len - 1) {
232                    irot *= irate2[(!s).trailing_zeros() as usize];
233                }
234            }
235            len -= 1;
236        } else {
237            // 4-base
238            let p = 1 << (h - len);
239            let mod1 = M::VALUE as u64;
240            let iimag_u64 = iroot[2].get() as u64;
241
242            let mut irot = StaticModInt::new(1);
243            for s in 0..1 << (len - 2) {
244                let irot2 = irot * irot;
245                let irot3 = irot2 * irot;
246
247                let irot_u64 = irot.get() as u64;
248                let irot2_u64 = irot2.get() as u64;
249                let irot3_u64 = irot3.get() as u64;
250
251                let offset = s << (h - len + 2);
252                for i in 0..p {
253                    let a0 = a[i + offset].get() as u64;
254                    let a1 = a[i + offset + p].get() as u64;
255                    let a2 = a[i + offset + 2 * p].get() as u64;
256                    let a3 = a[i + offset + 3 * p].get() as u64;
257
258                    let a2na3_u64 =
259                        StaticModInt::<M>::new(mod1 + a2 - a3).get() as u64;
260                    let a2na3iimag =
261                        StaticModInt::<M>::new(a2na3_u64 * iimag_u64);
262                    let a2na3iimag_u64 = a2na3iimag.get() as u64;
263
264                    a[i + offset] = StaticModInt::new(a0 + a1 + a2 + a3);
265                    a[i + offset + p] = StaticModInt::new(
266                        (a0 + (mod1 - a1) + a2na3iimag_u64) * irot_u64,
267                    );
268                    a[i + offset + 2 * p] = StaticModInt::new(
269                        (a0 + a1 + (mod1 - a2) + (mod1 - a3)) * irot2_u64,
270                    );
271                    a[i + offset + 3 * p] = StaticModInt::new(
272                        (a0 + (mod1 - a1) + (mod1 - a2na3iimag_u64))
273                            * irot3_u64,
274                    );
275                }
276                if s + 1 != 1 << (len - 2) {
277                    irot *= irate3[(!s).trailing_zeros() as usize];
278                }
279            }
280            len -= 2;
281        }
282    }
283}
284
285pub fn convolve<M: NttFriendly>(
286    a: Vec<StaticModInt<M>>,
287    b: Vec<StaticModInt<M>>,
288) -> Vec<StaticModInt<M>> {
289    if a.is_empty() || b.is_empty() {
290        return vec![];
291    }
292    let (n, m) = (a.len(), b.len());
293
294    if n.min(m) <= 60 {
295        convolve_naive(&a, &b)
296    } else if (n + m - 2).is_power_of_two() {
297        convolve_pow2p1(a, b)
298    } else {
299        convolve_fft(a, b)
300    }
301}
302
303fn convolve_naive<M: NttFriendly>(
304    a: &[StaticModInt<M>],
305    b: &[StaticModInt<M>],
306) -> Vec<StaticModInt<M>> {
307    let (n, m) = (a.len(), b.len());
308    let (n, m, a, b) = if n < m { (m, n, b, a) } else { (n, m, a, b) };
309    let mut res = vec![StaticModInt::new(0); n + m - 1];
310    for i in 0..n {
311        for j in 0..m {
312            res[i + j] += a[i] * b[j];
313        }
314    }
315    res
316}
317
318fn convolve_pow2p1<M: NttFriendly>(
319    a: Vec<StaticModInt<M>>,
320    b: Vec<StaticModInt<M>>,
321) -> Vec<StaticModInt<M>> {
322    let n = a.len();
323    let m = b.len();
324    let len = n + m - 1;
325    assert!((len - 1).is_power_of_two());
326
327    // n + m - 1 == 2^k + 1
328    //
329    // (a[0] + x a[1..]) (b[0] + x b[1..])
330    // a[0] b[0] + x (a[0] b[1..] + b[0] a[1..]) + x^2 a[1..] b[1..]
331
332    let mut res = convolve_fft(a[1..].to_vec(), b[1..].to_vec());
333    res.splice(0..0, (0..2).map(|_| StaticModInt::new(0)));
334
335    res[0] += a[0] * b[0];
336    for i in 1..n {
337        res[i] += a[i] * b[0];
338    }
339    for j in 1..m {
340        res[j] += a[0] * b[j];
341    }
342    res
343}
344
345fn convolve_fft<M: NttFriendly>(
346    mut a: Vec<StaticModInt<M>>,
347    mut b: Vec<StaticModInt<M>>,
348) -> Vec<StaticModInt<M>> {
349    let (n, m) = (a.len(), b.len());
350    let z = (n + m - 1).next_power_of_two();
351    a.resize(z, StaticModInt::new(0));
352    b.resize(z, StaticModInt::new(0));
353
354    butterfly(&mut a);
355    butterfly(&mut b);
356
357    for (ai, bi) in a.iter_mut().zip(&mut b) {
358        *ai *= *bi;
359    }
360    butterfly_inv(&mut a);
361
362    a.truncate(n + m - 1);
363    let iz = StaticModInt::new(z).recip();
364    for ai in &mut a {
365        *ai *= iz;
366    }
367
368    a
369}
370
371macro_rules! impl_modint_ntt {
372    ( $( ($mod:ident, $val:expr, $cache:ident), )* ) => { $(
373        #[derive(Clone, Copy, Eq, PartialEq)]
374        struct $mod;
375        static $cache: OnceLock<ButterflyCache<$mod>> = OnceLock::new();
376        impl Modulus for $mod {
377            const VALUE: u32 = $val;
378        }
379        impl NttFriendly for $mod {
380            fn cache() -> &'static OnceLock<ButterflyCache<$mod>> { &$cache }
381        }
382    )* }
383}
384
385impl_modint_ntt! {
386    (Mod45e24p1, 45 << 24 | 1, MOD45E24P1_CACHE),
387    (Mod73e24p1, 73 << 24 | 1, MOD73E24P1_CACHE),
388    (Mod127e24p1, 127 << 24 | 1, MOD127E24P1_CACHE),
389    (Mod5e25p1, 5 << 25 | 1, MOD5E25P1_CACHE),
390    (Mod33e25p1, 33 << 25 | 1, MOD33E25P1_CACHE),
391    (Mod51e25p1, 51 << 25 | 1, MOD51E25P1_CACHE),
392    (Mod63e25p1, 63 << 25 | 1, MOD63E25P1_CACHE),
393    (Mod7e26p1, 7 << 26 | 1, MOD7E26P1_CACHE),
394    (Mod27e26p1, 27 << 26 | 1, MOD27E26P1_CACHE),
395    (Mod15e27p1, 15 << 27 | 1, MOD15E27P1_CACHE),
396}
397
398type Mod0 = Mod15e27p1;
399type Mod1 = Mod27e26p1;
400type Mod2 = Mod7e26p1;
401type Mod3 = Mod63e25p1;
402type Mod4 = Mod51e25p1;
403type Mod5 = Mod33e25p1;
404type Mod6 = Mod5e25p1;
405type Mod7 = Mod127e24p1;
406type Mod8 = Mod73e24p1;
407type Mod9 = Mod45e24p1;
408
409const MOD0: u32 = Mod0::VALUE;
410const MOD1: u32 = Mod1::VALUE;
411const MOD2: u32 = Mod2::VALUE;
412const MOD3: u32 = Mod3::VALUE;
413const MOD4: u32 = Mod4::VALUE;
414const MOD5: u32 = Mod5::VALUE;
415const MOD6: u32 = Mod6::VALUE;
416const MOD7: u32 = Mod7::VALUE;
417const MOD8: u32 = Mod8::VALUE;
418const MOD9: u32 = Mod9::VALUE;
419
420fn convolve_from<M: NttFriendly, I: RemEuclidU32 + Copy>(
421    a: &[I],
422    b: &[I],
423) -> Vec<u32> {
424    let a: Vec<_> = a.iter().map(|&ai| StaticModInt::new(ai)).collect();
425    let b: Vec<_> = b.iter().map(|&bi| StaticModInt::new(bi)).collect();
426    convolve(a, b).into_iter().map(|x: StaticModInt<M>| x.get()).collect()
427}
428
429pub fn convolve_u64_acl(a: &[u64], b: &[u64]) -> Vec<u64> {
430    if a.is_empty() || b.is_empty() {
431        return vec![];
432    }
433    let n = a.len();
434    let m = b.len();
435
436    let mod1 = Mod45e24p1::VALUE as u64;
437    let mod2 = Mod5e25p1::VALUE as u64;
438    let mod3 = Mod7e26p1::VALUE as u64;
439    let m2m3 = mod2 * mod3;
440    let m1m3 = mod1 * mod3;
441    let m1m2 = mod1 * mod2;
442    let m1m2m3 = m1m2.wrapping_mul(mod3);
443
444    type ModInt754974721 = StaticModInt<Mod45e24p1>;
445    type ModInt167772161 = StaticModInt<Mod5e25p1>;
446    type ModInt469762049 = StaticModInt<Mod7e26p1>;
447
448    let i1 = ModInt754974721::new(m2m3).recip().get() as u64;
449    let i2 = ModInt167772161::new(m1m3).recip().get() as u64;
450    let i3 = ModInt469762049::new(m1m2).recip().get() as u64;
451
452    let max_bit = 24;
453    assert_eq!(mod1 % (1 << max_bit), 1);
454    assert_eq!(mod2 % (1 << max_bit), 1);
455    assert_eq!(mod3 % (1 << max_bit), 1);
456    assert!(n + m - 1 <= (1 << max_bit));
457
458    let c1 = convolve_from::<Mod45e24p1, _>(&a, &b);
459    let c2 = convolve_from::<Mod5e25p1, _>(&a, &b);
460    let c3 = convolve_from::<Mod7e26p1, _>(&a, &b);
461
462    c1.into_iter()
463        .zip(c2)
464        .zip(c3)
465        .map(|((c1i, c2i), c3i)| {
466            let c1i = c1i as u64;
467            let c2i = c2i as u64;
468            let c3i = c3i as u64;
469
470            let mut x = 0;
471            x += (c1i * i1) % mod1 * m2m3;
472            x += (c2i * i2) % mod2 * m1m3;
473            x += (c3i * i3) % mod3 * m1m2;
474            let rem = x.rem_euclid(mod1);
475            let diff = if c1i >= rem { c1i - rem } else { mod1 - (rem - c1i) };
476            let offset = [0, 0, m1m2m3, 2 * m1m2m3, 3 * m1m2m3];
477            x - offset[diff as usize % 5]
478        })
479        .collect()
480}
481
482enum CrtU64 {}
483enum CrtWrappingU64 {}
484enum CrtU128 {}
485enum CrtWrappingU128 {}
486#[derive(Copy, Clone)]
487struct CrtU32Mod(u32);
488#[derive(Copy, Clone)]
489struct CrtU64Mod(u64);
490#[derive(Copy, Clone)]
491struct CrtU128Mod(u128);
492
493type U32x3 = ((u32, u32), u32);
494type U32x5 = ((U32x3, u32), u32);
495type U32x6 = (U32x5, u32);
496type U32x10 = ((((U32x6, u32), u32), u32), u32);
497
498trait ToArray {
499    type Output;
500    fn to_array(self) -> Self::Output;
501}
502
503impl ToArray for U32x3 {
504    type Output = [u32; 3];
505    fn to_array(self) -> Self::Output {
506        let ((x0, x1), x2) = self;
507        [x0, x1, x2]
508    }
509}
510
511impl ToArray for U32x5 {
512    type Output = [u32; 5];
513    fn to_array(self) -> Self::Output {
514        let ((((x0, x1), x2), x3), x4) = self;
515        [x0, x1, x2, x3, x4]
516    }
517}
518
519impl ToArray for U32x6 {
520    type Output = [u32; 6];
521    fn to_array(self) -> Self::Output {
522        let (((((x0, x1), x2), x3), x4), x5) = self;
523        [x0, x1, x2, x3, x4, x5]
524    }
525}
526
527impl ToArray for U32x10 {
528    type Output = [u32; 10];
529    fn to_array(self) -> Self::Output {
530        let (((((((((x0, x1), x2), x3), x4), x5), x6), x7), x8), x9) = self;
531        [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9]
532    }
533}
534
535trait Crt {
536    type Input;
537    type Output;
538    fn crt(i: Self::Input) -> Self::Output;
539}
540
541impl Crt for CrtU64 {
542    type Input = U32x3;
543    type Output = u64;
544    fn crt(xs: Self::Input) -> u64 {
545        let [x0, x1, x2] = xs.to_array();
546        [
547            (x0 as u64, MOD0 as u64),
548            (x1 as u64, MOD1 as u64),
549            (x2 as u64, MOD2 as u64),
550        ]
551        .crt_wrapping()
552    }
553}
554
555impl Crt for CrtWrappingU64 {
556    type Input = U32x6;
557    type Output = u64;
558    fn crt(xs: Self::Input) -> u64 {
559        let [x0, x1, x2, x3, x4, x5] = xs.to_array();
560        [
561            (x0 as u64, MOD0 as u64),
562            (x1 as u64, MOD1 as u64),
563            (x2 as u64, MOD2 as u64),
564            (x3 as u64, MOD3 as u64),
565            (x4 as u64, MOD4 as u64),
566            (x5 as u64, MOD5 as u64),
567        ]
568        .crt_wrapping()
569    }
570}
571
572impl Crt for CrtU128 {
573    type Input = U32x5;
574    type Output = u128;
575    fn crt(xs: Self::Input) -> u128 {
576        let [x0, x1, x2, x3, x4] = xs.to_array();
577        [
578            (x0 as u128, MOD0 as u128),
579            (x1 as u128, MOD1 as u128),
580            (x2 as u128, MOD2 as u128),
581            (x3 as u128, MOD3 as u128),
582            (x4 as u128, MOD4 as u128),
583        ]
584        .crt_wrapping()
585    }
586}
587
588impl Crt for CrtWrappingU128 {
589    type Input = U32x10;
590    type Output = u128;
591    fn crt(xs: Self::Input) -> u128 {
592        let [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9] = xs.to_array();
593        [
594            (x0 as u128, MOD0 as u128),
595            (x1 as u128, MOD1 as u128),
596            (x2 as u128, MOD2 as u128),
597            (x3 as u128, MOD3 as u128),
598            (x4 as u128, MOD4 as u128),
599            (x5 as u128, MOD5 as u128),
600            (x6 as u128, MOD6 as u128),
601            (x7 as u128, MOD7 as u128),
602            (x8 as u128, MOD8 as u128),
603            (x9 as u128, MOD9 as u128),
604        ]
605        .crt_wrapping()
606    }
607}
608
609trait CrtMod {
610    type Input;
611    type Output;
612    fn crt_mod(self, i: Self::Input) -> Self::Output;
613}
614
615impl CrtU32Mod {
616    fn new(m: u32) -> Self { Self(m) }
617}
618
619impl CrtU64Mod {
620    fn new(m: u64) -> Self { Self(m) }
621}
622
623impl CrtU128Mod {
624    fn new(m: u128) -> Self { Self(m) }
625}
626
627impl CrtMod for CrtU32Mod {
628    type Input = U32x3;
629    type Output = u32;
630    fn crt_mod(self, xs: Self::Input) -> Self::Output {
631        let [x0, x1, x2] = xs.to_array();
632        [
633            (x0 as u64, MOD0 as u64),
634            (x1 as u64, MOD1 as u64),
635            (x2 as u64, MOD2 as u64),
636        ]
637        .crt_mod(self.0 as u64) as u32
638    }
639}
640
641impl CrtMod for CrtU64Mod {
642    type Input = U32x6;
643    type Output = u64;
644    fn crt_mod(self, xs: Self::Input) -> Self::Output {
645        let [x0, x1, x2, x3, x4, x5] = xs.to_array();
646        [
647            (x0 as u64, MOD0 as u64),
648            (x1 as u64, MOD1 as u64),
649            (x2 as u64, MOD2 as u64),
650            (x3 as u64, MOD3 as u64),
651            (x4 as u64, MOD4 as u64),
652            (x5 as u64, MOD5 as u64),
653        ]
654        .crt_mod(self.0)
655    }
656}
657
658impl CrtMod for CrtU128Mod {
659    type Input = U32x10;
660    type Output = u128;
661    fn crt_mod(self, xs: Self::Input) -> Self::Output {
662        let [x0, x1, x2, x3, x4, x5, x6, x7, x8, x9] = xs.to_array();
663        [
664            (x0 as u128, MOD0 as u128),
665            (x1 as u128, MOD1 as u128),
666            (x2 as u128, MOD2 as u128),
667            (x3 as u128, MOD3 as u128),
668            (x4 as u128, MOD4 as u128),
669            (x5 as u128, MOD5 as u128),
670            (x6 as u128, MOD6 as u128),
671            (x7 as u128, MOD7 as u128),
672            (x8 as u128, MOD8 as u128),
673            (x9 as u128, MOD9 as u128),
674        ]
675        .crt_mod(self.0)
676    }
677}
678
679macro_rules! impl_convolve {
680    ( $( ($fn:ident, $ty:ty, $crt:path, [$mod1:ty, $( $mod:ty ),*]), )* ) => { $(
681        pub fn $fn(a: &[$ty], b: &[$ty]) -> Vec<$ty> {
682            if a.is_empty() || b.is_empty() {
683                return vec![];
684            }
685            let n = a.len();
686            let m = b.len();
687            assert!(n + m - 1 <= 1_usize << <$mod1>::EXP);
688            $( assert!(n + m - 1 <= 1_usize << <$mod>::EXP) );*;
689            convolve_from::<$mod1, _>(&a, &b)
690                .into_iter()
691                $( .zip(convolve_from::<$mod, _>(&a, &b)) )*
692                .map($crt)
693                .collect()
694        }
695    )* }
696}
697
698macro_rules! impl_convolve_mod {
699    ( $( ($fn:ident, $ty:ty, $crt:ident, [$mod1:ty, $( $mod:ty ),*]), )* ) => { $(
700        pub fn $fn(a: &[$ty], b: &[$ty], mm: $ty) -> Vec<$ty> {
701            if a.is_empty() || b.is_empty() {
702                return vec![];
703            }
704            let n = a.len();
705            let m = b.len();
706            assert!(n + m - 1 <= 1_usize << <$mod1>::EXP);
707            $( assert!(n + m - 1 <= 1_usize << <$mod>::EXP) );*;
708            let crt = $crt::new(mm);
709            convolve_from::<$mod1, _>(&a, &b)
710                .into_iter()
711                $( .zip(convolve_from::<$mod, _>(&a, &b)) )*
712                .map(|x| crt.crt_mod(x))
713                .collect()
714        }
715    )* }
716}
717
718impl_convolve! {
719    (convolve_u64, u64, CrtU64::crt, [Mod0, Mod1, Mod2]),
720    (convolve_u128, u128, CrtU128::crt, [Mod0, Mod1, Mod2, Mod3, Mod4]),
721    (convolve_wrapping_u64, u64, CrtWrappingU64::crt, [Mod0, Mod1, Mod2, Mod3, Mod4, Mod5]),
722    (convolve_wrapping_u128, u128, CrtWrappingU128::crt, [Mod0, Mod1, Mod2, Mod3, Mod4, Mod5, Mod6, Mod7, Mod8, Mod9]),
723}
724
725impl_convolve_mod! {
726    (convolve_u32_mod, u32, CrtU32Mod, [Mod0, Mod1, Mod2]),
727    (convolve_u64_mod, u64, CrtU64Mod, [Mod0, Mod1, Mod2, Mod3, Mod4, Mod5]),
728    (convolve_u128_mod, u128, CrtU128Mod, [Mod0, Mod1, Mod2, Mod3, Mod4, Mod5, Mod6, Mod7, Mod8, Mod9]),
729}
730
731fn ceil_pow2(n: u32) -> u32 { 32 - n.saturating_sub(1).leading_zeros() }
732
733#[test]
734fn sanity_check() {
735    type Mi = modint::ModInt998244353;
736
737    let a: Vec<_> = [0, 1, 2, 3, 4].iter().map(|&x| Mi::new(x)).collect();
738    let b: Vec<_> = [0, 1, 2, 4, 8].iter().map(|&x| Mi::new(x)).collect();
739    let c: Vec<_> = convolve_fft(a, b).iter().map(|x| x.get()).collect();
740
741    assert_eq!(c, [0, 0, 1, 4, 11, 26, 36, 40, 32]);
742}
743
744#[test]
745fn proot() {
746    assert_eq!(Mod45e24p1::PRIMITIVE_ROOT, 11);
747    assert_eq!(Mod73e24p1::PRIMITIVE_ROOT, 3);
748    assert_eq!(Mod127e24p1::PRIMITIVE_ROOT, 3);
749    assert_eq!(Mod5e25p1::PRIMITIVE_ROOT, 3);
750    assert_eq!(Mod33e25p1::PRIMITIVE_ROOT, 10);
751    assert_eq!(Mod51e25p1::PRIMITIVE_ROOT, 29);
752    assert_eq!(Mod63e25p1::PRIMITIVE_ROOT, 5);
753    assert_eq!(Mod7e26p1::PRIMITIVE_ROOT, 3);
754    assert_eq!(Mod27e26p1::PRIMITIVE_ROOT, 13);
755    assert_eq!(Mod15e27p1::PRIMITIVE_ROOT, 31);
756}
757
758#[test]
759fn large() {
760    let max32 = u32::MAX as u64;
761    assert_eq!(convolve_u64(&[max32], &[max32]), [max32 * max32]);
762
763    let max64 = u64::MAX as u128;
764    assert_eq!(convolve_u128(&[max64], &[max64]), [max64 * max64]);
765}
766
767#[test]
768fn long_wrapping_u64() {
769    let max32 = u32::MAX as u64;
770    let n = 1 << 24;
771    let long32 = vec![max32; n];
772    let a = convolve_wrapping_u64(&long32, &long32);
773    for i in 0..n {
774        assert_eq!(a[i], a[n + n - 2 - i]);
775        assert_eq!(a[i], (max32 * max32).wrapping_mul(i as u64 + 1));
776    }
777}
778
779#[test]
780fn long_wrapping_u128() {
781    let max64 = u64::MAX as u128;
782    let n = 1 << 23;
783    let long64 = vec![max64; n];
784    let a = convolve_wrapping_u128(&long64, &long64);
785    for i in 0..n {
786        assert_eq!(a[i], a[n + n - 2 - i]);
787        assert_eq!(a[i], (max64 * max64).wrapping_mul(i as u128 + 1));
788    }
789}
790
791#[test]
792fn long_u32_mod() {
793    let max32 = u32::MAX;
794    let n = 1 << 24;
795    let p = 998244353;
796    let long32 = vec![max32; n];
797    let a = convolve_u32_mod(&long32, &long32, p as u32);
798    for i in 0..n {
799        assert_eq!(a[i], a[n + n - 2 - i]);
800        let expected = (max32 as u64 % p).pow(2) % p * (i as u64 + 1) % p;
801        assert_eq!(a[i], expected as u32);
802    }
803}
804
805#[test]
806fn long_u64_mod() {
807    let max32 = u32::MAX as u64;
808    let n = 1 << 24;
809    let p = 998244353;
810    let long32 = vec![max32; n];
811    let a = convolve_u64_mod(&long32, &long32, p);
812    for i in 0..n {
813        assert_eq!(a[i], a[n + n - 2 - i]);
814        assert_eq!(a[i], (max32 * max32 % p) * (i as u64 + 1) % p);
815    }
816}
817
818#[test]
819fn long_u128_mod() {
820    let max64 = u64::MAX as u128;
821    let n = 1 << 23;
822    let p = 998244353;
823    let long64 = vec![max64; n];
824    let a = convolve_u128_mod(&long64, &long64, p);
825    for i in 0..n {
826        assert_eq!(a[i], a[n + n - 2 - i]);
827        assert_eq!(a[i], (max64 * max64 % p) * (i as u128 + 1) % p);
828    }
829}
830
831#[test]
832fn pow2p1() {
833    type Mi = modint::ModInt998244353;
834
835    let n = 1 << 6 | 1;
836    let a: Vec<_> = (0..n).map(|x| Mi::new(x + 1)).collect();
837    let b = a.clone();
838    let expected = convolve_naive(&a, &b);
839    assert_eq!(convolve(a, b), expected);
840}