1use super::convolution;
4use super::modint;
5
6use std::fmt::{self, Debug, Display};
7use std::ops::{
8 Add, AddAssign, BitAnd, BitAndAssign, Div, DivAssign, Mul, MulAssign, Neg,
9 Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
10};
11
12use convolution::{butterfly, butterfly_inv, convolve, NttFriendly};
13use modint::{ModIntBase, StaticModInt};
14
15#[derive(Clone, Eq, PartialEq)]
30pub struct Polynomial<M: NttFriendly>(Vec<StaticModInt<M>>);
31
32impl<M: NttFriendly> Display for Polynomial<M> {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 if self.0.is_empty() {
35 return write!(f, "0");
36 }
37
38 let mut out = false;
39 for (i, &c) in self.0.iter().enumerate().filter(|&(_, c)| c.get() > 0) {
40 if out {
41 write!(f, " + ")?;
42 }
43 match (i, c.get()) {
44 (0, c) => write!(f, "{}", c)?,
45 (1, 1) => write!(f, "x")?,
46 (1, c) => write!(f, "{}x", c)?,
47 (_, 1) => write!(f, "x^{}", i)?,
48 (_, c) => write!(f, "{}x^{}", c, i)?,
49 }
50 out = true;
51 }
52 Ok(())
53 }
54}
55
56impl<M: NttFriendly> Debug for Polynomial<M> {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 f.debug_struct("Polynomial")
59 .field("f", &self.0.iter().map(|x| x.get()).collect::<Vec<_>>())
60 .field("mod", &M::VALUE)
61 .finish()
62 }
63}
64
65impl<M: NttFriendly> Polynomial<M> {
66 pub fn new() -> Self { Self(vec![]) }
82
83 fn normalize(&mut self) {
84 if self.0.is_empty() {
85 return;
86 }
87 if let Some(i) = (0..self.0.len()).rev().find(|&i| self.0[i].get() > 0)
88 {
89 self.0.truncate(i + 1);
90 } else {
91 self.0.clear();
92 }
93 }
94
95 #[allow(dead_code)]
96 fn recip_naive(&self, len: usize) -> Self {
97 if len == 0 {
98 return Self(vec![]);
99 }
100
101 let mut res = Self(vec![self.0[0].recip()]);
102 let mut cur_len = 1;
103 while cur_len < len {
104 cur_len *= 2;
105 let mut self_: Self =
108 self.0[..self.0.len().min(cur_len)].to_vec().into();
109
110 let ftwo = Self(vec![StaticModInt::new(2); 2 * cur_len]);
111 self_.fft_butterfly(2 * cur_len);
112 res.fft_butterfly(2 * cur_len);
113 let mut tmp = (&ftwo - (&self_ & &res)) & &res;
114 tmp.fft_inv_butterfly(2 * cur_len);
115
116 tmp.truncate(cur_len);
117 res.0 = tmp.0;
118 }
119 res.truncate(len);
120 res
121 }
122
123 pub fn recip(&self, len: usize) -> Self {
134 if len == 0 {
135 return Self(vec![]);
136 }
137
138 let mut res = Self(vec![self.0[0].recip()]);
139 let mut cur_len = 1;
140 while cur_len < len {
141 cur_len *= 2;
142
143 let mut ff: Self =
144 self.0[..self.0.len().min(cur_len)].to_vec().into();
145 let mut gg = res.clone();
146 ff.0.resize(cur_len, StaticModInt::new(0));
147 gg.0.resize(cur_len, StaticModInt::new(0));
148 butterfly(&mut ff.0);
149 butterfly(&mut gg.0);
150 for i in 0..cur_len {
151 ff.0[i] *= gg.0[i];
152 }
153 butterfly_inv(&mut ff.0);
154 let iz = StaticModInt::new(cur_len).recip();
155 for i in 0..cur_len / 2 {
156 ff.0[i] = StaticModInt::new(0);
157 ff.0[cur_len / 2 + i] = -ff.0[cur_len / 2 + i] * iz;
158 }
159 butterfly(&mut ff.0);
160 for i in 0..cur_len {
161 ff.0[i] *= gg.0[i];
162 }
163 butterfly_inv(&mut ff.0);
164 for i in 0..cur_len / 2 {
165 ff.0[i] = res.0[i];
166 ff.0[cur_len / 2 + i] *= iz;
167 }
168 res = ff;
169 }
170 res.truncated(len)
171 }
172
173 pub fn truncated(mut self, len: usize) -> Self {
184 self.truncate(len);
185 self
186 }
187
188 pub fn ref_truncated(&self, len: usize) -> Self {
200 self.0[..len.min(self.0.len())].to_vec().into()
201 }
202
203 pub fn truncate(&mut self, len: usize) {
215 self.0.truncate(len);
216 self.normalize();
217 }
218
219 pub fn reversed(mut self) -> Self {
230 self.reverse();
231 self
232 }
233
234 pub fn reverse(&mut self) {
246 self.0.reverse();
247 self.normalize();
248 }
249
250 pub fn differential(mut self) -> Self {
271 self.differentiate();
272 self
273 }
274
275 pub fn differentiate(&mut self) {
287 if self.0.is_empty() {
288 return;
289 }
290 for i in 1..self.0.len() {
291 self.0[i] *= StaticModInt::new(i);
292 }
293 self.0.remove(0);
294 }
295
296 pub fn integral(mut self) -> Self {
328 self.integrate();
329 self
330 }
331
332 pub fn integrate(&mut self) {
344 if self.0.is_empty() {
345 return;
346 }
347 let n = self.0.len();
348 let recip = {
349 let m = M::VALUE as u64;
350 let mut dp = vec![1_u64; n + 1];
351 for i in 2..=n {
352 let (q, r) = (m / i as u64, m % i as u64);
353 dp[i as usize] = m - q * dp[r as usize] % m;
354 }
355 dp
356 };
357 for i in 0..n {
358 self.0[i] *= StaticModInt::new(recip[i + 1]);
359 }
360 self.0.insert(0, StaticModInt::new(0));
361 }
362
363 pub fn log(&self, len: usize) -> Self {
383 assert_eq!(self.0[0].get(), 1);
384
385 let mut diff = self.clone().differential();
386 diff *= self.recip(len);
387 diff.integrate();
388 diff.truncate(len);
389 diff
390 }
391
392 #[allow(dead_code)]
393 fn exp_naive(&self, len: usize) -> Self {
394 assert_eq!(self.0.get(0).map(|x| x.get()).unwrap_or(0), 0);
395
396 if len == 0 {
397 return Self(vec![]);
398 }
399
400 let mut res = Self(vec![StaticModInt::new(1)]);
401 let one = Self(vec![StaticModInt::new(1)]);
402 let mut cur_len = 1;
403 while cur_len < len {
404 cur_len *= 2;
405 let mut tmp = &one - res.log(cur_len) + self;
406 tmp *= res;
407 tmp.truncate(cur_len);
408 res = tmp;
409 }
410 res.truncate(len);
411 res
412 }
413
414 pub fn exp(&self, len: usize) -> Self {
435 let mut b = Self::from([1, self.get(1).get()]);
436 let mut c = Self::from([1]);
437 let mut z2 = Self::from([1, 1]);
438
439 let mut cur_len = 2;
440 while cur_len < len {
441 let m = cur_len;
442 cur_len *= 2;
443
444 let mut y = b.clone();
445 y.0.resize(2 * m, 0.into());
446 y.fft_butterfly(2 * m);
447 let z1 = z2;
448 let mut z = &y & &z1;
449 z.fft_inv_butterfly(m);
450 z.0.resize(m, 0.into());
451 z.0[..m / 2].fill(0.into());
452 z.fft_butterfly(m);
453 z &= -&z1;
454 z.fft_inv_butterfly(m);
455 c.0.resize(m / 2, 0.into());
456 c.0.extend_from_slice(&z.0[z.0.len().min(m / 2)..]);
457 z2 = c.clone();
458 z2.fft_butterfly(2 * m);
459 let mut x = Self::from(&self.0[..m.min(self.0.len())]);
460 x.differentiate();
461 x.fft_butterfly(m);
462 x &= &y;
463 x.fft_inv_butterfly(m);
464 x -= b.clone().differential();
465 x.0.resize(2 * m, 0.into());
466 for i in 0..m - 1 {
467 x.0[m + i] = x.0[i];
468 x.0[i] = 0.into();
469 }
470 x.fft_butterfly(2 * m);
471 x &= &z2;
472 x.fft_inv_butterfly(2 * m);
473 x.integrate();
474 x.0.resize(2 * m, 0.into());
475 for i in m..self.0.len().min(2 * m) {
476 x.0[i] += self.0[i];
477 }
478 x.0[..m].fill(0.into());
479 x.fft_butterfly(2 * m);
480 x &= &y;
481 x.fft_inv_butterfly(2 * m);
482 b.0.resize(m, 0.into());
483 b.0.extend_from_slice(&x.0[x.0.len().min(m)..]);
484 }
485
486 b.truncated(len)
487 }
488
489 pub fn pow<I: Into<StaticModInt<M>>>(&self, k: I, len: usize) -> Self {
530 let k = k.into();
531 let k_ = k.get() as usize;
532
533 if k_ == 0 {
535 return Self::from([1]).truncated(len);
536 } else if self.is_zero() {
537 return Self::new();
538 }
539
540 let l = (0..).find(|&i| self.0[i].get() != 0).unwrap();
542 let a_l = self.0[l];
543 if len <= l * k_ {
544 return Self::new();
545 }
546
547 let g = (self >> l) / a_l;
548 let g_pow = (g.log(len) * k).exp(len - l * k_);
549 (g_pow << (l * k_)) * a_l.pow(k_ as u64)
550 }
551
552 #[allow(dead_code)]
553 fn circular_naive(&self, im: &Self, len: usize) -> (Self, Self) {
554 let re = self;
555 assert_eq!(re.get(0).get(), 0);
556 assert_eq!(im.get(0).get(), 0);
557 if len == 0 {
558 return (Self::new(), Self::new());
559 }
560
561 let one = StaticModInt::new(1);
562 let mut cos = Self::from([1]);
563 let mut sin = Self::from([0]);
564 let mut cur_len = 1;
565 while cur_len < len {
566 cur_len *= 2;
567
568 let dcos = cos.clone().differential();
569 let dsin = sin.clone().differential();
570
571 let hypot = (&cos * &cos + &sin * &sin).recip(cur_len);
572 let ecos = &dcos * &cos + &dsin * &sin;
573 let esin = &dsin * &cos - &dcos * &sin;
574
575 let logcos = (ecos * &hypot).truncated(cur_len - 1).integral();
576 let logsin = (esin * &hypot).truncated(cur_len - 1).integral();
577
578 let gcos = -logcos + one + re.ref_truncated(cur_len);
579 let gsin = -logsin + im.ref_truncated(cur_len);
580 let hcos = ((&cos * &gcos) - (&sin * &gsin)).truncated(cur_len);
581 let hsin = ((&cos * &gsin) + (&sin * &gcos)).truncated(cur_len);
582
583 cos = hcos;
584 sin = hsin;
585 }
586
587 (cos.truncated(len), sin.truncated(len))
588 }
589
590 pub fn circular(&self, im: &Self, len: usize) -> (Self, Self) {
609 let re = self;
610 assert_eq!(re.get(0).get(), 0);
611 assert_eq!(im.get(0).get(), 0);
612 if len == 0 {
613 return (Self::new(), Self::new());
614 }
615
616 let one = StaticModInt::new(1);
617 let mut cos = Self::from([1]);
618 let mut sin = Self::from([0]);
619 let mut cur_len = 1;
620 while cur_len < len {
621 cur_len *= 2;
622
623 let mut dcos = cos.clone().differential();
624 let mut dsin = sin.clone().differential();
625 cos.fft_butterfly(cur_len);
626 sin.fft_butterfly(cur_len);
627 dcos.fft_butterfly(cur_len);
628 dsin.fft_butterfly(cur_len);
629
630 let mut hypot = (&cos & &cos) + (&sin & &sin);
631 let mut ecos = (&dcos & &cos) + (&dsin & &sin);
632 let mut esin = (&dsin & &cos) - (&dcos & &sin);
633 hypot.fft_inv_butterfly(cur_len);
634 hypot = hypot.recip(cur_len);
635 hypot.fft_butterfly(2 * cur_len);
636 ecos.fft_butterfly_double(2 * cur_len);
637 esin.fft_butterfly_double(2 * cur_len);
638
639 let mut logcos = &ecos & &hypot;
640 let mut logsin = &esin & &hypot;
641 logcos.fft_inv_butterfly(2 * cur_len);
642 logsin.fft_inv_butterfly(2 * cur_len);
643 logcos = logcos.truncated(cur_len - 1).integral();
644 logsin = logsin.truncated(cur_len - 1).integral();
645
646 let mut gcos = -logcos + one + re.ref_truncated(cur_len);
647 let mut gsin = -logsin + im.ref_truncated(cur_len);
648 gcos.fft_butterfly(2 * cur_len);
649 gsin.fft_butterfly(2 * cur_len);
650 cos.fft_butterfly_double(2 * cur_len);
651 sin.fft_butterfly_double(2 * cur_len);
652
653 let mut hcos = (&cos & &gcos) - (&sin & &gsin);
654 let mut hsin = (&cos & &gsin) + (&sin & &gcos);
655 hcos.fft_inv_butterfly(2 * cur_len);
656 hsin.fft_inv_butterfly(2 * cur_len);
657
658 cos = hcos.truncated(cur_len);
659 sin = hsin.truncated(cur_len);
660 }
661
662 (cos.truncated(len), sin.truncated(len))
663 }
664
665 pub fn cos(&self, len: usize) -> Self { Self::new().circular(self, len).0 }
678
679 pub fn sin(&self, len: usize) -> Self { Self::new().circular(self, len).1 }
692
693 pub fn tan(&self, len: usize) -> Self {
706 let (cos, sin) = Self::new().circular(self, len);
707 (sin * cos.recip(len)).truncated(len)
708 }
709
710 pub fn polyeqn(
781 mut self,
782 n: usize,
783 f_dfr: impl Fn(&Self, usize) -> Self, ) -> Self {
785 if self.0.is_empty() {
786 self.0.push(StaticModInt::new(0));
787 }
788 let mut d = self.0.len();
789 let mut y = self;
790 while d < n {
791 d *= 2;
792 y -= f_dfr(&y, d).truncated(d);
793 }
794 y.truncated(n)
795 }
796
797 pub fn fode(
883 mut self,
884 n: usize,
885 f_df: impl Fn(&Self, usize) -> (Self, Self),
886 ) -> Self {
887 if self.0.is_empty() {
888 self.0.push(StaticModInt::new(0));
889 }
890 let mut d = self.0.len();
891 let mut y = self;
892 while d < n {
893 d *= 2;
894 let (f, df) = f_df(&y, d);
895 let h = f - y.clone().differential();
896 let u = (-df).integral().exp(d);
897 y += (u.recip(d) * (u * h).truncated(d).integral()).truncated(d);
898 }
899 y.truncated(n)
900 }
901
902 pub fn get(&self, i: usize) -> StaticModInt<M> {
916 self.0.get(i).copied().unwrap_or(StaticModInt::new(0))
917 }
918
919 pub fn eval(&self, t: impl Into<StaticModInt<M>>) -> StaticModInt<M> {
920 let t = t.into();
921 let mut ft = StaticModInt::new(0);
922 for &a in self.0.iter().rev() {
923 ft *= t;
924 ft += a;
925 }
926 ft
927 }
928
929 pub fn into_inner(self) -> Vec<StaticModInt<M>> { self.0 }
931
932 pub fn fft_butterfly(&mut self, len: usize) {
936 let ceil_len = len.next_power_of_two();
937 self.0.resize(ceil_len, StaticModInt::new(0));
938 butterfly(&mut self.0);
939 self.normalize();
940 }
941
942 pub fn fft_inv_butterfly(&mut self, len: usize) {
944 let ceil_len = len.next_power_of_two();
945 self.0.resize(ceil_len, StaticModInt::new(0));
946 butterfly_inv(&mut self.0);
947 self.0.truncate(len);
948 let iz = StaticModInt::new(ceil_len).recip();
949 for c in &mut self.0 {
950 *c *= iz;
951 }
952 self.normalize();
953 }
954
955 pub fn fft_butterfly_double(&mut self, to_len: usize) {
958 if self.is_zero() {
959 return;
960 }
961
962 let mut dbl = self.clone();
963 let g = StaticModInt::<M>::new(M::PRIMITIVE_ROOT);
964 let zeta = g.pow((M::VALUE as u64 - 1) / (to_len as u64));
965
966 dbl.fft_inv_butterfly(to_len / 2);
967 let mut r = StaticModInt::new(1);
968 for i in 0..dbl.0.len() {
969 dbl.0[i] *= r;
970 r *= zeta;
971 }
972 dbl.fft_butterfly(to_len / 2);
973 self.0.resize(to_len / 2, StaticModInt::new(0));
974 self.0.append(&mut dbl.0);
975 }
976
977 pub fn is_zero(&self) -> bool { self.0.is_empty() }
979
980 pub fn len(&self) -> usize { self.0.len() }
982
983 pub fn div_mod(&self, other: &Polynomial<M>) -> (Self, Self) {
987 let q = self / other;
988 let r = self - &q * other;
989 (q, r)
990 }
991
992 pub fn div_nth(
995 &self,
996 other: &Polynomial<M>,
997 mut n: usize,
998 ) -> StaticModInt<M> {
999 let mut p = self.clone();
1000 let mut q = other.clone();
1001 while n > 0 {
1002 let d = (2 * q.0.len() - 1).next_power_of_two();
1003 p.fft_butterfly(d);
1004 q.fft_butterfly(d);
1005 let pq_: Vec<_> = (0..d).map(|i| p.get(i) * q.get(i ^ 1)).collect();
1006 let qq_: Vec<_> =
1007 (0..d).step_by(2).map(|i| q.get(i) * q.get(i + 1)).collect();
1008 let (mut pq_, mut qq_): (Self, Self) = (pq_.into(), qq_.into());
1009 pq_.fft_inv_butterfly(d);
1010 qq_.fft_inv_butterfly(d / 2);
1011 let u: Vec<_> = (n % 2..d).step_by(2).map(|i| pq_.get(i)).collect();
1012 p = u.into();
1013 q = qq_.into();
1014 n /= 2;
1015 }
1016 p.get(0)
1017 }
1018
1019 #[allow(dead_code)]
1020 fn sparse(&self, thresh: usize) -> Option<Vec<(usize, StaticModInt<M>)>> {
1021 let nz: Vec<_> = self
1022 .0
1023 .iter()
1024 .copied()
1025 .enumerate()
1026 .filter(|&(_, ai)| ai.get() != 0)
1027 .take(thresh + 1)
1028 .collect();
1029 (nz.len() <= thresh).then(|| nz)
1030 }
1031}
1032
1033impl<M: NttFriendly> From<Vec<StaticModInt<M>>> for Polynomial<M> {
1034 fn from(buf: Vec<StaticModInt<M>>) -> Self {
1035 let mut res = Self(buf);
1036 res.normalize();
1037 res
1038 }
1039}
1040
1041impl<'a, M: NttFriendly> From<&'a [StaticModInt<M>]> for Polynomial<M> {
1042 fn from(buf: &'a [StaticModInt<M>]) -> Self {
1043 let mut res = Self(buf.to_vec());
1044 res.normalize();
1045 res
1046 }
1047}
1048
1049impl<M: NttFriendly, const N: usize> From<[StaticModInt<M>; N]>
1050 for Polynomial<M>
1051{
1052 fn from(buf: [StaticModInt<M>; N]) -> Self {
1053 let mut res = Self(buf.to_vec());
1054 res.normalize();
1055 res
1056 }
1057}
1058
1059macro_rules! impl_from {
1060 ( $($ty:ty) * ) => { $(
1061 impl<M: NttFriendly> From<Vec<$ty>> for Polynomial<M> {
1062 fn from(buf: Vec<$ty>) -> Self {
1063 let mut res =
1064 Self(buf.into_iter().map(StaticModInt::new).collect());
1065 res.normalize();
1066 res
1067 }
1068 }
1069 impl<'a, M: NttFriendly> From<&'a [$ty]> for Polynomial<M> {
1070 fn from(buf: &'a [$ty]) -> Self {
1071 let mut res =
1072 Self(buf.iter().map(|&x| StaticModInt::new(x)).collect());
1073 res.normalize();
1074 res
1075 }
1076 }
1077 impl<M: NttFriendly, const N: usize> From<[$ty; N]> for Polynomial<M> {
1078 fn from(buf: [$ty; N]) -> Self {
1079 let mut res =
1080 Self(buf.iter().map(|&x| StaticModInt::new(x)).collect());
1081 res.normalize();
1082 res
1083 }
1084 }
1085 )* }
1086}
1087
1088impl_from! {
1089 i8 i16 i32 i64 i128 isize u8 u16 u32 u64 u128 usize
1090}
1091
1092impl<'a, M: NttFriendly> AddAssign<&'a Polynomial<M>> for Polynomial<M> {
1095 fn add_assign(&mut self, other: &'a Polynomial<M>) {
1096 let n = self.0.len().max(other.0.len());
1097 self.0.resize(n, StaticModInt::new(0));
1098 for i in 0..other.0.len() {
1099 self.0[i] += other.0[i];
1100 }
1101 self.normalize();
1102 }
1103}
1104
1105impl<M: NttFriendly> AddAssign for Polynomial<M> {
1106 fn add_assign(&mut self, other: Polynomial<M>) { self.add_assign(&other); }
1107}
1108
1109impl<'a, M: NttFriendly> SubAssign<&'a Polynomial<M>> for Polynomial<M> {
1110 fn sub_assign(&mut self, other: &'a Polynomial<M>) {
1111 let n = self.0.len().max(other.0.len());
1112 self.0.resize(n, StaticModInt::new(0));
1113 for i in 0..other.0.len() {
1114 self.0[i] -= other.0[i];
1115 }
1116 self.normalize();
1117 }
1118}
1119
1120impl<M: NttFriendly> SubAssign for Polynomial<M> {
1121 fn sub_assign(&mut self, other: Polynomial<M>) { self.sub_assign(&other); }
1122}
1123
1124impl<'a, M: NttFriendly> MulAssign<&'a Polynomial<M>> for Polynomial<M> {
1125 fn mul_assign(&mut self, other: &'a Polynomial<M>) {
1126 self.mul_assign(other.clone());
1127 }
1128}
1129
1130impl<M: NttFriendly> MulAssign for Polynomial<M> {
1131 fn mul_assign(&mut self, other: Polynomial<M>) {
1132 let conv = convolve(std::mem::take(&mut self.0), other.0);
1133 self.0 = conv;
1134 self.normalize();
1135 }
1136}
1137
1138impl<'a, M: NttFriendly> DivAssign<&'a Polynomial<M>> for Polynomial<M> {
1139 fn div_assign(&mut self, other: &'a Polynomial<M>) {
1140 self.div_assign(other.clone());
1141 }
1142}
1143
1144impl<M: NttFriendly> DivAssign for Polynomial<M> {
1145 fn div_assign(&mut self, mut other: Polynomial<M>) {
1146 let deg = self.0.len() - other.0.len();
1147 self.reverse();
1148 other.reverse();
1149 *self *= other.recip(deg + 1);
1150 self.0.resize(deg + 1, StaticModInt::new(0));
1151 self.reverse();
1152 }
1153}
1154
1155impl<'a, M: NttFriendly> RemAssign<&'a Polynomial<M>> for Polynomial<M> {
1156 fn rem_assign(&mut self, other: &'a Polynomial<M>) {
1157 self.rem_assign(other.clone());
1158 }
1159}
1160
1161impl<M: NttFriendly> RemAssign for Polynomial<M> {
1162 fn rem_assign(&mut self, other: Polynomial<M>) {
1163 let div = &*self / &other;
1164 *self -= div * &other;
1165 }
1166}
1167
1168impl<'a, M: NttFriendly> BitAndAssign<&'a Polynomial<M>> for Polynomial<M> {
1169 fn bitand_assign(&mut self, other: &'a Polynomial<M>) {
1170 self.0.truncate(other.0.len());
1171 for (ai, &bi) in self.0.iter_mut().zip(&other.0) {
1172 *ai *= bi;
1173 }
1174 self.normalize();
1175 }
1176}
1177
1178impl<M: NttFriendly> BitAndAssign for Polynomial<M> {
1179 fn bitand_assign(&mut self, other: Polynomial<M>) {
1180 self.bitand_assign(&other);
1181 }
1182}
1183
1184impl<'a, M: NttFriendly> AddAssign<&'a StaticModInt<M>> for Polynomial<M> {
1187 fn add_assign(&mut self, &other: &'a StaticModInt<M>) {
1188 if other.get() == 0 {
1189 return;
1190 }
1191 if self.0.is_empty() {
1192 self.0.push(other);
1193 } else {
1194 self.0[0] += other;
1195 }
1196 self.normalize();
1197 }
1198}
1199
1200impl<M: NttFriendly> AddAssign<StaticModInt<M>> for Polynomial<M> {
1201 fn add_assign(&mut self, other: StaticModInt<M>) {
1202 self.add_assign(&other);
1203 }
1204}
1205
1206impl<'a, M: NttFriendly> SubAssign<&'a StaticModInt<M>> for Polynomial<M> {
1207 fn sub_assign(&mut self, &other: &'a StaticModInt<M>) {
1208 if other.get() == 0 {
1209 return;
1210 }
1211 if self.0.is_empty() {
1212 self.0.push(-other);
1213 } else {
1214 self.0[0] -= other;
1215 }
1216 self.normalize();
1217 }
1218}
1219
1220impl<M: NttFriendly> SubAssign<StaticModInt<M>> for Polynomial<M> {
1221 fn sub_assign(&mut self, other: StaticModInt<M>) {
1222 self.sub_assign(&other);
1223 }
1224}
1225
1226impl<'a, M: NttFriendly> MulAssign<&'a StaticModInt<M>> for Polynomial<M> {
1227 fn mul_assign(&mut self, &other: &'a StaticModInt<M>) {
1228 if other.get() == 0 {
1229 self.0.clear();
1230 return;
1231 }
1232 if self.0.is_empty() {
1233 return;
1234 }
1235
1236 for c in &mut self.0 {
1237 *c *= other;
1238 }
1239 self.normalize();
1240 }
1241}
1242
1243impl<M: NttFriendly> MulAssign<StaticModInt<M>> for Polynomial<M> {
1244 fn mul_assign(&mut self, other: StaticModInt<M>) {
1245 self.mul_assign(&other);
1246 }
1247}
1248
1249impl<'a, M: NttFriendly> DivAssign<&'a StaticModInt<M>> for Polynomial<M> {
1250 fn div_assign(&mut self, &other: &'a StaticModInt<M>) {
1251 assert_ne!(other.get(), 0);
1252 if self.0.is_empty() {
1253 return;
1254 }
1255
1256 let other = other.recip();
1257 for c in &mut self.0 {
1258 *c *= other;
1259 }
1260 self.normalize();
1261 }
1262}
1263
1264impl<M: NttFriendly> DivAssign<StaticModInt<M>> for Polynomial<M> {
1265 fn div_assign(&mut self, other: StaticModInt<M>) {
1266 self.div_assign(&other);
1267 }
1268}
1269
1270impl<'a, M: NttFriendly> RemAssign<&'a StaticModInt<M>> for Polynomial<M> {
1271 fn rem_assign(&mut self, &other: &'a StaticModInt<M>) {
1272 assert_ne!(other.get(), 0);
1273 if self.0.is_empty() {
1274 return;
1275 }
1276
1277 self.0.clear();
1278 }
1279}
1280
1281impl<M: NttFriendly> RemAssign<StaticModInt<M>> for Polynomial<M> {
1282 fn rem_assign(&mut self, other: StaticModInt<M>) {
1283 self.rem_assign(&other);
1284 }
1285}
1286
1287impl<'a, M: NttFriendly> BitAndAssign<&'a StaticModInt<M>> for Polynomial<M> {
1288 fn bitand_assign(&mut self, &other: &'a StaticModInt<M>) {
1289 if self.0.is_empty() {
1290 return;
1291 }
1292 if other.get() == 0 {
1293 self.0.clear();
1294 } else {
1295 self.0.truncate(1);
1296 self.0[0] *= other;
1297 self.normalize();
1298 }
1299 }
1300}
1301
1302impl<M: NttFriendly> BitAndAssign<StaticModInt<M>> for Polynomial<M> {
1303 fn bitand_assign(&mut self, other: StaticModInt<M>) {
1304 self.bitand_assign(&other);
1305 }
1306}
1307
1308macro_rules! impl_binop {
1309 ( $( ($op:ident, $op_assign:ident, $op_trait:ident, $op_assign_trait:ident), )* ) => {
1310 $(
1311 impl<'a, M: NttFriendly> $op_trait<Polynomial<M>> for &'a Polynomial<M> {
1312 type Output = Polynomial<M>;
1313 fn $op(self, other: Polynomial<M>) -> Polynomial<M> {
1314 self.clone().$op(other)
1315 }
1316 }
1317 impl<'a, M: NttFriendly> $op_trait<&'a Polynomial<M>> for Polynomial<M> {
1318 type Output = Polynomial<M>;
1319 fn $op(mut self, other: &'a Polynomial<M>) -> Polynomial<M> {
1320 self.$op_assign(other);
1321 self
1322 }
1323 }
1324 impl<'a, M: NttFriendly> $op_trait<&'a Polynomial<M>> for &'a Polynomial<M> {
1325 type Output = Polynomial<M>;
1326 fn $op(self, other: &'a Polynomial<M>) -> Polynomial<M> {
1327 self.clone().$op(other)
1328 }
1329 }
1330 impl<M: NttFriendly> $op_trait for Polynomial<M> {
1331 type Output = Polynomial<M>;
1332 fn $op(mut self, other: Polynomial<M>) -> Polynomial<M> {
1333 self.$op_assign(other);
1334 self
1335 }
1336 }
1337
1338 impl<'a, M: NttFriendly> $op_trait<StaticModInt<M>> for &'a Polynomial<M> {
1339 type Output = Polynomial<M>;
1340 fn $op(self, other: StaticModInt<M>) -> Polynomial<M> {
1341 self.clone().$op(other)
1342 }
1343 }
1344 impl<'a, M: NttFriendly> $op_trait<&'a StaticModInt<M>> for Polynomial<M> {
1345 type Output = Polynomial<M>;
1346 fn $op(mut self, other: &'a StaticModInt<M>) -> Polynomial<M> {
1347 self.$op_assign(other);
1348 self
1349 }
1350 }
1351 impl<'a, M: NttFriendly> $op_trait<&'a StaticModInt<M>> for &'a Polynomial<M> {
1352 type Output = Polynomial<M>;
1353 fn $op(self, other: &'a StaticModInt<M>) -> Polynomial<M> {
1354 self.clone().$op(other)
1355 }
1356 }
1357 impl<M: NttFriendly> $op_trait<StaticModInt<M>> for Polynomial<M> {
1358 type Output = Polynomial<M>;
1359 fn $op(mut self, other: StaticModInt<M>) -> Polynomial<M> {
1360 self.$op_assign(other);
1361 self
1362 }
1363 }
1364 )*
1365 }
1366}
1367
1368impl_binop! {
1369 (add, add_assign, Add, AddAssign),
1370 (sub, sub_assign, Sub, SubAssign),
1371 (mul, mul_assign, Mul, MulAssign),
1372 (div, div_assign, Div, DivAssign),
1373 (rem, rem_assign, Rem, RemAssign),
1374 (bitand, bitand_assign, BitAnd, BitAndAssign),
1375}
1376
1377impl<M: NttFriendly> Neg for Polynomial<M> {
1378 type Output = Polynomial<M>;
1379 fn neg(mut self) -> Polynomial<M> {
1380 for c in &mut self.0 {
1381 *c = -*c;
1382 }
1383 self
1384 }
1385}
1386
1387impl<'a, M: NttFriendly> Neg for &'a Polynomial<M> {
1388 type Output = Polynomial<M>;
1389 fn neg(self) -> Polynomial<M> { -self.clone() }
1390}
1391
1392impl<M: NttFriendly> ShlAssign<usize> for Polynomial<M> {
1393 fn shl_assign(&mut self, sh: usize) {
1394 if !self.0.is_empty() {
1395 self.0.splice(0..0, (0..sh).map(|_| StaticModInt::new(0)));
1396 }
1397 }
1398}
1399
1400impl<M: NttFriendly> Shl<usize> for Polynomial<M> {
1401 type Output = Polynomial<M>;
1402 fn shl(mut self, sh: usize) -> Self::Output {
1403 self.shl_assign(sh);
1404 self
1405 }
1406}
1407
1408impl<'a, M: NttFriendly> Shl<usize> for &'a Polynomial<M> {
1409 type Output = Polynomial<M>;
1410 fn shl(self, sh: usize) -> Self::Output { self.clone().shl(sh) }
1411}
1412
1413impl<M: NttFriendly> ShrAssign<usize> for Polynomial<M> {
1414 fn shr_assign(&mut self, sh: usize) {
1415 if !self.0.is_empty() {
1416 self.0.splice(0..sh.min(self.0.len()), None);
1417 }
1418 }
1419}
1420
1421impl<M: NttFriendly> Shr<usize> for Polynomial<M> {
1422 type Output = Polynomial<M>;
1423 fn shr(mut self, sh: usize) -> Self::Output {
1424 self.shr_assign(sh);
1425 self
1426 }
1427}
1428
1429impl<'a, M: NttFriendly> Shr<usize> for &'a Polynomial<M> {
1430 type Output = Polynomial<M>;
1431 fn shr(self, sh: usize) -> Self::Output { self.clone().shr(sh) }
1432}
1433
1434#[test]
1435fn sanity_check() {
1436 type Poly = Polynomial<modint::Mod998244353>;
1437
1438 let f: Poly = vec![0, 1, 2, 3, 4].into();
1439 let g = Poly::from(&[0, 1, 2, 4, 8][..]);
1440 assert_eq!(&f * g, Poly::from([0, 0, 1, 4, 11, 26, 36, 40, 32]));
1441
1442 let x: Poly = [0, 1].into();
1443 let exp_recip: Vec<_> =
1444 x.exp(10).0.into_iter().map(|x| x.recip().get()).collect();
1445 assert_eq!(exp_recip, [1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880]);
1446
1447 let one_x: Poly = [1, -1].into();
1448 let log_diff = one_x.log(10).differential();
1449 assert_eq!(log_diff, Poly::from([-1; 9]));
1450
1451 let h: Poly = [1, 9, 2, 6, 8, 3].into();
1452 let x_ten: Poly =
1453 (0..9).map(|_| 0).chain(Some(1)).collect::<Vec<_>>().into();
1454 assert_eq!((&h * h.recip(10)) % &x_ten, Poly::from([1]));
1455
1456 assert_eq!((&f / &x).integral(), &x * Poly::from([1; 4]));
1457
1458 let x1: Poly = [1; 2].into();
1459 assert_eq!(x1.pow(5, 10), &x1 * &x1 * &x1 * &x1 * &x1);
1460
1461 assert_eq!(x1.pow(998244352, 10) * &x1 % &x_ten, x1.pow(998244353, 10));
1462}
1463
1464#[test]
1465fn fft() {
1466 type Poly = Polynomial<modint::Mod998244353>;
1467
1468 const N: usize = 4 + 4 + 4 + 1;
1469 let one: Poly = [1].into();
1470 let f: Poly = [0, 1, 2, 3, 4].into();
1471 let g: Poly = [0, 1, 2, 4, 8].into();
1472 let h: Poly = [0, 6, 5, 4, 3].into();
1473
1474 let fft = |f: &Poly| {
1475 let mut f = f.clone();
1476 f.fft_butterfly(N);
1477 f
1478 };
1479 let ifft = |f: &Poly| {
1480 let mut f = f.clone();
1481 f.fft_inv_butterfly(N);
1482 f
1483 };
1484
1485 let fone: Poly = [1; N.next_power_of_two() as usize].into();
1486 let ff = fft(&f);
1487 let fg = fft(&g);
1488 let fh = fft(&h);
1489
1490 assert_eq!(fft(&(&f + &one)), fft(&f) + &fone);
1491
1492 assert_eq!(f, ifft(&ff));
1493 assert_eq!(&f + &h, ifft(&(&ff + &fh)));
1494 assert_eq!(&f * &g, ifft(&(&ff & &fg)));
1495
1496 assert_eq!(&f * &g * &h, ifft(&(&ff & &fg & &fh)));
1497 assert_eq!(f * g + h, ifft(&((ff & fg) + fh)));
1498}
1499
1500#[test]
1501fn recip() {
1502 type Mi = modint::ModInt998244353;
1503 type Poly = Polynomial<modint::Mod998244353>;
1504
1505 let f: Poly = [1, 2, 3, 4].into();
1506 assert_eq!(f.recip(10), f.recip_naive(10));
1507
1508 let n = 100;
1509 let f = Poly::from([1, -1]).recip(n).integral();
1510 for i in 1..=n {
1511 assert_eq!((f.get(i) * Mi::new(i)).get(), 1);
1512 }
1513}
1514
1515#[test]
1516fn pow() {
1517 type Poly = Polynomial<modint::Mod998244353>;
1518
1519 let f: Poly = [0, 0, 0, 2, 1, 3].into();
1520
1521 for len in 0..100 {
1522 let mut g = Poly::from([1]).truncated(len);
1523 for k in 0..=10 {
1524 assert_eq!(f.pow(k, len), g, "({})^{}", f, k);
1525
1526 g *= &f;
1527 g.truncate(len);
1528 }
1529 }
1530}
1531
1532#[test]
1533fn polyeqn() {
1534 type Poly = Polynomial<modint::Mod998244353>;
1535 type Mi = modint::ModInt998244353;
1536
1537 let f: Poly = [1, 2, 3, 4, 5].into();
1538 let n = 10;
1539 let g = Poly::from([1])
1540 .polyeqn(n, |y, n| (&f - y.recip(n)) * (y * y).truncated(n));
1541 assert_eq!(g, f.recip(n));
1542
1543 let cat = Poly::from([1]).polyeqn(n, |y, n| {
1544 let f = ((y * y) << 1) - y + Mi::new(1);
1545 let df = (y << 1) * Mi::new(2) - Mi::new(1);
1546 (f.truncated(n) * df.recip(n)).truncated(n)
1547 });
1548 assert_eq!(cat, Poly::from([1, 1, 2, 5, 14, 42, 132, 429, 1430, 4862]));
1549}
1550
1551#[test]
1552fn fode() {
1553 type Poly = Polynomial<modint::Mod998244353>;
1554 type Mi = modint::ModInt998244353;
1555
1556 let one = Mi::new(1);
1557 let two = Mi::new(2);
1558 let three = Mi::new(3);
1559 let x: Poly = [0, 1].into();
1560
1561 let n = 20;
1562 let f_df = |y: &Poly, n| {
1563 let d = y - &x;
1564 ((&d * &d + one).truncated(n), &d * two)
1566 };
1567 let y = Poly::from([1]).fode(n + 1, f_df);
1568
1569 assert_eq!(f_df(&y, n).0, y.differential());
1571
1572 let f_df = |y: &Poly, n| {
1573 let d = y - &x;
1574 let dd = (&d * &d).truncated(n);
1576 ((&dd * &d + one).truncated(n), &dd * three)
1577 };
1578 let y = Poly::from([2]).fode(n + 1, f_df);
1579
1580 assert_eq!(((&y - &x) / two).recip(n).pow(2, n), Poly::from([1, -8]));
1583 assert_eq!(f_df(&y, n).0, y.differential());
1584
1585 let catalan = |y: &Poly, n| {
1586 let xy2r = (-((y * Mi::new(2)) << 1) + Mi::new(1)).recip(n);
1588 let f = ((y * y).truncated(n) * &xy2r).truncated(n);
1589 let df = (y * Mi::new(2) * (-(y << 1) + Mi::new(1))).truncated(n)
1590 * (&xy2r * &xy2r).truncated(n);
1591 (f, df.truncated(n))
1592 };
1593 let y = Poly::from([1]).fode(10, catalan);
1594 assert_eq!(y, [1, 1, 2, 5, 14, 42, 132, 429, 1430, 4862].into());
1595}
1596
1597#[test]
1598fn fibonacci() {
1599 type Poly = Polynomial<modint::Mod998244353>;
1600
1601 let p: Poly = [1].into();
1602 let q: Poly = [1, -1, -1].into();
1603
1604 let n = 10;
1605 let expected = (&p * q.recip(n)).truncated(n);
1606
1607 let actual: Vec<_> = (0..n).map(|i| p.div_nth(&q, i)).collect();
1608 let actual: Poly = actual.into();
1609
1610 assert_eq!(actual, expected);
1611}
1612
1613#[test]
1614fn butterfly_double() {
1615 type Poly = Polynomial<modint::Mod998244353>;
1616
1617 let f: Poly = [1, 2, 3, 4, 5].into();
1618 let fft = |f: &Poly, n| {
1619 let mut f = f.clone();
1620 f.fft_butterfly(n);
1621 f
1622 };
1623 let mut ff8_dbl = fft(&f, 8);
1624 let ff16 = fft(&f, 16);
1625 ff8_dbl.fft_butterfly_double(16);
1626 assert_eq!(ff8_dbl, ff16);
1627}
1628
1629#[test]
1630fn sin_cos() {
1631 type Mi = modint::ModInt998244353;
1632 type Poly = Polynomial<modint::Mod998244353>;
1633
1634 let n = 100;
1635 let zero: Poly = [0].into();
1636 let x: Poly = [0, 1].into();
1637
1638 let exp_x = x.exp(n);
1639 let (exp, o) = x.circular(&zero, n);
1640
1641 assert_eq!(exp, exp_x);
1642 assert_eq!(o, zero);
1643
1644 let (cos, sin) = zero.circular(&x, n);
1645 for i in 0..n {
1646 let sgn = Mi::new(if i / 2 % 2 == 0 { 1 } else { -1 });
1647 if i % 2 == 0 {
1648 assert_eq!(cos.get(i), sgn * exp_x.get(i));
1649 assert_eq!(sin.get(i).get(), 0);
1650 } else {
1651 assert_eq!(cos.get(i).get(), 0);
1652 assert_eq!(sin.get(i), sgn * exp_x.get(i));
1653 }
1654 }
1655
1656 let z = zero.circular(&Poly::from([0, 1, 1]), n);
1659
1660 let (cos2, sin2): (Poly, Poly) = {
1661 let mut cos2 = vec![Mi::new(0); n];
1662 let mut sin2 = vec![Mi::new(0); n];
1663 for i in (0..n).step_by(2) {
1664 cos2[i] = cos.get(i / 2);
1665 sin2[i] = sin.get(i / 2);
1666 }
1667 (cos2.into(), sin2.into())
1668 };
1669
1670 assert_eq!(z.0, (&cos * &cos2 - &sin * &sin2).truncated(n));
1671 assert_eq!(z.1, (&sin * &cos2 + &cos * &sin2).truncated(n));
1672}