Skip to main content

nekolib/math/
const_div.rs

1//! 定数除算。
2
3use std::fmt::{self, Debug};
4
5/// 定数除算。
6///
7/// 除算命令は重いので、加減算や乗算で置き換えることを考える。
8/// 同じ値で何度も除算する際には、あらかじめ置き換える値を先に求めておくことで高速化できる。
9///
10/// Barrett reduction に基づく。$a\\lt n^2$ に対して、$\\lfloor a/n\\rfloor$ と $a\\bmod n$
11/// を求めることができる。ちゃんと考察すれば、[この制約は除ける][`ConstDiv`]。
12/// 実際、コンパイラは同様の最適化を行う。
13///
14/// [`ConstDiv`]: struct.ConstDiv.html
15///
16/// ```asm
17/// example::div2:
18///         mov     rax, rdi
19///         shr     rax
20///         ret
21/// ```
22/// ```asm
23/// example::div3:
24///         mov     rax, rdi
25///         movabs  rcx, -6148914691236517205
26///         mul     rcx
27///         mov     rax, rdx
28///         shr     rax
29///         ret
30/// ```
31/// ```asm
32/// example::div63:
33///         movabs  rcx, 292805461487453201
34///         mov     rax, rdi
35///         mul     rcx
36///         sub     rdi, rdx
37///         shr     rdi
38///         lea     rax, [rdi + rdx]
39///         shr     rax, 5
40///         ret
41///
42/// example::div64:
43///         mov     rax, rdi
44///         shr     rax, 6
45///         ret
46///
47/// example::div65:
48///         mov     rax, rdi
49///         movabs  rcx, 1135184250689818561
50///         mul     rcx
51///         mov     rax, rdx
52///         shr     rax, 2
53///         ret
54/// ```
55///
56/// ```
57/// fn div63(rdi: u64) -> u64 {
58///     let rdx = ((rdi as u128 * 0x410410410410411_u128) >> 64) as u64;
59///     (((rdi - rdx) >> 1) + rdx) >> 5
60/// }
61///
62/// fn div64(rdi: u64) -> u64 { rdi >> 6 }
63///
64/// fn div65(rdi: u64) -> u64 {
65///     ((rdi as u128 * 0xFC0FC0FC0FC0FC1_u128) >> 66) as u64
66/// }
67///
68/// for i in 0..=100000 {
69///     assert_eq!(div63(i), i / 63);
70///     assert_eq!(div64(i), i / 64);
71///     assert_eq!(div65(i), i / 65);
72/// }
73/// ```
74///
75/// $$ \\begin{aligned}
76/// \\lfloor n/63\\rfloor &= (((n-m)\\gg 1) + m)\\gg 5\\text{, where }
77/// m=(n\\cdot\\lceil 2^{64}/63\\rceil)\\gg 64 \\\\
78/// \\lfloor n/64\\rfloor &= n\\gg 6 \\\\
79/// \\lfloor n/65\\rfloor &= (n\\cdot\\lceil 2^{66}/65\\rceil)\\gg 66
80/// \\end{aligned} $$
81///
82/// 剰余算については、$n\\bmod d = n-\\lfloor n/d\\rfloor\\cdot d$ に基づく。
83/// $d$ を掛ける際には定数乗算の最適化(加減算とシフトを用いるなど)を行っていそう。
84///
85/// # Naming
86/// 除数の 2 乗未満の入力を仮定することから `2` をつけている。
87///
88/// # References
89/// - <https://rsk0315.hatenablog.com/entry/2021/01/18/065720#Barrett-reduction-%E3%81%AE%E8%A9%B1>
90/// - <https://godbolt.org/z/snq4nvTP6>
91#[derive(Clone, Copy, Debug, Eq, PartialEq)]
92pub struct ConstDiv2 {
93    n: u64,
94    recip: u128,
95}
96
97impl ConstDiv2 {
98    pub fn new(n: u64) -> Self {
99        let recip = 1_u64.wrapping_add(std::u64::MAX / n) as u128;
100        Self { n, recip }
101    }
102    pub fn quot(&self, z: u64) -> u64 {
103        if self.n == 1 {
104            return z;
105        }
106        let x = ((self.recip * z as u128) >> 64) as u64;
107        match x.checked_mul(self.n) {
108            Some(xn) if xn <= z => x,
109            _ => x - 1,
110        }
111    }
112    pub fn rem(&self, z: u64) -> u64 {
113        if self.n == 1 {
114            return 0;
115        }
116        let x = ((self.recip * z as u128) >> 64) as u64;
117        let v = z.wrapping_sub(x.wrapping_mul(self.n));
118        if self.n <= v {
119            v.wrapping_add(self.n)
120        } else {
121            v
122        }
123    }
124}
125
126/// 定数除算。
127///
128/// 除算命令は重いので、加減算や乗算で置き換えることを考える。
129/// 同じ値で何度も除算する際には、あらかじめ置き換える値を先に求めておくことで高速化できる。
130///
131/// 以下、$d$ による除算を行うとする。$d = 2^s$ であれば $s$ bit 右シフトするだけなので、$2$
132/// べきではないとする。magic number $M\_d$ とシフト幅 $s$
133/// を求めておき、次の式に基づいて計算する。
134/// $$ \\lfloor n/d\\rfloor
135/// = \\left\\lfloor\\frac{M\_d\\cdot n}{2^s}\\right\\rfloor. $$
136/// $M\_d$ は、ある $0\\le r\\lt d$ が存在して次の形になる。
137/// $$ M\_d = \\frac{2^s+r}{d} = 1+\\left\\lfloor\\frac{2^s-1}{d}\\right\\rfloor. $$
138///
139/// $M\_d$ と $s$ が満たすべき性質について考える。$0\\le n\\lt 2^w$
140/// に対して常に次の式が成り立ってほしい。$w$ はワードサイズで、ここでは $w=64$ とする。
141/// $$ \\lfloor n/d\\rfloor
142/// = \\left\\lfloor\\frac{2\^s+r}{d}\\cdot\\frac{n}{2^s}\\right\\rfloor
143/// = \\left\\lfloor\\frac{n\\vphantom{2^s}}{d} + \\frac{r\\cdot n}{2^s}\\right\\rfloor. $$
144///
145/// 有理数と床関数の性質から、$r\\cdot n/2^s \\lt 1/d$ が常に成り立てばよい。
146/// このとき、$0\\le M\_d\\lt 2^{w+1}$ をみたす $M\_d$ が存在することを示す。
147/// すなわち、$\\lfloor M\_d/2^w\\rfloor\\in\\{0, 1\\}$ となる。`todo!()`
148///
149/// さて、$M\_d$ が見つかったとする。$0\\le M\_d\\lt 2^w$ であれば上の式に基づいて、
150/// 直接計算できる。一方で、$2^w\\le M\_d\\lt 2^{w+1}$ の場合はワードサイズに収まらないので、
151/// 少々工夫する必要がある。$M\_d-2\^w$ はワードサイズに収まるので、それを利用する。
152/// `todo!()`
153///
154/// # References
155/// - Warren, Henry S. _Hacker's delight_. Pearson Education, 2013.
156#[derive(Clone, Copy, Debug, Eq, PartialEq)]
157pub struct ConstDiv {
158    n: u64,
159    di: DivAlgo,
160}
161
162#[derive(Clone, Copy, Eq, PartialEq)]
163enum DivAlgo {
164    Shr(u32, u64),
165    MulShr(u64, u32),
166    MulAddShr(u64, u32),
167    Ge(u64),
168}
169use DivAlgo::{Ge, MulAddShr, MulShr, Shr};
170
171impl Debug for DivAlgo {
172    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173        let res = match self {
174            Shr(s, a) => format!("|n| n >> {}, |n| n & 0x{:016X}", s, a),
175            MulShr(m, s) => format!("|n| (n * 0x{:016X}) >> {}", m, s),
176            MulAddShr(m, s) => {
177                format!("|n| (n + ((n * 0x{:016X}) >> 64) >> 1) >> {}", m, s)
178            }
179            Ge(g) => {
180                let q =
181                    format!("|n| if n >= 0x{:016x} {{ 1 }} else {{ 0 }}", g);
182                let r = format!(
183                    "|n| if n >= 0x{0:016X} {{ n - 0x{0:016X} }} else {{ n }}",
184                    g
185                );
186                format!("{}, {}", q, r)
187            }
188        };
189        f.write_str(res.as_str())
190    }
191}
192
193impl ConstDiv {
194    pub fn new(n: u64) -> Self {
195        let ns = n.next_power_of_two().trailing_zeros();
196        if n.is_power_of_two() {
197            return Self { n, di: Shr(ns, n - 1) };
198        }
199        if n.leading_zeros() == 0 {
200            return Self { n, di: Ge(n) };
201        }
202        let nc = std::u64::MAX as u128;
203
204        for p in 63 + ns..128 {
205            let n_ = n as u128;
206            let r = ((1_u128 << p) - 1) % n_;
207            if (nc * (n_ - 1 - r)) >> p == 0 {
208                let m = 1 + ((1_u128 << p) - 1 - r) / n_;
209                return if m >> 64 == 0 {
210                    Self { n, di: MulShr(m as u64, p) }
211                } else {
212                    Self { n, di: MulAddShr(m as u64, p - 1 - 64) }
213                };
214            }
215        }
216        unreachable!()
217    }
218    pub fn quot(&self, n: u64) -> u64 {
219        match self.di {
220            Shr(s, _) => n >> s,
221            MulShr(m, s) => ((n as u128 * m as u128) >> s) as u64,
222            MulAddShr(m, s) => {
223                let tmp = ((n as u128 * m as u128) >> 64) as u64;
224                (((n - tmp) >> 1) + tmp) >> s
225            }
226            Ge(g) if n >= g => 1,
227            Ge(_) => 0,
228        }
229    }
230    pub fn rem(&self, n: u64) -> u64 {
231        match self.di {
232            Shr(_, a) => n & a,
233            Ge(g) if n >= g => n - g,
234            Ge(_) => n,
235            _ => n - self.quot(n) * self.n,
236        }
237    }
238}
239
240#[test]
241fn test_small_2() {
242    for n in 1..=500 {
243        let cd = ConstDiv2::new(n);
244        for a in 0..n * n {
245            assert_eq!(cd.quot(a), a / n);
246            assert_eq!(cd.rem(a), a % n);
247        }
248    }
249}
250
251#[test]
252fn test_small() {
253    for n in 1..=500 {
254        let cd = ConstDiv::new(n);
255        for a in 0..5 * n * n {
256            assert_eq!(cd.quot(a), a / n);
257            assert_eq!(cd.rem(a), a % n);
258        }
259        for a in 1..=5 * n * n {
260            let a = std::u64::MAX - a;
261            assert_eq!(cd.quot(a), a / n);
262            assert_eq!(cd.rem(a), a % n);
263        }
264    }
265}
266
267#[test]
268fn test_corner() {
269    for &d in &[(1 << 63) - 1, 1 << 63, (1 << 63) + 1, std::u64::MAX] {
270        let cd = ConstDiv::new(d);
271        for &n in &[0, 1, d - 1, d, d.saturating_add(1), d.saturating_mul(2)] {
272            assert_eq!(cd.quot(n), n / d);
273            assert_eq!(cd.rem(n), n % d);
274        }
275    }
276}