「Codeforces 204E」Little Elephant and Strings【后缀数组 + 二分】

Description

给定 nn 个字符串,询问每个字符串有多少个非空子串是 nn 个字符串中至少 kk 个字符串的子串。

Input

第一行包含两个整数 n,kn, k
接下来 nn 行,每行包含一个字符串。

Output

输出一行 nn 个整数,第 ii 个整数表示第 ii 个字符串的答案。

Sample Input

1
2
3
4
5
6
7
8
7 4
rubik
furik
abab
baba
aaabbbababa
abababababa
zero

Sample Output

1
1 0 9 9 21 30 0

Solution

首先将所有字符串连在一起,求出它们的后缀数组。枚举每个串的每个后缀,然后求出有几个该后缀的前缀符合条件,那么就需要判定区间里不同的数的个数是否 k\geqslant k

对于每个位置 xx,设 L(x)L(x) 表示 [L(x),x][L(x), x] 中恰好有 kk 个不同的数,且满足 L(x)L(x) 最大。

思路 1: 可对每个后缀二分出长度 (请参考 Editorial),时间复杂度为 O(nlog2n)O(n \log^2 n)

思路 2: 在枚举后缀时,如果后缀 c+Sc + S (c(c 为一个字符))nn 个前缀合法,那么后缀 SS 至少有 n1n - 1 个前缀合法,可以用类似于 height 数组的求法。时间复杂度为 O(nlogn)O(n \log n)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include <bits/stdc++.h>
using namespace std;

const int maxn = 200010;
int n, m, K, lg[maxn], sa[maxn], st[maxn], ed[maxn], bel[maxn], num[maxn], cnt[maxn];
int tmp[maxn], rk[maxn], ht[maxn], a[maxn], buc[maxn], fir[maxn], sec[maxn], f[maxn][20];
char str[maxn], s[maxn];

void build_sa(int n) {
copy(s + 1, s + n + 1, num + 1);
sort(num + 1, num + n + 1);
int *end = unique(num + 1, num + n + 1);
for (int i = 1; i <= n; i++) a[i] = lower_bound(num + 1, end, s[i]) - num;
memset(buc, 0, sizeof(buc));
for (int i = 1; i <= n; i++) buc[a[i]]++;
for (int i = 1; i <= n; i++) buc[i] += buc[i - 1];
for (int i = 1; i <= n; i++) rk[i] = buc[a[i] - 1] + 1;
for (int k = 1; k <= n; k <<= 1) {
for (int i = 1; i <= n; i++) fir[i] = rk[i];
for (int i = 1; i <= n; i++) sec[i] = i + k > n ? 0 : rk[i + k];
memset(buc, 0, sizeof(buc));
for (int i = 1; i <= n; i++) buc[sec[i]]++;
for (int i = 1; i <= n; i++) buc[i] += buc[i - 1];
for (int i = 1; i <= n; i++) tmp[n - --buc[sec[i]]] = i;
memset(buc, 0, sizeof(buc));
for (int i = 1; i <= n; i++) buc[fir[i]]++;
for (int i = 1; i <= n; i++) buc[i] += buc[i - 1];
for (int i = 1; i <= n; i++) sa[buc[fir[tmp[i]]]--] = tmp[i];
bool unique = true;
rk[sa[1]] = 1;
for (int i = 1; i <= n; i++) {
rk[sa[i]] = rk[sa[i - 1]];
if (fir[sa[i]] == fir[sa[i - 1]] && sec[sa[i]] == sec[sa[i - 1]]) unique = false;
else rk[sa[i]]++;
}
if (unique) break;
}
for (int i = 1, k = 0; i <= n; i++) {
if (k) k--;
int j = sa[rk[i] - 1];
while (i + k <= n && j + k <= n && a[i + k] == a[j + k]) k++;
ht[rk[i]] = k;
}
for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
for (int i = 1; i <= n; i++) f[i][0] = ht[i];
for (int j = 1; j <= lg[n]; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
f[i][j] = min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
}
}

int lcp(int x, int y) {
int k = lg[y - x + 1]; return min(f[x][k], f[y - (1 << k) + 1][k]);
}

bool check(int x, int y) {
int p, q, l, r;
if (ht[x + 1] < y) q = x;
else {
l = x + 1, r = m;
while (l <= r) {
int mid = (l + r) >> 1;
lcp(x + 1, mid) >= y ? l = mid + 1 : r = mid - 1;
}
q = r;
}
if (ht[x] < y) p = x;
else {
l = 1, r = x - 1;
while (l <= r) {
int mid = (l + r) >> 1;
lcp(mid + 1, x) >= y ? r = mid - 1 : l = mid + 1;
}
p = l;
}
return num[q] >= p;
}

int main() {
scanf("%d %d", &n, &K);
for (int i = 1; i <= n; i++) {
scanf("%s", str), st[i] = m + 1;
for (int j = 0; str[j]; j++) bel[++m] = i, s[m] = str[j];
ed[i] = m, s[++m] = ' ';
}
build_sa(m);
memset(num, 0, sizeof(num));
for (int i = 1, j = 0, k = 1; i <= m; i++) {
if (!bel[sa[i]]) continue;
if (!cnt[bel[sa[i]]]++) j++;
if (j >= K) {
for (; j - (cnt[bel[sa[k]]] == 1) >= K; j -= !--(cnt[bel[sa[k++]]]));
num[i] = k;
}
}
for (int i = 1; i <= n; i++) {
long long ans = 0;
for (int j = st[i], k = 0; j <= ed[i]; j++) {
for (k ? k-- : 0; k <= ed[i] - j && check(rk[j], k + 1); k++);
ans += k;
}
printf("%lld ", ans);
}
return 0;
}