nekolib/algo/
karatsuba.rs1use std::ops::{AddAssign, Mul, SubAssign};
4
5pub fn convolve<T>(a: &[T], b: &[T]) -> Vec<T>
26where
27 T: AddAssign + SubAssign + Mul<Output = T> + Default + Clone,
28{
29 let n = a.len();
30 let m = b.len();
31 let nm = n.max(m);
32 let mut a = a.to_vec();
33 a.resize_with(nm, T::default);
34 let mut b = b.to_vec();
35 b.resize_with(nm, T::default);
36 let mut ab = mul(&mut a, &mut b);
37 ab.truncate(n + m - 1);
38 ab
39}
40
41const NAIVE_THRESHOLD: usize = 32;
42
43fn mul<T>(a: &mut [T], b: &mut [T]) -> Vec<T>
44where
45 T: AddAssign + SubAssign + Mul<Output = T> + Default + Clone,
46{
47 assert_eq!(a.len(), b.len());
48
49 let n = a.len();
50 if n <= NAIVE_THRESHOLD {
51 let mut res = vec![T::default(); n + n - 1];
52 for (i, ai) in a.iter().enumerate() {
53 for (j, bj) in b.iter().enumerate() {
54 res[i + j] += ai.clone() * bj.clone();
55 }
56 }
57 return res;
58 }
59
60 let nl = n / 2;
61 let nh = n - nl;
62
63 let (al, ah) = a.split_at_mut(nl);
64 let (bl, bh) = b.split_at_mut(nl);
65
66 let t = mul(al, bl);
67 let u = mul(ah, bh);
68
69 let mut alh = ah.to_vec();
70 let mut blh = bh.to_vec();
71 for i in 0..nl {
72 alh[i] += al[i].clone();
73 blh[i] += bl[i].clone();
74 }
75
76 let mut res = vec![T::default(); n + n - 1];
77 let mut v = mul(&mut alh, &mut blh);
78 for (i, ti) in t.iter().enumerate() {
79 v[i] -= ti.clone();
80 res[i] += ti.clone();
81 }
82 for (i, ui) in u.iter().enumerate() {
83 v[i] -= ui.clone();
84 res[nl + nl + i] += ui.clone();
85 }
86
87 if nl != nh {
88 v.pop();
89 }
90
91 for (i, vi) in v.into_iter().enumerate() {
92 res[nl + i] += vi;
93 }
94 res
95}
96
97#[test]
98fn test() {
99 let mut it =
100 std::iter::successors(Some(14025256_i64), |&i| Some(i * i % 20300713))
101 .map(|i| i % 10);
102 let n = 1024;
103 let a: Vec<_> = (0..n).map(|_| it.next().unwrap()).collect();
104 let b: Vec<_> = (0..n).map(|_| it.next().unwrap()).collect();
105
106 let mut expected = vec![0; n + n - 1];
107 for (i, &ai) in a.iter().enumerate() {
108 for (j, &bj) in b.iter().enumerate() {
109 expected[i + j] += ai * bj;
110 }
111 }
112
113 let actual = convolve(&a, &b);
114 assert_eq!(actual, expected);
115}