1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
//! Lagrange 補間。

use super::const_div;
use super::gcd_recip;

use const_div::ConstDiv;
use gcd_recip::GcdRecip;

/// Lagrange 補間。
///
/// 与えられた $\\langle f(0), f(1), \\dots, f(n-1)\\rangle$ に対して $\\hat{f}(i)=f(i)$
/// ($i=0,1,\\dots,n-1$) なる $n$ 次多項式 $\\hat{f}$ を考える。
/// この $\\hat{f}$ に対して $\\hat{f}(x)\\bmod p$ を返す。
///
/// # Idea
/// `todo!()`
///
/// # Complexity
/// 前処理に $O(n)$ 時間、$\\hat{f}(x)$ を求めるのに $O(n)$ 時間。
///
/// # Examples
/// ```
/// use nekolib::math::Interpolation;
///
/// let f = Interpolation::with(vec![0, 1, 3], 998244353);
/// assert_eq!(f.interpolate(0), 0);
/// assert_eq!(f.interpolate(3), 6);
/// assert_eq!(f.interpolate(4), 10);
/// assert_eq!(f.interpolate(100000000), 722404071);
/// ```
///
/// # See also
/// - <https://rsk0315.hatenablog.com/entry/2019/04/25/141012>
pub struct Interpolation {
    first: Vec<u64>,
    fact_recip: Vec<u64>,
    cd: ConstDiv,
    modulo: u64,
}

impl Interpolation {
    pub fn with(first: Vec<u64>, modulo: u64) -> Self {
        let n = first.len();
        let cd = ConstDiv::new(modulo);
        let r = (2..n as u64).reduce(|x, y| cd.rem(x * y)).unwrap_or(1);
        let mut fact_recip = vec![1; n];
        fact_recip[n - 1] = r.gcd_recip(modulo).1;
        for i in (2..n).rev() {
            fact_recip[i - 1] = cd.rem(fact_recip[i] * i as u64);
        }
        Self { first, fact_recip, cd, modulo }
    }
    pub fn interpolate(&self, x: u64) -> u64 {
        if (x as usize) < self.first.len() {
            return self.first[x as usize];
        }
        let cd = self.cd;
        let modulo = self.modulo;
        let n = self.first.len() - 1;
        // omega = (x-0) * ... (x-n)
        let omega = (0..=n as u64)
            .map(|i| cd.rem(x + modulo - i))
            .reduce(|acc, x| cd.rem(acc * x))
            .unwrap();
        let sigma: u64 = (0..=n)
            .map(|i| {
                let wi = cd.rem(self.fact_recip[i] * self.fact_recip[n - i]);
                let sgn = if (n - i) % 2 != 0 { modulo - wi } else { wi };
                let tmp = cd.rem(self.first[i] * sgn);
                tmp * (x + modulo - i as u64).gcd_recip(modulo).1
            })
            .sum();
        cd.rem(omega * cd.rem(sigma))
    }
}