1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#![allow(dead_code)]

use std::ops::{Range, RangeBounds, RangeInclusive};

use rs01_dict::Rs01Dict;
use usize_bounds::UsizeBounds;

pub struct WaveletMatrix<I> {
    len: usize,
    bitlen: usize,
    buf: Vec<Rs01Dict>,
    zeros: Vec<usize>,
    orig: Vec<I>,
}

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Count3wayResult {
    lt: usize,
    eq: usize,
    gt: usize,
}

impl Count3wayResult {
    fn new(lt: usize, eq: usize, gt: usize) -> Self { Self { lt, eq, gt } }
    pub fn lt(self) -> usize { self.lt }
    pub fn le(self) -> usize { self.lt + self.eq }
    pub fn eq(self) -> usize { self.eq }
    pub fn ge(self) -> usize { self.eq + self.gt }
    pub fn gt(self) -> usize { self.gt }
    pub fn ne(self) -> usize { self.gt + self.lt }
}

impl<I: WmInt> From<Vec<I>> for WaveletMatrix<I> {
    fn from(orig: Vec<I>) -> Self {
        let len = orig.len();
        let bitlen =
            orig.iter().map(|ai| ai.bitlen()).max().unwrap_or(0) as usize;
        let mut whole = orig.clone();
        let mut zeros = vec![0; bitlen];
        let mut buf = vec![];
        for i in (0..bitlen).rev() {
            let mut zero = vec![];
            let mut one = vec![];
            let mut vb = vec![false; len];
            for (j, aj) in whole.into_iter().enumerate() {
                (if aj.test(i) { &mut one } else { &mut zero }).push(aj);
                vb[j] = aj.test(i);
            }
            zeros[i] = zero.len();
            buf.push(Rs01Dict::new(&vb));
            whole = zero;
            whole.append(&mut one);
        }
        buf.reverse();
        Self { len, bitlen, buf, zeros, orig }
    }
}

impl<I: WmInt> WaveletMatrix<I> {
    pub fn count<R: WmIntRange<Int = I>>(
        &self,
        range: impl RangeBounds<usize>,
        value: R,
    ) -> usize {
        self.count_3way(range, value).eq()
    }
    pub fn count_3way<R: WmIntRange<Int = I>>(
        &self,
        range: impl RangeBounds<usize>,
        value: R,
    ) -> Count3wayResult {
        let Range { start: il, end: ir } = range.to_range(self.len);
        let value = value.to_inclusive_range();
        let vl = *value.start();
        let vr = *value.end();
        let (lt, gt) = if vl == vr {
            self.count_3way_internal(il..ir, vl)
        } else {
            let lt = self.count_3way_internal(il..ir, vl).0;
            let gt = self.count_3way_internal(il..ir, vr).1;
            (lt, gt)
        };
        let eq = (ir - il) - (lt + gt);
        Count3wayResult::new(lt, eq, gt)
    }
    fn count_3way_internal(
        &self,
        Range { mut start, mut end }: Range<usize>,
        value: I,
    ) -> (usize, usize) {
        if start == end {
            return (0, 0);
        }
        if value.bitlen() > self.bitlen {
            return (end - start, 0);
        }
        let mut lt = 0;
        let mut gt = 0;
        for i in (0..self.bitlen).rev() {
            let tmp = end - start;
            if !value.test(i) {
                start = self.buf[i].count0(..start);
                end = self.buf[i].count0(..end);
            } else {
                start = self.zeros[i] + self.buf[i].count1(..start);
                end = self.zeros[i] + self.buf[i].count1(..end);
            }
            let len = end - start;
            *(if value.test(i) { &mut lt } else { &mut gt }) += tmp - len;
        }
        (lt, gt)
    }

    pub fn quantile(
        &self,
        range: impl RangeBounds<usize>,
        mut n: usize,
    ) -> Option<I> {
        let Range { mut start, mut end } = range.to_range(self.len);
        if end - start <= n {
            return None;
        }
        let mut res = I::zero();
        for i in (0..self.bitlen).rev() {
            let z = self.buf[i].count0(start..end);
            if n < z {
                start = self.buf[i].count0(..start);
                end = self.buf[i].count0(..end);
            } else {
                res.set(i);
                start = self.zeros[i] + self.buf[i].count1(..start);
                end = self.zeros[i] + self.buf[i].count1(..end);
                n -= z;
            }
        }
        Some(res)
    }
}

pub trait WmInt: Copy + Eq {
    fn test(self, i: usize) -> bool;
    fn set(&mut self, i: usize);
    fn bitlen(self) -> usize;
    fn zero() -> Self;
}

pub trait WmIntRange {
    type Int;
    fn to_inclusive_range(self) -> RangeInclusive<Self::Int>;
}

macro_rules! impl_uint {
    ( $($ty:ty)* ) => { $(
        impl WmInt for $ty {
            fn test(self, i: usize) -> bool { self >> i & 1 != 0 }
            fn set(&mut self, i: usize) { *self |= 1 << i; }
            fn bitlen(self) -> usize {
                let bits = <$ty>::BITS;
                (if self == 0 { 1 } else { bits - self.leading_zeros() }) as _
            }
            fn zero() -> $ty { 0 }
        }
        impl WmIntRange for $ty {
            type Int = $ty;
            fn to_inclusive_range(self) -> RangeInclusive<$ty> { self..=self }
        }
        impl WmIntRange for RangeInclusive<$ty> {
            type Int = $ty;
            fn to_inclusive_range(self) -> RangeInclusive<$ty> { self }
        }
    )* }
}

impl_uint! { u8 u16 u32 u64 u128 usize }