nekolib/math/const_div.rs
1//! 定数除算。
2
3use std::fmt::{self, Debug};
4
5/// 定数除算。
6///
7/// 除算命令は重いので、加減算や乗算で置き換えることを考える。
8/// 同じ値で何度も除算する際には、あらかじめ置き換える値を先に求めておくことで高速化できる。
9///
10/// Barrett reduction に基づく。$a\\lt n^2$ に対して、$\\lfloor a/n\\rfloor$ と $a\\bmod n$
11/// を求めることができる。ちゃんと考察すれば、[この制約は除ける][`ConstDiv`]。
12/// 実際、コンパイラは同様の最適化を行う。
13///
14/// [`ConstDiv`]: struct.ConstDiv.html
15///
16/// ```asm
17/// example::div2:
18/// mov rax, rdi
19/// shr rax
20/// ret
21/// ```
22/// ```asm
23/// example::div3:
24/// mov rax, rdi
25/// movabs rcx, -6148914691236517205
26/// mul rcx
27/// mov rax, rdx
28/// shr rax
29/// ret
30/// ```
31/// ```asm
32/// example::div63:
33/// movabs rcx, 292805461487453201
34/// mov rax, rdi
35/// mul rcx
36/// sub rdi, rdx
37/// shr rdi
38/// lea rax, [rdi + rdx]
39/// shr rax, 5
40/// ret
41///
42/// example::div64:
43/// mov rax, rdi
44/// shr rax, 6
45/// ret
46///
47/// example::div65:
48/// mov rax, rdi
49/// movabs rcx, 1135184250689818561
50/// mul rcx
51/// mov rax, rdx
52/// shr rax, 2
53/// ret
54/// ```
55///
56/// ```
57/// fn div63(rdi: u64) -> u64 {
58/// let rdx = ((rdi as u128 * 0x410410410410411_u128) >> 64) as u64;
59/// (((rdi - rdx) >> 1) + rdx) >> 5
60/// }
61///
62/// fn div64(rdi: u64) -> u64 { rdi >> 6 }
63///
64/// fn div65(rdi: u64) -> u64 {
65/// ((rdi as u128 * 0xFC0FC0FC0FC0FC1_u128) >> 66) as u64
66/// }
67///
68/// for i in 0..=100000 {
69/// assert_eq!(div63(i), i / 63);
70/// assert_eq!(div64(i), i / 64);
71/// assert_eq!(div65(i), i / 65);
72/// }
73/// ```
74///
75/// $$ \\begin{aligned}
76/// \\lfloor n/63\\rfloor &= (((n-m)\\gg 1) + m)\\gg 5\\text{, where }
77/// m=(n\\cdot\\lceil 2^{64}/63\\rceil)\\gg 64 \\\\
78/// \\lfloor n/64\\rfloor &= n\\gg 6 \\\\
79/// \\lfloor n/65\\rfloor &= (n\\cdot\\lceil 2^{66}/65\\rceil)\\gg 66
80/// \\end{aligned} $$
81///
82/// 剰余算については、$n\\bmod d = n-\\lfloor n/d\\rfloor\\cdot d$ に基づく。
83/// $d$ を掛ける際には定数乗算の最適化(加減算とシフトを用いるなど)を行っていそう。
84///
85/// # Naming
86/// 除数の 2 乗未満の入力を仮定することから `2` をつけている。
87///
88/// # References
89/// - <https://rsk0315.hatenablog.com/entry/2021/01/18/065720#Barrett-reduction-%E3%81%AE%E8%A9%B1>
90/// - <https://godbolt.org/z/snq4nvTP6>
91#[derive(Clone, Copy, Debug, Eq, PartialEq)]
92pub struct ConstDiv2 {
93 n: u64,
94 recip: u128,
95}
96
97impl ConstDiv2 {
98 pub fn new(n: u64) -> Self {
99 let recip = 1_u64.wrapping_add(std::u64::MAX / n) as u128;
100 Self { n, recip }
101 }
102 pub fn quot(&self, z: u64) -> u64 {
103 if self.n == 1 {
104 return z;
105 }
106 let x = ((self.recip * z as u128) >> 64) as u64;
107 match x.checked_mul(self.n) {
108 Some(xn) if xn <= z => x,
109 _ => x - 1,
110 }
111 }
112 pub fn rem(&self, z: u64) -> u64 {
113 if self.n == 1 {
114 return 0;
115 }
116 let x = ((self.recip * z as u128) >> 64) as u64;
117 let v = z.wrapping_sub(x.wrapping_mul(self.n));
118 if self.n <= v {
119 v.wrapping_add(self.n)
120 } else {
121 v
122 }
123 }
124}
125
126/// 定数除算。
127///
128/// 除算命令は重いので、加減算や乗算で置き換えることを考える。
129/// 同じ値で何度も除算する際には、あらかじめ置き換える値を先に求めておくことで高速化できる。
130///
131/// 以下、$d$ による除算を行うとする。$d = 2^s$ であれば $s$ bit 右シフトするだけなので、$2$
132/// べきではないとする。magic number $M\_d$ とシフト幅 $s$
133/// を求めておき、次の式に基づいて計算する。
134/// $$ \\lfloor n/d\\rfloor
135/// = \\left\\lfloor\\frac{M\_d\\cdot n}{2^s}\\right\\rfloor. $$
136/// $M\_d$ は、ある $0\\le r\\lt d$ が存在して次の形になる。
137/// $$ M\_d = \\frac{2^s+r}{d} = 1+\\left\\lfloor\\frac{2^s-1}{d}\\right\\rfloor. $$
138///
139/// $M\_d$ と $s$ が満たすべき性質について考える。$0\\le n\\lt 2^w$
140/// に対して常に次の式が成り立ってほしい。$w$ はワードサイズで、ここでは $w=64$ とする。
141/// $$ \\lfloor n/d\\rfloor
142/// = \\left\\lfloor\\frac{2\^s+r}{d}\\cdot\\frac{n}{2^s}\\right\\rfloor
143/// = \\left\\lfloor\\frac{n\\vphantom{2^s}}{d} + \\frac{r\\cdot n}{2^s}\\right\\rfloor. $$
144///
145/// 有理数と床関数の性質から、$r\\cdot n/2^s \\lt 1/d$ が常に成り立てばよい。
146/// このとき、$0\\le M\_d\\lt 2^{w+1}$ をみたす $M\_d$ が存在することを示す。
147/// すなわち、$\\lfloor M\_d/2^w\\rfloor\\in\\{0, 1\\}$ となる。`todo!()`
148///
149/// さて、$M\_d$ が見つかったとする。$0\\le M\_d\\lt 2^w$ であれば上の式に基づいて、
150/// 直接計算できる。一方で、$2^w\\le M\_d\\lt 2^{w+1}$ の場合はワードサイズに収まらないので、
151/// 少々工夫する必要がある。$M\_d-2\^w$ はワードサイズに収まるので、それを利用する。
152/// `todo!()`
153///
154/// # References
155/// - Warren, Henry S. _Hacker's delight_. Pearson Education, 2013.
156#[derive(Clone, Copy, Debug, Eq, PartialEq)]
157pub struct ConstDiv {
158 n: u64,
159 di: DivAlgo,
160}
161
162#[derive(Clone, Copy, Eq, PartialEq)]
163enum DivAlgo {
164 Shr(u32, u64),
165 MulShr(u64, u32),
166 MulAddShr(u64, u32),
167 Ge(u64),
168}
169use DivAlgo::{Ge, MulAddShr, MulShr, Shr};
170
171impl Debug for DivAlgo {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 let res = match self {
174 Shr(s, a) => format!("|n| n >> {}, |n| n & 0x{:016X}", s, a),
175 MulShr(m, s) => format!("|n| (n * 0x{:016X}) >> {}", m, s),
176 MulAddShr(m, s) => {
177 format!("|n| (n + ((n * 0x{:016X}) >> 64) >> 1) >> {}", m, s)
178 }
179 Ge(g) => {
180 let q =
181 format!("|n| if n >= 0x{:016x} {{ 1 }} else {{ 0 }}", g);
182 let r = format!(
183 "|n| if n >= 0x{0:016X} {{ n - 0x{0:016X} }} else {{ n }}",
184 g
185 );
186 format!("{}, {}", q, r)
187 }
188 };
189 f.write_str(res.as_str())
190 }
191}
192
193impl ConstDiv {
194 pub fn new(n: u64) -> Self {
195 let ns = n.next_power_of_two().trailing_zeros();
196 if n.is_power_of_two() {
197 return Self { n, di: Shr(ns, n - 1) };
198 }
199 if n.leading_zeros() == 0 {
200 return Self { n, di: Ge(n) };
201 }
202 let nc = std::u64::MAX as u128;
203
204 for p in 63 + ns..128 {
205 let n_ = n as u128;
206 let r = ((1_u128 << p) - 1) % n_;
207 if (nc * (n_ - 1 - r)) >> p == 0 {
208 let m = 1 + ((1_u128 << p) - 1 - r) / n_;
209 return if m >> 64 == 0 {
210 Self { n, di: MulShr(m as u64, p) }
211 } else {
212 Self { n, di: MulAddShr(m as u64, p - 1 - 64) }
213 };
214 }
215 }
216 unreachable!()
217 }
218 pub fn quot(&self, n: u64) -> u64 {
219 match self.di {
220 Shr(s, _) => n >> s,
221 MulShr(m, s) => ((n as u128 * m as u128) >> s) as u64,
222 MulAddShr(m, s) => {
223 let tmp = ((n as u128 * m as u128) >> 64) as u64;
224 (((n - tmp) >> 1) + tmp) >> s
225 }
226 Ge(g) if n >= g => 1,
227 Ge(_) => 0,
228 }
229 }
230 pub fn rem(&self, n: u64) -> u64 {
231 match self.di {
232 Shr(_, a) => n & a,
233 Ge(g) if n >= g => n - g,
234 Ge(_) => n,
235 _ => n - self.quot(n) * self.n,
236 }
237 }
238}
239
240#[test]
241fn test_small_2() {
242 for n in 1..=500 {
243 let cd = ConstDiv2::new(n);
244 for a in 0..n * n {
245 assert_eq!(cd.quot(a), a / n);
246 assert_eq!(cd.rem(a), a % n);
247 }
248 }
249}
250
251#[test]
252fn test_small() {
253 for n in 1..=500 {
254 let cd = ConstDiv::new(n);
255 for a in 0..5 * n * n {
256 assert_eq!(cd.quot(a), a / n);
257 assert_eq!(cd.rem(a), a % n);
258 }
259 for a in 1..=5 * n * n {
260 let a = std::u64::MAX - a;
261 assert_eq!(cd.quot(a), a / n);
262 assert_eq!(cd.rem(a), a % n);
263 }
264 }
265}
266
267#[test]
268fn test_corner() {
269 for &d in &[(1 << 63) - 1, 1 << 63, (1 << 63) + 1, std::u64::MAX] {
270 let cd = ConstDiv::new(d);
271 for &n in &[0, 1, d - 1, d, d.saturating_add(1), d.saturating_mul(2)] {
272 assert_eq!(cd.quot(n), n / d);
273 assert_eq!(cd.rem(n), n % d);
274 }
275 }
276}