這題是AC自動機上的DP,第一次寫這種東西,還蠻有趣的XD
因為一個字串相對於「一堆病毒字串」的關係其實只要用「這個字串在這些病毒字串建的AC自動機上匹配到哪個節點」就可以完整的表達它了,並且要記得在建AC自動機的時候,把哪些節點是不合法的( 會含有病毒字串的 )先標記好。這樣就可以在AC自動機上DP了,設 dp[ i ][ j ][ k ] 代表過了 i 天,長度為 j 且會匹配到自動機上節點 k 的字串數量,第一種轉移是在字串後加上 a , b , c , d ,他得到的新字串就是直接沿AC自動機轉移的結果。另外一種則是他自己長度要減少1,如果他目前的長度就是 1 了,那就把他丟進回收場。如果不是,我們想要知道所有在AC自動機上特定節點的字串,把它砍掉第一個字母後會轉移到哪。而一個字串 S 被匹配到AC自動機上的節點 k 的時候,其實是代表「考慮從自動機起點走到所有節點的路徑所產生的字串,則 k 形成的字串會是 S 的後綴,且他是所有節點裡面長度最長的」。所以如果在自動機上的節點多記錄一個 len ,存這個節點代表的字串長度,那麼可以知道,當 j = len[ k ] 的時候,也就是這個字串在匹配的過程沒有經過 fail 的邊,所以如果把第一個字母砍掉,他就會落到 fail[ k ] ,而如果 j > len[ k ] 的話,砍掉第一個字母不會對他的狀態有影響,因為我們需要的是最長的在自動機上的後綴,所以他會直接轉移到 k 。另外在每天的轉移結束後要把不合法的狀態(會含有病毒字串的狀態)砍掉,丟進醫院。
code :
#include<bits/stdc++.h> #define MOD 10007 using namespace std; const int maxn=100+10 ; int ch[maxn*15][4] , len[maxn*15] , ccnt=1 ; bool val[maxn*15] ; void insert(char *t) { int l=strlen(t) , now=0 ; for(int i=0;i<l;i++) { int c=t[i]-'a' ; if(!ch[now][c]) { memset(ch[ccnt],0,sizeof(ch[ccnt])) ; ch[now][c]=ccnt ; len[ccnt]=len[now]+1 ; ccnt++ ; } now=ch[now][c] ; } val[now]=1 ; } int fail[maxn*15] ; void get_fail() { queue<int> q ; for(int i=0;i<4;i++) if(ch[0][i]) fail[ch[0][i]]=0 , q.push(ch[0][i]) ; while(!q.empty()) { int u=q.front() ; q.pop() ; for(int i=0;i<4;i++) { int v=ch[u][i] ; if(!v) { ch[u][i]=ch[fail[u]][i] ; continue ; } q.push(v) ; fail[v]=ch[fail[u]][i] ; if(val[fail[v]]) val[v]=1 ; } } } inline void up(int &x,int y) { x=(x+y)%MOD ; } int n,dp[2][maxn][maxn*15] ; char s[maxn] ; void DP() { int st=0 , L=strlen(s) ; for(int i=0;i<L;i++) { st=ch[st][s[i]-'a'] ; if(val[st]) { printf("0 1\n") ; return ; } } dp[1][L][st]=1 ; int ans1=0 , ans2=0 ; for(int i=1;i<=n;i++) { for(int j=1;j<maxn;j++) for(int k=0;k<ccnt;k++) dp[(i+1)%2][j][k]=0 ; for(int j=1;j<maxn;j++) for(int k=0;k<ccnt;k++) if(dp[i%2][j][k]) { int add=dp[i%2][j][k] ; for(int z=0;z<4;z++) up(dp[(i+1)%2][j+1][ch[k][z]],add) ; if(j==1) { up(ans1,add) ; continue ; } if(j==len[k]) up(dp[(i+1)%2][j-1][fail[k]],add) ; else up(dp[(i+1)%2][j-1][k],add) ; } for(int j=1;j<maxn;j++) for(int k=0;k<ccnt;k++) if(val[k]) up(ans2,dp[(i+1)%2][j][k]) , dp[(i+1)%2][j][k]=0 ; } printf("%d %d\n",ans1,ans2) ; } char t[maxn] ; main() { int k ; scanf("%s%d%d",s,&n,&k) ; while(k--) scanf("%s",t) , insert(t) ; get_fail() ; DP() ; }
沒有留言:
張貼留言