HDU2896 AC 自动机多串匹配多串

这算是很经典的多模式串多原串的题目了,匹配的复杂度是 O(n)O(n) 的,这样,直接上自动机,注意判重另外开一个 vis 变量表示第 turn 趟是否访问过此结点即可。

运用静态分配内存,跑了 156ms。

我的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>

using namespace std;

const int MAX = 101100;
const int MAXK = 128;

struct Node {
Node* ne[MAXK];
Node* fail;
int cnt, vis;
} node[MAX], *root;
Node* que[MAX];
char s[11000];
int K, ans[5], top;

Node* New() {
Node* ret = &node[K++];
ret->fail = NULL;
ret->cnt = 0;
ret->vis = 0;
for (int i = 0; i < MAXK; i++) {
ret->ne[i] = NULL;
}
return ret;
}

void init() {
K = 0;
root = New();
}

void insert(char* s, int num) {
Node* ptr = root;
char* p = s;
int id;
while (*p) {
id = *(p++);
if (ptr->ne[id] == NULL) {
ptr->ne[id] = New();
}
ptr = ptr->ne[id];
}
ptr->cnt = num;
}

void bfs() {
int b = 0, f = 0;
Node* now;
Node* ptr;
que[b++] = root;
while (f != b) {
now = que[f++];
for (int i = 0; i < MAXK; i++) {
if (now->ne[i] != NULL) {
if (now == root) {
now->ne[i]->fail = root;
} else {
ptr = now->fail;
while (ptr != NULL) {
if (ptr->ne[i] != NULL) {
now->ne[i]->fail = ptr->ne[i];
break;
} else {
ptr = ptr->fail;
}
}
if (ptr == NULL) {
now->ne[i]->fail = root;
}
}
que[b++] = now->ne[i];
}
}
}
}

void find(char* s, int turn) {
bool vis[520] = {0};
Node* ptr = root;
Node* next;
char* p = s;
int id;
vis[0] = true;
top = 0;
while (*p) {
id = *(p++);
while (ptr->ne[id] == NULL && ptr != root) {
ptr = ptr->fail;
}
ptr = ptr->ne[id];
if (ptr == NULL) {
ptr = root;
}
next = ptr;
while (next != NULL && next->vis != turn) {
if (!vis[next->cnt]) {
ans[top++] = next->cnt;
vis[next->cnt] = true;
}
next->vis = turn;
next = next->fail;
}
}
}

int main() {
int n, tol;
while (~scanf("%d", &n)) {
init();
gets(s);
for (int i = 1; i <= n; i++) {
gets(s);
insert(s, i);
}
bfs();
scanf("%d", &n);
tol = 0;
gets(s);
for (int i = 1; i <= n; i++) {
gets(s);
find(s, i);
if (top) {
printf("web %d:", i);
sort(ans, ans + top);
for (int j = 0; j < top; j++) {
printf(" %d", ans[j]);
}
puts("");
tol++;
}
}
printf("total: %d\n", tol);
}
return 0;
}