suffix_array/
lib.rs

1use std::{
2    cmp::Ordering::{Equal, Greater, Less},
3    collections::{BTreeMap, BTreeSet},
4    ops::Index,
5};
6
7const NONE: usize = 1_usize.wrapping_neg();
8
9#[derive(Clone, Debug, Eq, PartialEq)]
10pub struct SuffixArray<T: Ord> {
11    buf: Vec<T>,
12    sa: Vec<usize>,
13}
14
15impl<T: Ord> From<Vec<T>> for SuffixArray<T> {
16    fn from(buf: Vec<T>) -> Self {
17        let buf_usize = hash(&buf);
18        let sa = sa_is(&buf_usize);
19        Self { buf, sa }
20    }
21}
22
23impl From<String> for SuffixArray<char> {
24    fn from(buf: String) -> Self {
25        let buf: Vec<_> = buf.chars().collect();
26        Self::from_chars(buf)
27    }
28}
29
30impl SuffixArray<u8> {
31    pub fn from_bytes(buf: Vec<u8>) -> Self {
32        let buf_usize = hash_bytes(&buf);
33        let sa = sa_is(&buf_usize);
34        Self { buf, sa }
35    }
36}
37
38impl SuffixArray<char> {
39    pub fn from_chars(buf: Vec<char>) -> Self {
40        let buf_usize = hash_chars(&buf);
41        let sa = sa_is(&buf_usize);
42        Self { buf, sa }
43    }
44}
45
46impl SuffixArray<usize> {
47    pub fn from_hashed(buf: Vec<usize>) -> Self {
48        assert!(Self::is_hashed(&buf));
49        let buf_usize: Vec<_> =
50            buf.iter().map(|x| x + 1).chain(Some(0)).collect();
51        let sa = sa_is(&buf_usize);
52        Self { buf, sa }
53    }
54
55    fn is_hashed(buf: &[usize]) -> bool {
56        let mut count = vec![0; buf.len()];
57        for &x in buf {
58            count[x] += 1;
59        }
60        (0..buf.len())
61            .find(|&i| count[i] == 0)
62            .map(|i| (i..count.len()).all(|i| count[i] == 0))
63            .unwrap_or(true)
64    }
65}
66
67fn hash<T: Ord>(buf: &[T]) -> Vec<usize> {
68    let enc: BTreeMap<_, _> = {
69        let seen: BTreeSet<_> = buf.iter().collect();
70        seen.into_iter().zip(0..).collect()
71    };
72    buf.iter()
73        .map(|x| enc[x] + 1)
74        .chain(Some(0)) // represents '$'
75        .collect()
76}
77
78fn hash_chars(buf: &[char]) -> Vec<usize> {
79    let max = match buf.iter().max() {
80        Some(&c) => c as usize,
81        None => return vec![0], // "$"
82    };
83    let enc = {
84        let mut enc = vec![0; max + 1];
85        for &c in buf {
86            enc[c as usize] = 1;
87        }
88        for i in 1..=max {
89            enc[i] += enc[i - 1];
90        }
91        enc
92    };
93    buf.iter().map(|&x| enc[x as usize]).chain(Some(0)).collect()
94}
95
96fn hash_bytes(buf: &[u8]) -> Vec<usize> {
97    let enc = {
98        let mut enc = vec![0; 256];
99        for &b in buf {
100            enc[b as usize] = 1;
101        }
102        for i in 1..=255 {
103            enc[i] += enc[i - 1];
104        }
105        enc
106    };
107    buf.iter().map(|&x| enc[x as usize]).chain(Some(0)).collect()
108}
109
110#[derive(Clone, Copy, Debug, Eq, PartialEq)]
111enum LsType {
112    L,
113    S(bool), // is leftmost S-type
114}
115
116fn count_freq(buf: &[usize]) -> Vec<usize> {
117    let mut res = vec![0; buf.len()];
118    for &x in buf {
119        res[x] += 1;
120    }
121    res
122}
123
124fn inv_perm(buf: &[usize]) -> Vec<usize> {
125    let mut res = vec![0; buf.len()];
126    for (i, &x) in buf.iter().enumerate() {
127        res[x] = i;
128    }
129    res
130}
131
132fn ls_classify(buf: &[usize]) -> Vec<LsType> {
133    let mut res = vec![LsType::S(false); buf.len()];
134    for i in (0..buf.len() - 1).rev() {
135        res[i] = match buf[i].cmp(&buf[i + 1]) {
136            Less => LsType::S(false),
137            Equal => res[i + 1],
138            Greater => LsType::L,
139        };
140    }
141    for i in 1..buf.len() {
142        if let (LsType::L, LsType::S(_)) = (res[i - 1], res[i]) {
143            res[i] = LsType::S(true);
144        }
145    }
146    res
147}
148
149fn bucket_head(count: &[usize]) -> Vec<usize> {
150    let n = count.len();
151    let mut head: Vec<_> =
152        std::iter::once(&0).chain(&count[..n - 1]).copied().collect();
153    for i in 1..n {
154        head[i] += head[i - 1];
155    }
156    head
157}
158
159fn bucket_tail(count: &[usize]) -> Vec<usize> {
160    let mut tail = count.to_vec();
161    for i in 1..count.len() {
162        tail[i] += tail[i - 1];
163    }
164    tail
165}
166
167fn induce(buf: &[usize], sa: &mut [usize], count: &[usize], ls: &[LsType]) {
168    let mut head = bucket_head(count);
169    for i in 0..sa.len() {
170        let j = sa[i];
171        if j <= buf.len() {
172            if j > 0 && ls[j - 1] == LsType::L {
173                sa[head[buf[j - 1]]] = j - 1;
174                head[buf[j - 1]] += 1;
175            }
176        }
177    }
178    let mut tail = bucket_tail(count);
179    for i in (1..count.len()).rev() {
180        let j = sa[i];
181        if j <= buf.len() {
182            if j > 0 && ls[j - 1] != LsType::L {
183                tail[buf[j - 1]] -= 1;
184                sa[tail[buf[j - 1]]] = j - 1;
185            }
186        }
187    }
188}
189
190fn reduce(buf: &[usize], lms: &[usize], ls: &[LsType]) -> Vec<usize> {
191    if lms.len() <= 1 {
192        return vec![0; lms.len()];
193    }
194
195    let e = |(i0, i1)| {
196        if (ls[i0], ls[i1]) == (LsType::S(true), LsType::S(true)) {
197            Some(true)
198        } else if ls[i0] != ls[i1] || buf[i0] != buf[i1] {
199            Some(false)
200        } else {
201            None
202        }
203    };
204
205    let mut map = vec![0; buf.len()]; // map[lms[0]] = 0
206    map[lms[1]] = 1;
207    let mut x = 1;
208    for i in 2..lms.len() {
209        let eq = buf[lms[i]] == buf[lms[i - 1]]
210            && (lms[i] + 1..).zip(lms[i - 1] + 1..).find_map(e).unwrap();
211        if !eq {
212            x += 1;
213        }
214        map[lms[i]] = x;
215    }
216    (0..buf.len())
217        .filter_map(|i| match ls[i] {
218            LsType::S(true) => Some(map[i]),
219            _ => None,
220        })
221        .collect()
222}
223
224fn sa_is(buf: &[usize]) -> Vec<usize> {
225    let len = buf.len();
226    let count = count_freq(buf);
227    if count.iter().all(|&x| x == 1) {
228        return inv_perm(buf);
229    }
230
231    let ls = ls_classify(buf);
232    let mut sa = vec![NONE; len];
233    let mut tail = bucket_tail(&count);
234    for i in (1..len).rev().filter(|&i| ls[i] == LsType::S(true)) {
235        tail[buf[i]] -= 1;
236        sa[tail[buf[i]]] = i;
237    }
238
239    induce(buf, &mut sa, &count, &ls);
240
241    // lexicographic order
242    let lms: Vec<_> =
243        sa.into_iter().filter(|&i| ls[i] == LsType::S(true)).collect();
244    let rs_sa = sa_is(&reduce(buf, &lms, &ls));
245
246    // appearing order
247    let lms: Vec<_> = (0..len).filter(|&i| ls[i] == LsType::S(true)).collect();
248
249    let mut tail = bucket_tail(&count);
250    let mut sa = vec![NONE; len];
251    for i in rs_sa.into_iter().rev() {
252        let j = lms[i];
253        tail[buf[j]] -= 1;
254        sa[tail[buf[j]]] = j;
255    }
256    induce(buf, &mut sa, &count, &ls);
257
258    sa.into_iter().collect()
259}
260
261impl<T: Ord> SuffixArray<T> {
262    pub fn search(&self, pat: &[T]) -> impl Iterator<Item = usize> + '_ {
263        let lo = {
264            let mut lt = 1_usize.wrapping_neg();
265            let mut ge = self.sa.len();
266            while ge.wrapping_sub(lt) > 1 {
267                let mid = lt.wrapping_add(ge.wrapping_sub(lt) / 2);
268                let pos = self.sa[mid];
269                match self.buf[pos..].cmp(pat) {
270                    Less => lt = mid,
271                    _ => ge = mid,
272                }
273            }
274            ge
275        };
276        if lo >= self.sa.len() {
277            return self.sa[lo..lo].iter().copied();
278        }
279        let hi = {
280            let mut le = lo.wrapping_sub(1);
281            let mut gt = self.sa.len();
282            while gt.wrapping_sub(le) > 1 {
283                let mid = le.wrapping_add(gt.wrapping_sub(le) / 2);
284                let pos = self.sa[mid];
285                let len = pat.len().min(self.buf[pos..].len());
286                match self.buf[pos..pos + len].cmp(pat) {
287                    Greater => gt = mid,
288                    _ => le = mid,
289                }
290            }
291            gt
292        };
293        self.sa[lo..hi].iter().copied()
294    }
295
296    pub fn lcpa(&self) -> Vec<usize> {
297        let n = self.buf.len();
298        let mut rank = vec![0; n + 1];
299        for i in 0..=n {
300            rank[self.sa[i]] = i;
301        }
302        let mut h = 0;
303        let mut res = vec![0; n + 1];
304        for i in 0..n {
305            let j = self.sa[rank[i] - 1];
306            if h > 0 {
307                h -= 1;
308            }
309            while j + h < n && i + h < n {
310                if self.buf[j + h] != self.buf[i + h] {
311                    break;
312                }
313                h += 1;
314            }
315            res[rank[i]] = h;
316        }
317        res
318    }
319
320    pub fn into_inner(self) -> Vec<usize> { self.sa }
321}
322
323impl SuffixArray<char> {
324    pub fn search_str(&self, pat: &str) -> impl Iterator<Item = usize> + '_ {
325        let pat: Vec<_> = pat.chars().collect();
326        self.search(&pat)
327    }
328}
329
330impl<T: Ord> Index<usize> for SuffixArray<T> {
331    type Output = usize;
332    fn index(&self, i: usize) -> &usize { &self.sa[i] }
333}
334
335#[test]
336fn sanity_check() {
337    let buf = b"abracadabra".to_vec();
338    let sa = SuffixArray::from_bytes(buf);
339    assert_eq!(sa.sa, [11, 10, 7, 0, 3, 5, 8, 1, 4, 6, 9, 2]);
340}
341
342#[test]
343fn empty_text() {
344    let sa = SuffixArray::from("".to_owned());
345    assert!(sa.search_str("").eq(Some(0)));
346    assert!(sa.search_str("_").eq(None));
347}
348
349#[test]
350fn empty_pattern() {
351    let sa = SuffixArray::from("empty".to_owned());
352    assert!(sa.search_str("").eq([5, 0, 1, 2, 3, 4]));
353}
354
355#[test]
356fn worst_case() {
357    let k = 22;
358    let a: Vec<_> = (1_usize..1 << k)
359        .map(|i| (k - (i & i.wrapping_neg()).trailing_zeros()) as u8)
360        .collect();
361    let actual = SuffixArray::from_bytes(a);
362
363    let w = 0_usize.count_zeros();
364    let expected: Vec<_> =
365        (0_usize..1 << k).map(|i| !i.reverse_bits() >> (w - k)).collect();
366
367    assert_eq!(actual.sa, expected);
368}