Skip to main content

nekolib/ds/
incremental_line_set.rs

1//! 直線の集合。
2
3use super::btree_bimap;
4
5use std::collections::BTreeMap;
6use std::fmt::{self, Debug};
7
8use btree_bimap::BTreeBimap;
9
10/// 直線の集合。
11///
12/// 以下のクエリを処理する。
13/// - 集合 $S \\gets \\emptyset$ で初期化する。
14/// - 集合 $S$ に 1 次関数 $\\lambda x.\\; ax+b$ を追加する。
15/// - 集合 $S$ 中の関数における、$x=x\_0$ での最小値を返す。
16///
17/// 言い換えると、直線の追加クエリと、特定の $x$ 座標での $y$
18/// 座標の最小値を求めるクエリを捌く。いわゆる CHT。
19///
20/// # Idea
21/// 次の二つの連想配列を管理する。
22/// - $a$ を与えると、$\\lambda x.\\; ax+b \\in S$ なる $b$ を返す。
23/// - $a$ を与えると、$\\lambda x.\\; ax+b \\in S$ が最小となる $x$ の最大値 $x\_a$ を返す。
24///     - こちらは双方向で管理しておく。すなわち、$x\_a\\mapsto a$ の連想配列も持つ。
25///
26/// 保持しておく必要がある直線を対応する区間の昇順に並べると、傾きの降順に並ぶことに気づく。
27/// そこで、追加したい直線の傾きより小さい最大の傾きの直線と、大きい最小の直線と比較し、
28/// 新しい直線が必要かどうかをまず確かめる。
29/// それが必要なら、追加する直線に近い方から順にすでにある直線を見ていき、
30/// 必要なものが見つかるまで削除する。
31///
32/// クエリを整数とすると、以下が成り立つ。
33///
34/// $$ \\begin{aligned}
35/// f(\\lambda x.\\; a\_l x+b\_l, \\lambda x.\\; a\_r x+b\_r)
36/// &= \\max\\,\\{k \\mid a\_l k+b\_l \\le a\_r k+b\_r \\} \\\\
37/// &= \\left\\lfloor\\frac{b\_r-b\_l}{a\_l-a\_r}\\right\\rfloor.
38/// \\end{aligned} $$
39///
40/// # Complexity
41/// |演算|時間計算量|
42/// |---|---|
43/// |`new`|$O(1)$|
44/// |`push`|$O(\\log(\|S\'\|))$|
45/// |`min`|$O(\\log(\|S\'\|))$|
46///
47/// ここで、$S\'$ は $S$ から必要のない直線を除いたものからなる集合である。
48///
49/// # Applications
50/// 次の形式の DP の高速化に使える。
51/// $$ \\mathrm{dp}\[i\] = \\min\_{0\\le j\\lt i} (p(j)+q(j)\\cdot r(i)) +s(i). $$
52/// $\\min\_{0\\le j\\lt i} (\\bullet)$ の部分が、直線 $y=q(j)\\cdot x+p(j)$ の $x=r(i)$
53/// における最小値に相当するためである。$\\mathrm{dp}\[i\]$ の値を求めた後、直線
54/// $y=q(i)\\cdot x+p(i)$ を追加していけばよい。ここで、$p(j)$ や $q(j)$ は
55/// $\\mathrm{dp}\[j\]$ を含んでもよいし含まなくてもよい。どちらにも $\\mathrm{dp}\[j\]$
56/// が含まれない場合には、特に DP 配列のようなものを用意する必要はない。
57///
58/// たとえば、次のようなものが当てはまる。
59/// $$ \\begin{aligned}
60/// \\mathrm{dp}\[i\] &= \\min\_{0\\le j\\lt i} (\\mathrm{dp}\[j\]+(a\_j-a\_i)^2) \\\\
61/// &= \\min\_{0\\le j\\lt i} ((\\mathrm{dp}\[j\]+a\_j^2) + (-2\\cdot a\_j)\\cdot a\_i)+a\_i^2.
62/// \\end{aligned} $$
63///
64/// お気に入りの例として、[次のような問題](https://codeforces.com/contest/660/problem/F)
65/// も解ける:
66/// > 整数列 $a = (a\_0, a\_1, \\dots, a\_{n-1})$ が与えられる。
67/// > これの空でもよい区間 $(a\_l, a\_{l+1}, \\dots, a\_{r-1})$
68/// > に対し、次の値を考える。
69/// > $$ \\sum\_{i=l}^{r-1} (i-l+1)\\cdot a\_i
70/// > = 1\\cdot a\_l+2\\cdot a\_{l+1} + \\dots + (r-l)\\cdot a\_{r-1}. $$
71/// > 全ての区間の選び方におけるこの値の最大値を求めよ。
72/// >
73/// > ### Sample
74/// > ```text
75/// > [5, -1000, 1, -3, 7, -8]
76/// >           [ ------ ] => 1 * 1 + (-3) * 2 + 7 * 3 = 16
77/// > ```
78///
79/// $\\sigma(r) = \\sum\_{i=0}^{r-1} a\_i$、$\\tau(r) = \\sum\_{i=0}^{r-1} (i+1)\\cdot a\_i$
80/// とおくと、次のように変形できる。
81/// $$ \\begin{aligned} \\sum\_{i=l}^{r-1} (i-l+1)\\cdot a\_i &=
82/// \\sum\_{i=l}^{r-1} (i+1)\\cdot a\_i - \\sum\_{i=l}^{r-1} l\\cdot a\_i \\\\
83/// &= (\\tau(r)-\\tau(l)) - l\\cdot (\\sigma(r) - \\sigma(l))
84/// . \\end{aligned} $$
85///
86/// 右端 $r$ を固定したときの最大値を $\\mathrm{dp}\[r\]$ とおくと、
87/// $$ \\begin{aligned} \\mathrm{dp}\[r\] &=
88/// \\max\_{0\\le l\\lt r} (\\tau(r)-\\tau(l)) - l\\cdot(\\sigma(r)-\\sigma(l)) \\\\
89/// &= \\max\_{0\\le l\\lt r} (l\\cdot\\sigma(l)-\\tau(l) - l\\cdot\\sigma(r))+\\tau(r) \\\\
90/// &= -\\min\_{0\\le l\\lt r}(\\tau(l)-l\\cdot\\sigma(l) + l\\cdot\\sigma(r))+\\tau(r)
91/// \\end{aligned} $$
92/// とできる。よって、上記の枠組みで $p(j) = \\tau(j)-j\\cdot\\sigma(j)$、$q(j)=j$、
93/// $r(i)=\\sigma(i)$、$s(i)=\\tau(i)$ としたものと見なせ、$\\sigma(\\bullet)$ や $\\tau(\\bullet)$
94/// の計算を適切に高速化すれば、$O(n\\log(n))$ 時間で解ける。
95///
96/// # Examples
97/// ```
98/// use nekolib::ds::IncrementalLineSet;
99///
100/// let mut ls = IncrementalLineSet::new();
101/// assert_eq!(ls.min(0), None);
102///
103/// ls.push((2, 2));
104/// assert_eq!(ls.min(0), Some(2));
105/// assert_eq!(ls.min(2), Some(6));
106///
107/// ls.push((1, 3));
108/// assert_eq!(ls.min(0), Some(2));
109/// assert_eq!(ls.min(2), Some(5));
110/// assert_eq!(ls.min(5), Some(8));
111///
112/// ls.push((-1, 10));
113/// assert_eq!(ls.min(2), Some(5));
114/// assert_eq!(ls.min(5), Some(5));
115///
116/// assert_eq!(
117///     format!("{:?}", ls),
118///     r"{\x. 2x+2: ..=1, \x. x+3: ..=3, \x. -x+10: ..=2147483647}"
119/// );
120/// ```
121///
122/// ```
123/// use nekolib::ds::IncrementalLineSet;
124///
125/// let a = vec![5, -1000, 1, -3, 7, -8];
126/// let n = a.len();
127///
128/// let sigma = {
129///     let mut sigma = vec![0; n + 1];
130///     for i in 0..n {
131///         sigma[i + 1] = sigma[i] + a[i];
132///     }
133///     sigma
134/// };
135/// let tau = {
136///     let mut tau = vec![0; n + 1];
137///     for i in 0..n {
138///         tau[i + 1] = tau[i] + a[i] * (i + 1) as i64;
139///     }
140///     tau
141/// };
142/// let p = |j: usize| tau[j] - j as i64 * sigma[j];
143/// let q = |j: usize| j as i64;
144/// let r = |i: usize| sigma[i];
145/// let s = |i: usize| tau[i];
146///
147/// let mut ls = IncrementalLineSet::new();
148/// let mut dp = vec![0; n + 1];
149/// ls.push((q(0), p(0)));
150/// for i in 1..=n {
151///     dp[i] = -ls.min(r(i)).unwrap() + s(i);
152///     ls.push((q(i), p(i)));
153/// }
154/// let res = *dp.iter().max().unwrap();
155/// assert_eq!(res, 1 * 1 + (-3) * 2 + 7 * 3);
156/// ```
157///
158/// # References
159/// - <https://noshi91.hatenablog.com/entry/2021/03/23/200810>
160#[derive(Clone, Default)]
161pub struct IncrementalLineSet<I: Ord> {
162    f: BTreeMap<I, I>,
163    range: BTreeBimap<I, I>,
164}
165
166impl<I: ChtInt> IncrementalLineSet<I> {
167    pub fn new() -> Self { Self::default() }
168    pub fn push(&mut self, (a, b): (I, I)) {
169        if self.f.is_empty() {
170            let max = I::oo();
171            self.f.insert(a, b);
172            self.range.insert(a, max);
173            return;
174        }
175        if self.unused((a, b)) {
176            return;
177        }
178        self.remove_unused((a, b));
179        self.insert((a, b));
180    }
181    pub fn min(&self, x: I) -> Option<I> {
182        let a = *self.range.range_right(x..).next()?.1;
183        let b = self.f[&a];
184        Some(x.on_line((a, b)))
185    }
186    pub fn inner_len(&self) -> usize { self.f.len() }
187
188    fn unused(&self, (a, b): (I, I)) -> bool {
189        let (&al, &bl) = match self.f.range(a..).next() {
190            Some((&al, &bl)) if a == al => return bl <= b,
191            Some(s) => s,
192            None => return false,
193        };
194        let (&ar, &br) = match self.f.range(..a).next_back() {
195            Some(s) => s,
196            None => return false,
197        };
198        al.right(bl, (a, b)) >= a.right(b, (ar, br))
199    }
200    fn remove_unused(&mut self, (a, b): (I, I)) {
201        self.f.remove(&a);
202        self.range.remove_left(&a);
203
204        let mut rm = vec![];
205        for ((&all, &bll), (&al, &bl)) in
206            self.f.range(a..).skip(1).zip(self.f.range(a..))
207        {
208            if all.right(bll, (al, bl)) >= al.right(bl, (a, b)) {
209                rm.push(al);
210            } else {
211                break;
212            }
213        }
214        for ((&arr, &brr), (&ar, &br)) in
215            self.f.range(..a).rev().skip(1).zip(self.f.range(..a).rev())
216        {
217            if a.right(b, (ar, br)) >= ar.right(br, (arr, brr)) {
218                rm.push(ar);
219            } else {
220                break;
221            }
222        }
223        for ar in &rm {
224            self.f.remove(ar);
225            self.range.remove_left(ar);
226        }
227    }
228    fn insert(&mut self, (a, b): (I, I)) {
229        if let Some((&al, &bl)) = self.f.range(a..).next() {
230            self.range.insert(al, al.right(bl, (a, b)));
231        };
232        if let Some((&ar, &br)) = self.f.range(..a).next_back() {
233            self.range.insert(a, a.right(b, (ar, br)));
234        } else {
235            self.range.insert(a, I::oo());
236        }
237
238        self.f.insert(a, b);
239    }
240}
241
242struct LineDebugHelper<I>(I, I);
243
244impl<I: ChtInt> Debug for LineDebugHelper<I> {
245    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
246        let s = match (self.0.simplify(), self.1.simplify()) {
247            (0, _) => format!("\\x. {:?}", self.1),
248            (1, _) => format!("\\x. x{:+?}", self.1),
249            (-1, _) => format!("\\x. -x{:+?}", self.1),
250            (_, 0) => format!("\\x. {:?}x", self.0),
251            _ => format!("\\x. {:?}x{:+?}", self.0, self.1),
252        };
253        f.write_str(&s)
254    }
255}
256
257impl<I: ChtInt> Debug for IncrementalLineSet<I> {
258    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
259        f.debug_map()
260            .entries(
261                self.f
262                    .iter()
263                    .rev()
264                    .zip(self.range.range_left(..).rev())
265                    .map(|((&a, &b), (&_, &r))| (LineDebugHelper(a, b), ..=r)),
266            )
267            .finish()
268    }
269}
270
271pub trait ChtInt: Copy + Ord + Default + Debug {
272    fn oo() -> Self;
273    fn right(self, b: Self, line1: (Self, Self)) -> Self;
274    fn on_line(self, line: (Self, Self)) -> Self;
275    fn simplify(self) -> i8;
276}
277
278macro_rules! impl_cht_int {
279    ( $($tt:tt)* ) => { $(
280        impl ChtInt for $tt {
281            // std::$tt::MAX が 1.43.0 で AtCoder は 1.42.0 なのがつらい。
282            fn oo() -> $tt {
283                let w = (0 as $tt).count_zeros();
284                ((1 as $tt) << (w - 1)).wrapping_sub(1)
285            }
286            fn right(self, b: Self, (ar, br): (Self, Self)) -> Self {
287                // a > ar
288                let a = self;
289                (br - b).div_euclid(a - ar)
290            }
291            fn on_line(self, (a, b): (Self, Self)) -> Self { a * self + b }
292            fn simplify(self) -> i8 {
293                match self {
294                    0 => 0,
295                    1 => 1,
296                    -1 => -1,
297                    _ => 2,
298                }
299            }
300        }
301    )* };
302}
303
304impl_cht_int! { i8 i16 i32 i64 i128 isize }
305
306#[test]
307fn test_simple() {
308    let mut ls = IncrementalLineSet::new();
309    assert_eq!(ls.min(1), None);
310
311    let mut f = std::iter::successors(Some(185_i32), |&x| {
312        Some((x * 291 + 748) % 93739)
313    })
314    .map(|x| x % 300 - 150);
315
316    let mut naive = vec![];
317    for _ in 0..5000 {
318        let a = f.next().unwrap();
319        let b = f.next().unwrap();
320        ls.push((a, b));
321        naive.push((a, b));
322        for x in -100..=100 {
323            let expected = naive.iter().map(|&(a, b)| a * x + b).min();
324            let got = ls.min(x);
325            assert_eq!(got, expected);
326        }
327    }
328}
329
330#[test]
331fn test_cross() {
332    // 一点でたくさんの直線が交差する場合のテストを書く
333    let mut ls = IncrementalLineSet::new();
334    // (0, 0) でたくさん交わるようにする
335    ls.push((0, 0));
336    for a in 1..1000 {
337        ls.push((a, 0));
338        assert_eq!(ls.inner_len(), 2);
339    }
340    for a in 1..1000 {
341        ls.push((-a, 0));
342        assert_eq!(ls.inner_len(), 2);
343    }
344}
345
346#[test]
347fn test_many() {
348    // 傾きが 1 ずつ異なる直線がたくさん使われる場合のテストを書く
349    let mut ls = IncrementalLineSet::new();
350    // (0, 0), (1, -1), (2, -3), (3, -6), (4, -10), ...
351    let mut y = 0;
352    let x_max = 1000;
353    for x in 0..=x_max {
354        let a = -x;
355        y += a;
356        // (x, y) を通り、傾きが a
357        // Y - y = a (X - x)
358        // Y = a X - a x + y
359        ls.push((a, -a * x + y));
360        // (-x-1, y) を通り、傾きが -a
361        ls.push((-a, -a * x + y - a));
362        assert_eq!(ls.inner_len(), (2 * x + 1) as usize);
363    }
364    for x in -x_max..=x_max {
365        let y = -x * (x + 1) / 2;
366        assert_eq!(ls.min(x), Some(y));
367    }
368}
369
370#[test]
371fn test_frac() {
372    // ある直線が最小となる区間が格子点を含まない場合のテストを書く
373    let mut ls = IncrementalLineSet::new();
374    ls.push((2, 1)); // [..., -1, 1, 3, ...]
375    ls.push((-5, 6)); // [..., 11, 6, 1, ...]
376    ls.push((0, 3)); // [..., 3, 3, 3, ...]
377    assert_eq!(ls.inner_len(), 2);
378}
379
380#[test]
381fn test_below() {
382    let mut ls = IncrementalLineSet::new();
383    ls.push((0, 2));
384    assert_eq!(ls.min(10), Some(2));
385    ls.push((0, 4));
386    assert_eq!(ls.min(10), Some(2));
387    ls.push((0, 1));
388    assert_eq!(ls.min(10), Some(1));
389    assert_eq!(ls.inner_len(), 1);
390}
391
392#[cfg(test)]
393fn test_cf660_f_internal(a: &[i64], expected: i64) {
394    let n = a.len();
395    let sigma = {
396        let mut sigma = vec![0; n + 1];
397        for i in 0..n {
398            sigma[i + 1] = sigma[i] + a[i];
399        }
400        sigma
401    };
402    let tau = {
403        let mut tau = vec![0; n + 1];
404        for i in 0..n {
405            tau[i + 1] = tau[i] + a[i] * (i + 1) as i64;
406        }
407        tau
408    };
409    let p = |j: usize| tau[j] - j as i64 * sigma[j];
410    let q = |j: usize| j as i64;
411    let r = |i: usize| sigma[i];
412    let s = |i: usize| tau[i];
413
414    let mut ls = IncrementalLineSet::new();
415    let mut dp = vec![0; n + 1];
416    ls.push((q(0), p(0)));
417    for i in 1..=n {
418        dp[i] = -ls.min(r(i)).unwrap() + s(i);
419        ls.push((q(i), p(i)));
420    }
421    let actual = *dp.iter().max().unwrap();
422    assert_eq!(actual, expected);
423}
424
425#[test]
426fn test_cf660_f() {
427    test_cf660_f_internal(&[5, -1000, 1, -3, 7, -8], 16);
428    test_cf660_f_internal(&[1000, 1000, 1001, 1000, 1000], 15003);
429    test_cf660_f_internal(&[-60, -70, -80], 0);
430    test_cf660_f_internal(&[-4], 0);
431    test_cf660_f_internal(&[-3, 6], 9);
432    test_cf660_f_internal(&[8, 1, -6], 10);
433    test_cf660_f_internal(&[9, 2, -5, 1], 13);
434    test_cf660_f_internal(&[10, -3, -3, 8, 2], 37);
435    test_cf660_f_internal(&[3, 1, -9, 1, 2, -10], 5);
436    test_cf660_f_internal(&[-3, -7, -7, -9, -3, 7, -9], 11);
437    test_cf660_f_internal(&[-2, 1, -5, -2, 1, -9, 0, 2], 4);
438    test_cf660_f_internal(&[-1, 10, -8, -9, -7, 8, 6, -6, 7], 38);
439    test_cf660_f_internal(&[-9, -10, -9, 4, 6, 8, 3, -8, 0, 10], 100);
440    test_cf660_f_internal(
441        &[
442            349, -152, -35, -353, -647, -702, 64, 299, -431, -11, -185, 437,
443            237, -103, 1, 448, 23, -308, -689, 329, -409, 309, 424, -93, -192,
444            0, 257, -90, -394, -512, -148, 376, -394, -528, 212, -215, -255,
445            -684, -321, 503, -72, -227, -583, -537, -65, 444, -332, 465, -547,
446            291, -663, -235, 542, -89, -450, -212, 438, 12, 139, -558, -87,
447            433, -462, 79, 35,
448        ],
449        6676,
450    );
451    test_cf660_f_internal(&[7, -5, 3, -9, 8], 10);
452    test_cf660_f_internal(&[-7, 0, 10, 1, -1, -5, 6], 34);
453    test_cf660_f_internal(&[3, -10, -2, 5, 2, -7, 7], 21);
454    test_cf660_f_internal(&[0, -7, 1, -9], 1);
455    test_cf660_f_internal(&[4, -6, 3, 3], 13);
456    test_cf660_f_internal(&[-9, 8, 0, -4, -4, -3, -5, 9, -6, -9], 14);
457    test_cf660_f_internal(&[3, -5, -5, 1, -6, -2], 3);
458    test_cf660_f_internal(&[8, -2, -8, 4, -8, 8, -3, -8, 0], 12);
459    test_cf660_f_internal(&[3, 3, 0, -7, 6, -6], 11);
460    test_cf660_f_internal(&[5, -6, -2, 6, -2, -4, -3], 11);
461}