1use std::{
2 cmp::Ordering::{Equal, Greater, Less},
3 collections::{BTreeMap, BTreeSet},
4 ops::Index,
5};
6
7const NONE: usize = 1_usize.wrapping_neg();
8
9#[derive(Clone, Debug, Eq, PartialEq)]
10pub struct SuffixArray<T: Ord> {
11 buf: Vec<T>,
12 sa: Vec<usize>,
13}
14
15impl<T: Ord> From<Vec<T>> for SuffixArray<T> {
16 fn from(buf: Vec<T>) -> Self {
17 let buf_usize = hash(&buf);
18 let sa = sa_is(&buf_usize);
19 Self { buf, sa }
20 }
21}
22
23impl From<String> for SuffixArray<char> {
24 fn from(buf: String) -> Self {
25 let buf: Vec<_> = buf.chars().collect();
26 Self::from_chars(buf)
27 }
28}
29
30impl SuffixArray<u8> {
31 pub fn from_bytes(buf: Vec<u8>) -> Self {
32 let buf_usize = hash_bytes(&buf);
33 let sa = sa_is(&buf_usize);
34 Self { buf, sa }
35 }
36}
37
38impl SuffixArray<char> {
39 pub fn from_chars(buf: Vec<char>) -> Self {
40 let buf_usize = hash_chars(&buf);
41 let sa = sa_is(&buf_usize);
42 Self { buf, sa }
43 }
44}
45
46impl SuffixArray<usize> {
47 pub fn from_hashed(buf: Vec<usize>) -> Self {
48 assert!(Self::is_hashed(&buf));
49 let buf_usize: Vec<_> =
50 buf.iter().map(|x| x + 1).chain(Some(0)).collect();
51 let sa = sa_is(&buf_usize);
52 Self { buf, sa }
53 }
54
55 fn is_hashed(buf: &[usize]) -> bool {
56 let mut count = vec![0; buf.len()];
57 for &x in buf {
58 count[x] += 1;
59 }
60 (0..buf.len())
61 .find(|&i| count[i] == 0)
62 .map(|i| (i..count.len()).all(|i| count[i] == 0))
63 .unwrap_or(true)
64 }
65}
66
67fn hash<T: Ord>(buf: &[T]) -> Vec<usize> {
68 let enc: BTreeMap<_, _> = {
69 let seen: BTreeSet<_> = buf.iter().collect();
70 seen.into_iter().zip(0..).collect()
71 };
72 buf.iter()
73 .map(|x| enc[x] + 1)
74 .chain(Some(0)) .collect()
76}
77
78fn hash_chars(buf: &[char]) -> Vec<usize> {
79 let max = match buf.iter().max() {
80 Some(&c) => c as usize,
81 None => return vec![0], };
83 let enc = {
84 let mut enc = vec![0; max + 1];
85 for &c in buf {
86 enc[c as usize] = 1;
87 }
88 for i in 1..=max {
89 enc[i] += enc[i - 1];
90 }
91 enc
92 };
93 buf.iter().map(|&x| enc[x as usize]).chain(Some(0)).collect()
94}
95
96fn hash_bytes(buf: &[u8]) -> Vec<usize> {
97 let enc = {
98 let mut enc = vec![0; 256];
99 for &b in buf {
100 enc[b as usize] = 1;
101 }
102 for i in 1..=255 {
103 enc[i] += enc[i - 1];
104 }
105 enc
106 };
107 buf.iter().map(|&x| enc[x as usize]).chain(Some(0)).collect()
108}
109
110#[derive(Clone, Copy, Debug, Eq, PartialEq)]
111enum LsType {
112 L,
113 S(bool), }
115
116fn count_freq(buf: &[usize]) -> Vec<usize> {
117 let mut res = vec![0; buf.len()];
118 for &x in buf {
119 res[x] += 1;
120 }
121 res
122}
123
124fn inv_perm(buf: &[usize]) -> Vec<usize> {
125 let mut res = vec![0; buf.len()];
126 for (i, &x) in buf.iter().enumerate() {
127 res[x] = i;
128 }
129 res
130}
131
132fn ls_classify(buf: &[usize]) -> Vec<LsType> {
133 let mut res = vec![LsType::S(false); buf.len()];
134 for i in (0..buf.len() - 1).rev() {
135 res[i] = match buf[i].cmp(&buf[i + 1]) {
136 Less => LsType::S(false),
137 Equal => res[i + 1],
138 Greater => LsType::L,
139 };
140 }
141 for i in 1..buf.len() {
142 if let (LsType::L, LsType::S(_)) = (res[i - 1], res[i]) {
143 res[i] = LsType::S(true);
144 }
145 }
146 res
147}
148
149fn bucket_head(count: &[usize]) -> Vec<usize> {
150 let n = count.len();
151 let mut head: Vec<_> =
152 std::iter::once(&0).chain(&count[..n - 1]).copied().collect();
153 for i in 1..n {
154 head[i] += head[i - 1];
155 }
156 head
157}
158
159fn bucket_tail(count: &[usize]) -> Vec<usize> {
160 let mut tail = count.to_vec();
161 for i in 1..count.len() {
162 tail[i] += tail[i - 1];
163 }
164 tail
165}
166
167fn induce(buf: &[usize], sa: &mut [usize], count: &[usize], ls: &[LsType]) {
168 let mut head = bucket_head(count);
169 for i in 0..sa.len() {
170 let j = sa[i];
171 if j <= buf.len() {
172 if j > 0 && ls[j - 1] == LsType::L {
173 sa[head[buf[j - 1]]] = j - 1;
174 head[buf[j - 1]] += 1;
175 }
176 }
177 }
178 let mut tail = bucket_tail(count);
179 for i in (1..count.len()).rev() {
180 let j = sa[i];
181 if j <= buf.len() {
182 if j > 0 && ls[j - 1] != LsType::L {
183 tail[buf[j - 1]] -= 1;
184 sa[tail[buf[j - 1]]] = j - 1;
185 }
186 }
187 }
188}
189
190fn reduce(buf: &[usize], lms: &[usize], ls: &[LsType]) -> Vec<usize> {
191 if lms.len() <= 1 {
192 return vec![0; lms.len()];
193 }
194
195 let e = |(i0, i1)| {
196 if (ls[i0], ls[i1]) == (LsType::S(true), LsType::S(true)) {
197 Some(true)
198 } else if ls[i0] != ls[i1] || buf[i0] != buf[i1] {
199 Some(false)
200 } else {
201 None
202 }
203 };
204
205 let mut map = vec![0; buf.len()]; map[lms[1]] = 1;
207 let mut x = 1;
208 for i in 2..lms.len() {
209 let eq = buf[lms[i]] == buf[lms[i - 1]]
210 && (lms[i] + 1..).zip(lms[i - 1] + 1..).find_map(e).unwrap();
211 if !eq {
212 x += 1;
213 }
214 map[lms[i]] = x;
215 }
216 (0..buf.len())
217 .filter_map(|i| match ls[i] {
218 LsType::S(true) => Some(map[i]),
219 _ => None,
220 })
221 .collect()
222}
223
224fn sa_is(buf: &[usize]) -> Vec<usize> {
225 let len = buf.len();
226 let count = count_freq(buf);
227 if count.iter().all(|&x| x == 1) {
228 return inv_perm(buf);
229 }
230
231 let ls = ls_classify(buf);
232 let mut sa = vec![NONE; len];
233 let mut tail = bucket_tail(&count);
234 for i in (1..len).rev().filter(|&i| ls[i] == LsType::S(true)) {
235 tail[buf[i]] -= 1;
236 sa[tail[buf[i]]] = i;
237 }
238
239 induce(buf, &mut sa, &count, &ls);
240
241 let lms: Vec<_> =
243 sa.into_iter().filter(|&i| ls[i] == LsType::S(true)).collect();
244 let rs_sa = sa_is(&reduce(buf, &lms, &ls));
245
246 let lms: Vec<_> = (0..len).filter(|&i| ls[i] == LsType::S(true)).collect();
248
249 let mut tail = bucket_tail(&count);
250 let mut sa = vec![NONE; len];
251 for i in rs_sa.into_iter().rev() {
252 let j = lms[i];
253 tail[buf[j]] -= 1;
254 sa[tail[buf[j]]] = j;
255 }
256 induce(buf, &mut sa, &count, &ls);
257
258 sa.into_iter().collect()
259}
260
261impl<T: Ord> SuffixArray<T> {
262 pub fn search(&self, pat: &[T]) -> impl Iterator<Item = usize> + '_ {
263 let lo = {
264 let mut lt = 1_usize.wrapping_neg();
265 let mut ge = self.sa.len();
266 while ge.wrapping_sub(lt) > 1 {
267 let mid = lt.wrapping_add(ge.wrapping_sub(lt) / 2);
268 let pos = self.sa[mid];
269 match self.buf[pos..].cmp(pat) {
270 Less => lt = mid,
271 _ => ge = mid,
272 }
273 }
274 ge
275 };
276 if lo >= self.sa.len() {
277 return self.sa[lo..lo].iter().copied();
278 }
279 let hi = {
280 let mut le = lo.wrapping_sub(1);
281 let mut gt = self.sa.len();
282 while gt.wrapping_sub(le) > 1 {
283 let mid = le.wrapping_add(gt.wrapping_sub(le) / 2);
284 let pos = self.sa[mid];
285 let len = pat.len().min(self.buf[pos..].len());
286 match self.buf[pos..pos + len].cmp(pat) {
287 Greater => gt = mid,
288 _ => le = mid,
289 }
290 }
291 gt
292 };
293 self.sa[lo..hi].iter().copied()
294 }
295
296 pub fn lcpa(&self) -> Vec<usize> {
297 let n = self.buf.len();
298 let mut rank = vec![0; n + 1];
299 for i in 0..=n {
300 rank[self.sa[i]] = i;
301 }
302 let mut h = 0;
303 let mut res = vec![0; n + 1];
304 for i in 0..n {
305 let j = self.sa[rank[i] - 1];
306 if h > 0 {
307 h -= 1;
308 }
309 while j + h < n && i + h < n {
310 if self.buf[j + h] != self.buf[i + h] {
311 break;
312 }
313 h += 1;
314 }
315 res[rank[i]] = h;
316 }
317 res
318 }
319
320 pub fn into_inner(self) -> Vec<usize> { self.sa }
321}
322
323impl SuffixArray<char> {
324 pub fn search_str(&self, pat: &str) -> impl Iterator<Item = usize> + '_ {
325 let pat: Vec<_> = pat.chars().collect();
326 self.search(&pat)
327 }
328}
329
330impl<T: Ord> Index<usize> for SuffixArray<T> {
331 type Output = usize;
332 fn index(&self, i: usize) -> &usize { &self.sa[i] }
333}
334
335#[test]
336fn sanity_check() {
337 let buf = b"abracadabra".to_vec();
338 let sa = SuffixArray::from_bytes(buf);
339 assert_eq!(sa.sa, [11, 10, 7, 0, 3, 5, 8, 1, 4, 6, 9, 2]);
340}
341
342#[test]
343fn empty_text() {
344 let sa = SuffixArray::from("".to_owned());
345 assert!(sa.search_str("").eq(Some(0)));
346 assert!(sa.search_str("_").eq(None));
347}
348
349#[test]
350fn empty_pattern() {
351 let sa = SuffixArray::from("empty".to_owned());
352 assert!(sa.search_str("").eq([5, 0, 1, 2, 3, 4]));
353}
354
355#[test]
356fn worst_case() {
357 let k = 22;
358 let a: Vec<_> = (1_usize..1 << k)
359 .map(|i| (k - (i & i.wrapping_neg()).trailing_zeros()) as u8)
360 .collect();
361 let actual = SuffixArray::from_bytes(a);
362
363 let w = 0_usize.count_zeros();
364 let expected: Vec<_> =
365 (0_usize..1 << k).map(|i| !i.reverse_bits() >> (w - k)).collect();
366
367 assert_eq!(actual.sa, expected);
368}