Skip to main content

nekolib/algo/
karatsuba.rs

1//! Karatsuba 法。
2
3use std::ops::{AddAssign, Mul, SubAssign};
4
5/// Karatsuba 法。Карацуба 法?
6///
7/// $a = (a\_i)$ と $b = (b\_i)$ の積 $a * b$ を求める。
8/// ただし、$a * b$ は以下のように定義される。
9/// $$ (a * b)\_i = \\sum\_{j=0}^i a\_j \\cdot b\_{i-j}. $$
10///
11/// # Idea
12/// `todo!()`
13///
14/// # Complexity
15/// $O(n^{\\log\_2(3)}) \\subset O(n^{1.585})$ time.
16///
17/// # Examples
18/// ```
19/// use nekolib::algo::convolve;
20///
21/// let a = vec![0_i32, 1, 2, 3, 4];
22/// let b = vec![0, 1, 2, 4, 8];
23/// assert_eq!(convolve(&a, &b), [0, 0, 1, 4, 11, 26, 36, 40, 32]);
24/// ```
25pub fn convolve<T>(a: &[T], b: &[T]) -> Vec<T>
26where
27    T: AddAssign + SubAssign + Mul<Output = T> + Default + Clone,
28{
29    let n = a.len();
30    let m = b.len();
31    let nm = n.max(m);
32    let mut a = a.to_vec();
33    a.resize_with(nm, T::default);
34    let mut b = b.to_vec();
35    b.resize_with(nm, T::default);
36    let mut ab = mul(&mut a, &mut b);
37    ab.truncate(n + m - 1);
38    ab
39}
40
41const NAIVE_THRESHOLD: usize = 32;
42
43fn mul<T>(a: &mut [T], b: &mut [T]) -> Vec<T>
44where
45    T: AddAssign + SubAssign + Mul<Output = T> + Default + Clone,
46{
47    assert_eq!(a.len(), b.len());
48
49    let n = a.len();
50    if n <= NAIVE_THRESHOLD {
51        let mut res = vec![T::default(); n + n - 1];
52        for (i, ai) in a.iter().enumerate() {
53            for (j, bj) in b.iter().enumerate() {
54                res[i + j] += ai.clone() * bj.clone();
55            }
56        }
57        return res;
58    }
59
60    let nl = n / 2;
61    let nh = n - nl;
62
63    let (al, ah) = a.split_at_mut(nl);
64    let (bl, bh) = b.split_at_mut(nl);
65
66    let t = mul(al, bl);
67    let u = mul(ah, bh);
68
69    let mut alh = ah.to_vec();
70    let mut blh = bh.to_vec();
71    for i in 0..nl {
72        alh[i] += al[i].clone();
73        blh[i] += bl[i].clone();
74    }
75
76    let mut res = vec![T::default(); n + n - 1];
77    let mut v = mul(&mut alh, &mut blh);
78    for (i, ti) in t.iter().enumerate() {
79        v[i] -= ti.clone();
80        res[i] += ti.clone();
81    }
82    for (i, ui) in u.iter().enumerate() {
83        v[i] -= ui.clone();
84        res[nl + nl + i] += ui.clone();
85    }
86
87    if nl != nh {
88        v.pop();
89    }
90
91    for (i, vi) in v.into_iter().enumerate() {
92        res[nl + i] += vi;
93    }
94    res
95}
96
97#[test]
98fn test() {
99    let mut it =
100        std::iter::successors(Some(14025256_i64), |&i| Some(i * i % 20300713))
101            .map(|i| i % 10);
102    let n = 1024;
103    let a: Vec<_> = (0..n).map(|_| it.next().unwrap()).collect();
104    let b: Vec<_> = (0..n).map(|_| it.next().unwrap()).collect();
105
106    let mut expected = vec![0; n + n - 1];
107    for (i, &ai) in a.iter().enumerate() {
108        for (j, &bj) in b.iter().enumerate() {
109            expected[i + j] += ai * bj;
110        }
111    }
112
113    let actual = convolve(&a, &b);
114    assert_eq!(actual, expected);
115}