Skip to main content

nekolib/math/
is_close_float.rs

1pub trait IsCloseFloat {
2    fn is_close(self, other: Self, rel_tol: Self, abs_tol: Self) -> bool;
3}
4
5impl IsCloseFloat for f64 {
6    // See also: <https://github.com/scikit-hep/scikit-hep/blob/207cf827851d98c453c655e56bd0ee36f8f2b045/skhep/math/isclose.py#L32>
7    fn is_close(self, other: f64, rel_tol: f64, abs_tol: f64) -> bool {
8        assert!(
9            rel_tol >= 0.0 && abs_tol >= 0.0,
10            "error tolerances must be >= 0.0"
11        ); // error tolerances are not NaNs.
12
13        if self == other {
14            // short-circuit; including oo == oo, -oo == -oo.
15            return true;
16        }
17        if self.is_nan() && other.is_nan() {
18            // if self and other are NaNs, judge should accept it.
19            // cf. <https://atcoder.jp/contests/abc280/tasks/abc280_f>
20            return true;
21        }
22        if self.is_nan() || other.is_nan() {
23            // a number is not equal to NaN.
24            return false;
25        }
26        let diff = (self - other).abs();
27        (diff <= (rel_tol * other).abs() && diff <= (rel_tol * self).abs())
28            || diff <= abs_tol
29    }
30}
31
32#[test]
33fn sanity_check() {
34    let oo = 1.0_f64 / 0.0;
35    let neg_oo = -oo;
36    let nan = 0.0_f64 / 0.0;
37
38    // numbers
39    assert!(2.0_f64.is_close(3.0, 0.5, 0.0)); // |3.0 - 2.0| / 2.0
40    assert!(2.0_f64.is_close(3.0, 0.0, 1.0)); // |3.0 - 2.0|
41    assert!(!2.0_f64.is_close(3.0, 0.499, 0.0));
42    assert!(!2.0_f64.is_close(3.0, 0.0, 0.999));
43
44    // infinities
45    assert!(oo.is_close(oo, 0.0, 0.0));
46    assert!(!oo.is_close(neg_oo, 0.0, 0.0));
47    assert!(neg_oo.is_close(neg_oo, 0.0, 0.0));
48    assert!(oo.is_close(neg_oo, oo, oo));
49    assert!(oo.is_close(2.0, oo, oo));
50    assert!(!oo.is_close(2.0, 0.0, 0.0));
51
52    // nans
53    assert!(nan.is_close(nan, 0.0, 0.0));
54    assert!(!nan.is_close(0.0, oo, oo));
55    assert!(!nan.is_close(oo, 0.0, 0.0));
56    assert!(!nan.is_close(oo, oo, oo));
57}
58
59#[test]
60#[should_panic]
61fn panic_nan_tol() {
62    let nan = 0.0_f64 / 0.0;
63    eprintln!("nan < 0.0: {:?}?", nan < 0.0);
64    assert!(0.0_f64.is_close(0.0, nan, nan));
65}