考慮對三個字串都建個後綴自動機,那麼假設對於某個出現在三字串中的子字串$T$來說,令$u_1,u_2,u_3$分別是三個自動機吃了字串$T$之後達到的狀態,而我們知道$u_i$其實會代表某個範圍長度的子字串(也就是這篇中講到的 min 值和 max 值),因此這樣可以得到我們必須要在答案陣列中的某個區間同時加上一個數,他的值是這三個節點的 right 集合大小的乘積(right 集合的定義也在上面那篇中有提到)(因為 right 集合就是子字串的出現位置集合,取他的大小就是這個子字串出現的幾次)。先不談要怎麼求每個節點的 right 集合大小,事實上只有這三個後綴自動機是不夠的,因為這三個節點的 min 值和 max 值不盡相同,也就是這三個節點代表的子字串集合不一樣,沒辦法知道要在答案的哪個區間加上 right 集合大小的積。因此我們需要第四個後綴自動機,他是將三個字串中間用沒看過的字元串起來得到的,那麼就可以同時對這四個後綴自動機DFS並求答案了。具體來說,當DFS到某個節點時,設四個節點分別為$u_1,u_2,u_3,u$,那麼我們就知道要用$u_1,u_2,u_3$的 right 集合大小乘積去更新$u$的 min 值到 max 值這個區間的答案,對這四台自動機 DFS 一次即可(每次只走前三台有轉移邊的節點,至於第四台當前三台都有轉移邊的話他也必定會有)。
於是重點剩下要怎麼知道一個節點的 min,max 值,還有 right 集合的大小。max 值在構造的時候就已經算好了,min 值比較好算,因為如果 $u$ 在自動機中的父親是 $p$ ,那麼 $u$ 的 min 值就會是$p$的 max 值再$+1$,詳細也在上面那篇裡有。至於 right 集合的大小,作法為:首先將所有用前綴轉移到的節點的值設成$1$,那麼一個節點的 right 集合大小就會是他在 parent 樹中子數裡的數值總和。而 parent 樹有個特性,就是一個節點的的 max 值會比他父親的 max 值還大,所以可以把所有節點按照 max 值排序,按照 max 值大到小拿自己的 right 集合大小值更新父親的 right 集合大小值就可以了。
code :
#include<bits/stdc++.h> #define LL long long #define MOD 1000000007 using namespace std; const int maxn=300000+10 ; struct node { node *p ; int l,sz=0 ; /// l : Max node *trans[27] ; bool vis=0 ; node(node *u) { p=u->p ; l=u->l ; for(int i=0;i<27;i++) trans[i]=u->trans[i] ; } node(int _l,node *_p) { l = _l ; p = _p ; memset(trans,0,sizeof(trans)) ; } }; vector<node*> vn[maxn] ; void build_SAM(char *A,node *&root) { for(int i=0;i<maxn;i++) vn[i].clear() ; node *curnode ; root=curnode=new node(0,NULL);//最开始的后缀自动机只有一个节点,长度是0,父亲是空 for(int i=0;A[i];i++) { int x=A[i]-'a';//增加一个字符 node *p=curnode; curnode=new node(i+1,NULL);//建立一个Lth为i+1的节点 vn[i+1].push_back(curnode) ; for(;p && p->trans[x]==NULL ; p=p->p) p->trans[x]=curnode;//沿祖先向上,寻找插入位置。同时更新Trans if(!p)curnode->p=root;//插入到根的下面 else { node *q=p->trans[x]; if (q->l==p->l+1)curnode->p=q;//成为q的孩子 else { node *r=new node(q);r->l=p->l+1;//新建一个节点,表示curnode和q的公共前缀 vn[r->l].push_back(r) ; q->p=r;curnode->p=r;//兄弟 for (;p && p->trans[x]==q;p=p->p)p->trans[x]=r;//更新第二部分的Trans } } } node *u=root ; for(int i=0;A[i];i++) u=u->trans[A[i]-'a'] , u->sz++ ; for(int i=maxn-1;i>=0;i--) for(auto j : vn[i]) j->p->sz+=j->sz ; } LL ans[maxn] ; void dfs(node *u1,node *u2,node *u3,node *u) { if(u->vis) return ; u->vis=1 ; if(u->p) { LL add=((LL)u1->sz*u2->sz)*u3->sz%MOD ; ans[u->p->l+1]+=add ; ans[u->l+1]-=add ; } for(int i=0;i<26;i++) if(u1->trans[i] && u2->trans[i] && u3->trans[i]) dfs(u1->trans[i],u2->trans[i],u3->trans[i],u->trans[i]) ; } char s[4][maxn] ; int len[4] ; node *root[4] ; main() { scanf("%s%s%s",s[0],s[1],s[2]) ; for(int i=0;i<3;i++) len[i]=strlen(s[i]) , build_SAM(s[i],root[i]) ; for(int i=0;i<len[1];i++) s[0][len[0]+1+i]=s[1][i] ; for(int i=0;i<len[2];i++) s[0][len[0]+len[1]+2+i]=s[2][i] ; s[0][len[0]]=s[0][len[0]+len[1]+1]='z'+1 ; build_SAM(s[0],root[3]) ; dfs(root[0],root[1],root[2],root[3]) ; int n=min(len[0],min(len[1],len[2])) ; for(int i=0;i<n;i++) ans[i+1]=(ans[i+1]+ans[i])%MOD+MOD , printf("%lld\n",ans[i+1]%MOD) ; }
沒有留言:
張貼留言