如果我們能計算出一個二維陣列ok,其中ok[i][j]代表S[i,...,j]是否可以恰好被完整刪除,那麼就可以用簡單的DP求出最後答案了。首先我們知道S中有許多子字串是可以「直接被刪除的」,也就是他剛好等於給定字串集合裡的其中一個字串。當我們把這種子字串刪掉時,有可能會造成新的子字串可以被刪掉。例如原字串為caabbd,並且可以刪掉ab,那麼在第二層就可以把aabb刪掉了,也就是ok[2][5]=1。我們可以概念上的稱「可直接被刪除的子字串」為一級的子字串,需要經過兩層才能刪除的為二級的子字串,那麼可知一個有辦法被刪除的子字串至多n層。因此我們要想辦法從i級的子字串的ok陣列推出i+1級的陣列,重複n次就可以得到我們要的ok陣列了。假設我們現在想知道S[L,...,R]這個子字串是否能在第k+1層時被刪掉,那麼考慮所有S[L,...,R]的可在第k層時被刪掉的子字串們,那麼S[L,...,R]可以在第k+1層時被刪掉若且唯若:存在(l_1,r_1),...,(l_t,r_t)滿足L\leq l_1\leq r_1 < l_2 \leq r_2 < ... < l_t \leq r_t \leq R,並且S[l_i,...,r_i]均可以在第k層被刪掉,還有把S[L,...,R]扣掉所有S[l_i,...,r_i]後所得到的字串是可以直接刪除的。有了這個之後就可以來算S[L,...,R]是否可以在第k+1層時刪除了。我們從S[L]開始往右邊掃,當遇到S[i]時,一個決策是:選一個r使得S[i,...,r]可以被刪除,那麼轉移到S[r+1]的狀態。另一個決策則是將S[i]放進一個「暫存字串」中,轉移到S[i+1],並且我們想要在處理完S[R]時暫存字串是可以一次被刪除的(就和前面的條件等價了)。因此我們可以用一個 trie 節點搭配當前的 index 來做 DP 狀態,設 dp[x][i] 代表只考慮字串S[L,...,i]時,是否可以讓「暫存字串」在 trie 上的節點為x,這樣就可以O(1)轉移了,總複雜度為O(n^5)。
code :
#include<bits/stdc++.h> using namespace std; const int maxn=100+10,maxc=3000 ; int ch[maxc][26],ccnt ; bool have[maxc] ; void insert(char *t) { int now=0 ; for(int i=0;t[i];i++) { int c=t[i]-'a' ; if(!ch[now][c]) { ccnt++ ; memset(ch[ccnt],0,sizeof(ch[ccnt])) ; ch[now][c]=ccnt ; } now=ch[now][c] ; } have[now]=1 ; } int n ; char s[maxn],t[maxn] ; bool ok[maxn][maxn] ; bool dp[maxc][maxn] ; vector<int> ri[maxn] ; bool process() { for(int i=1;i<=n;i++) { ri[i].clear() ; for(int j=i;j<=n;j++) if(ok[i][j]) ri[i].push_back(j) ; } bool ret=0 ; for(int i=1;i<=n;i++) { memset(dp,0,sizeof(dp)) ; dp[0][i-1]=1 ; for(int j=0;j<=ccnt;j++) for(int k=i-1;k<n;k++) if(dp[j][k]) { int c=s[k+1]-'a' ; if(ch[j][c]) dp[ch[j][c]][k+1]=1 ; for(auto it : ri[k+1]) dp[j][it]=1 ; } for(int j=1;j<=ccnt;j++) if(have[j]) for(int k=i;k<=n;k++) if(dp[j][k] && !ok[i][k]) {ok[i][k]=1 ; ret=1 ;} } return ret ; } int dp2[maxn] ; void solve(int m) { memset(have,0,sizeof(have)) ; memset(ok,0,sizeof(ok)) ; memset(dp,0,sizeof(dp)) ; ccnt=0 ; memset(ch[0],0,sizeof(ch[0])) ; scanf("%s",s+1) ; n=strlen(s+1) ; while(m--) scanf("%s",t) , insert(t) ; while(process()) ; memset(dp2,0,sizeof(dp2)) ; for(int i=1;i<=n;i++) { dp2[i]=max(dp2[i],dp2[i-1]) ; for(int j=i;j<=n;j++) if(ok[i][j]) dp2[j]=max(dp2[j],dp2[i-1]+(j-i+1)) ; } printf("%d\n",n-dp2[n]) ; } main() { int m ; while(scanf("%d",&m)==1 && m) solve(m) ; }
沒有留言:
張貼留言