如果我們能計算出一個二維陣列$ok$,其中$ok[i][j]$代表$S[i,...,j]$是否可以恰好被完整刪除,那麼就可以用簡單的DP求出最後答案了。首先我們知道$S$中有許多子字串是可以「直接被刪除的」,也就是他剛好等於給定字串集合裡的其中一個字串。當我們把這種子字串刪掉時,有可能會造成新的子字串可以被刪掉。例如原字串為$caabbd$,並且可以刪掉$ab$,那麼在第二層就可以把$aabb$刪掉了,也就是$ok[2][5]=1$。我們可以概念上的稱「可直接被刪除的子字串」為一級的子字串,需要經過兩層才能刪除的為二級的子字串,那麼可知一個有辦法被刪除的子字串至多$n$層。因此我們要想辦法從$i$級的子字串的$ok$陣列推出$i+1$級的陣列,重複$n$次就可以得到我們要的$ok$陣列了。假設我們現在想知道$S[L,...,R]$這個子字串是否能在第$k+1$層時被刪掉,那麼考慮所有$S[L,...,R]$的可在第$k$層時被刪掉的子字串們,那麼$S[L,...,R]$可以在第$k+1$層時被刪掉若且唯若:存在$(l_1,r_1),...,(l_t,r_t)$滿足$L\leq l_1\leq r_1 < l_2 \leq r_2 < ... < l_t \leq r_t \leq R$,並且$S[l_i,...,r_i]$均可以在第$k$層被刪掉,還有把$S[L,...,R]$扣掉所有$S[l_i,...,r_i]$後所得到的字串是可以直接刪除的。有了這個之後就可以來算$S[L,...,R]$是否可以在第$k+1$層時刪除了。我們從$S[L]$開始往右邊掃,當遇到$S[i]$時,一個決策是:選一個$r$使得$S[i,...,r]$可以被刪除,那麼轉移到$S[r+1]$的狀態。另一個決策則是將$S[i]$放進一個「暫存字串」中,轉移到$S[i+1]$,並且我們想要在處理完$S[R]$時暫存字串是可以一次被刪除的(就和前面的條件等價了)。因此我們可以用一個 trie 節點搭配當前的 index 來做 DP 狀態,設 $dp[x][i]$ 代表只考慮字串$S[L,...,i]$時,是否可以讓「暫存字串」在 trie 上的節點為$x$,這樣就可以$O(1)$轉移了,總複雜度為$O(n^5)$。
code :
#include<bits/stdc++.h> using namespace std; const int maxn=100+10,maxc=3000 ; int ch[maxc][26],ccnt ; bool have[maxc] ; void insert(char *t) { int now=0 ; for(int i=0;t[i];i++) { int c=t[i]-'a' ; if(!ch[now][c]) { ccnt++ ; memset(ch[ccnt],0,sizeof(ch[ccnt])) ; ch[now][c]=ccnt ; } now=ch[now][c] ; } have[now]=1 ; } int n ; char s[maxn],t[maxn] ; bool ok[maxn][maxn] ; bool dp[maxc][maxn] ; vector<int> ri[maxn] ; bool process() { for(int i=1;i<=n;i++) { ri[i].clear() ; for(int j=i;j<=n;j++) if(ok[i][j]) ri[i].push_back(j) ; } bool ret=0 ; for(int i=1;i<=n;i++) { memset(dp,0,sizeof(dp)) ; dp[0][i-1]=1 ; for(int j=0;j<=ccnt;j++) for(int k=i-1;k<n;k++) if(dp[j][k]) { int c=s[k+1]-'a' ; if(ch[j][c]) dp[ch[j][c]][k+1]=1 ; for(auto it : ri[k+1]) dp[j][it]=1 ; } for(int j=1;j<=ccnt;j++) if(have[j]) for(int k=i;k<=n;k++) if(dp[j][k] && !ok[i][k]) {ok[i][k]=1 ; ret=1 ;} } return ret ; } int dp2[maxn] ; void solve(int m) { memset(have,0,sizeof(have)) ; memset(ok,0,sizeof(ok)) ; memset(dp,0,sizeof(dp)) ; ccnt=0 ; memset(ch[0],0,sizeof(ch[0])) ; scanf("%s",s+1) ; n=strlen(s+1) ; while(m--) scanf("%s",t) , insert(t) ; while(process()) ; memset(dp2,0,sizeof(dp2)) ; for(int i=1;i<=n;i++) { dp2[i]=max(dp2[i],dp2[i-1]) ; for(int j=i;j<=n;j++) if(ok[i][j]) dp2[j]=max(dp2[j],dp2[i-1]+(j-i+1)) ; } printf("%d\n",n-dp2[n]) ; } main() { int m ; while(scanf("%d",&m)==1 && m) solve(m) ; }
沒有留言:
張貼留言