Skip to main content

nekolib/math/
interpolation.rs

1//! Lagrange 補間。
2
3use super::const_div;
4use super::gcd_recip;
5
6use const_div::ConstDiv;
7use gcd_recip::GcdRecip;
8
9/// Lagrange 補間。
10///
11/// 与えられた $\\langle f(0), f(1), \\dots, f(n-1)\\rangle$ に対して $\\hat{f}(i)=f(i)$
12/// ($i=0,1,\\dots,n-1$) なる $n$ 次多項式 $\\hat{f}$ を考える。
13/// この $\\hat{f}$ に対して $\\hat{f}(x)\\bmod p$ を返す。
14///
15/// # Idea
16/// `todo!()`
17///
18/// # Complexity
19/// 前処理に $O(n)$ 時間、$\\hat{f}(x)$ を求めるのに $O(n)$ 時間。
20///
21/// # Examples
22/// ```
23/// use nekolib::math::Interpolation;
24///
25/// let f = Interpolation::with(vec![0, 1, 3], 998244353);
26/// assert_eq!(f.interpolate(0), 0);
27/// assert_eq!(f.interpolate(3), 6);
28/// assert_eq!(f.interpolate(4), 10);
29/// assert_eq!(f.interpolate(100000000), 722404071);
30/// ```
31///
32/// # See also
33/// - <https://rsk0315.hatenablog.com/entry/2019/04/25/141012>
34pub struct Interpolation {
35    first: Vec<u64>,
36    fact_recip: Vec<u64>,
37    cd: ConstDiv,
38    modulo: u64,
39}
40
41impl Interpolation {
42    pub fn with(first: Vec<u64>, modulo: u64) -> Self {
43        let n = first.len();
44        let cd = ConstDiv::new(modulo);
45        let r = (2..n as u64).reduce(|x, y| cd.rem(x * y)).unwrap_or(1);
46        let mut fact_recip = vec![1; n];
47        fact_recip[n - 1] = r.gcd_recip(modulo).1;
48        for i in (2..n).rev() {
49            fact_recip[i - 1] = cd.rem(fact_recip[i] * i as u64);
50        }
51        Self { first, fact_recip, cd, modulo }
52    }
53    pub fn interpolate(&self, x: u64) -> u64 {
54        if (x as usize) < self.first.len() {
55            return self.first[x as usize];
56        }
57        let cd = self.cd;
58        let modulo = self.modulo;
59        let n = self.first.len() - 1;
60        // omega = (x-0) * ... (x-n)
61        let omega = (0..=n as u64)
62            .map(|i| cd.rem(x + modulo - i))
63            .reduce(|acc, x| cd.rem(acc * x))
64            .unwrap();
65        let sigma: u64 = (0..=n)
66            .map(|i| {
67                let wi = cd.rem(self.fact_recip[i] * self.fact_recip[n - i]);
68                let sgn = if (n - i) % 2 != 0 { modulo - wi } else { wi };
69                let tmp = cd.rem(self.first[i] * sgn);
70                tmp * (x + modulo - i as u64).gcd_recip(modulo).1
71            })
72            .sum();
73        cd.rem(omega * cd.rem(sigma))
74    }
75}