2015年5月18日 星期一

[CF 504E] Misha and LCP on Tree


作法:

這題我是寫hash的作法,官方解裡有提到用樹鍊剖分 + 後綴數組的作法,雖然他講的沒有很清楚,不過官方解的code寫的很好懂,建議可以參考看看。

首先二分搜答案,只要先求出「這條路徑上的第$t$個點是誰」,就可以把問題轉化成判斷樹上的兩條路徑代表的字串是否相等,因此我們需要求出這兩個字串的hash值,也就是我們需要計算樹上任意一條路徑的hash值。而這可以分成「沿著路競走一路往上」、「沿著路徑走一路往下」、「沿著路徑走先往上再往下」幾種情形,所以對於前兩種情形,我們就必須在每個節點 $i$ 紀錄兩種hash值,兩個hash值都是從 $i$ 走到根所對應的字串的hash值,不過一個是$s[root]+...+s[i]\cdot X^{dep[i]}$,一個是$s[root]\cdot X^{dep[i]}+...+s[i]$,其中 $s$ 為原始字串,$dep[i]$ 為這個點的深度,並且$dep[root]=0$,這兩個值都可以在DFS時順便算好。再來則是如何求出所求的hash值,首先前兩種情況是顯然的,至於第三種情況,我們要先找到兩個點的LCA,在把兩段的hash值合併,而合併的式子稍微推一下就可以了。

但這樣傳上去TLE了,畢竟是$O(Qlog^2n)$的解,官方解也說hash的解太慢了需要優化。首先就是官方解提到的「離線作詢問$k$級祖先」,因為離線處理「詢問節點$x$的$k$級祖先」可以做到$O(n+Q)$,方法也不難,在DFS時維護一個stack就可以了。具體作法是,先把所有詢問讀進來,求出每條路徑兩端點的LCA,並且對每個詢問都紀錄兩個值 $l,r$ ,代表當前這個詢問的解的區間,在每個階段對每個詢問跑一次,如果這個詢問的區間長度已經是$1$了那就略過,至於不是的話,等於是現在我們要詢問長度$mid=\frac{l+r}{2}$是否可行,所以此時可以知道這個詢問會需要知道「誰的幾級祖先是誰」,所以就可以把所有這些東西紀錄下來。跑完所有詢問一次之後DFS一次處理剛才所需要的答案,然後就可以再跑一次所有詢問計算此時的兩個hash值,並且判斷他們兩個是否相等了。

但這樣傳上去還是TLE了,後來我測了一下,發現光預處理每個詢問裡兩點LCA的部份就已經花了$4$秒了,所以我把求LCA的部份也改成離線的作法,離線作LCA可以做到$O((n+Q)\alpha (n))$,詳細作法可以查查LCA的 tarjan 算法。苦苦的改完後竟然還TLE了......最後發現是因為取模太慢了,把 「$ret=(ret\% MOD+MOD)\% MOD$」的外面那層改成用「如果小於 $0$ 就加MOD」就壓線過了,這題的時限太恐怖了OAO

後來我去看別人的作法,這份code跑的速度最快,他的作法是樹鍊剖分套hash,寫的也蠻好懂的,建議也可以參考看看。

code :



#include<bits/stdc++.h>
#define MOD 1000000007
#define LL long long
#define debugf(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
const int maxn=300000+10,maxq=1000000+10 ;
const LL X=123LL ;
LL xpow[maxn],ixpow[maxn] ;
 
int getint()
{
    char c=getchar() ;
    while(c<'0'||c>'9') c=getchar() ;
    int ret=0 ;
    while(1)
    {
        ret=ret*10+c-'0' ;
        c=getchar() ;
        if(c<'0'||c>'9') return ret ;
    }
}
 
char cc[12] ;
void printint(int x)
{
    int cnt=0 ;
    while(x) cc[cnt++]='0'+x%10 , x/=10 ;
    if(!cnt) cc[cnt++]='0' ;
    for(int i=cnt-1;i>=0;i--) putchar(cc[i]) ;
    putchar('\n') ;
}
 
LL pow(LL x,int n)
{
    if(n<=1) return n ? x : 1LL ;
    LL tmp=pow(x,n/2) ;
    if(n&1) return (tmp*tmp%MOD)*x%MOD ;
    return tmp*tmp%MOD ;
}
 
LL inv(LL x)
{
    return pow(x,MOD-2) ;
}
 
int bs[maxn] ;
int getbs(int x)
{
    return x==bs[x] ? x : bs[x]=getbs(bs[x]) ;
}
 
vector<int> v[maxn] ;
 
int n,fa[maxn],dep[maxn] ;
char s[maxn] ;
LL h1[maxn],h2[maxn] ;
void dfs(int x,int f,int d=0)
{
    dep[x]=d ; fa[x]=f ;
    if(x)
        h1[x]=(h1[f]*X+s[x]-'a')%MOD ,
        h2[x]=(h2[f]+xpow[d]*(s[x]-'a'))%MOD ;
    for(auto i : v[x]) if(i!=f)
        dfs(i,x,d+1) ;
}
 
struct P{int val,id;};
vector<P> v2[maxn] ;
int ansval[2*maxq] ;
 
bool vis[maxn] ;
void dfs_solve_lca(int x,int f)
{
    vis[x]=1 ;
    for(auto i : v2[x]) if(vis[i.val])
        ansval[i.id]=getbs(i.val) ;
    v2[x].clear() ;
    for(auto i : v[x]) if(i!=f)
        dfs_solve_lca(i,x) , bs[getbs(i)]=x ;
}
 
int sta[maxn],top=0 ;
void dfs_solve_fa(int x,int f)
{
    sta[top++]=x ;
    for(auto i : v2[x]) ansval[i.id]=sta[top-1-i.val] ;
    v2[x].clear() ;
    for(auto i : v[x]) if(i!=f)
        dfs_solve_fa(i,x) ;
    top-- ;
}
 
struct query
{
    int x1,y1,x2,y2,lca1,lca2,l,r ;
    void get(int qid)
    {
        x1=getint() , y1=getint() ;
        x2=getint() , y2=getint() ;
        v2[x1].push_back((P){y1,2*qid-1}) ;
        v2[y1].push_back((P){x1,2*qid-1}) ;
        v2[x2].push_back((P){y2,2*qid}) ;
        v2[y2].push_back((P){x2,2*qid}) ;
    }
    void get2(int qid)
    {
        if(s[x1]!=s[x2]) {l=0 , r=1 ; return ;}
        lca1=ansval[2*qid-1] ;
        lca2=ansval[2*qid] ;
        l=1 ; r=min(dep[x1]+dep[y1]-2*dep[lca1],
                    dep[x2]+dep[y2]-2*dep[lca2])+2 ;
    }
}q[maxq];
int Qnum ;
 
void precal_query()
{
    dfs_solve_lca(0,0) ;
    for(int i=1;i<=Qnum;i++) q[i].get2(i) ;
}
 
LL gethash(int x,int y,int lca,int len,int query_pos)
{
    LL ret ;
    if(dep[x]-dep[lca]+1 >= len)
        ret=h1[x]-h1[query_pos]*xpow[len] ; /// = getfa(x.len)
    else
    {
        y=query_pos ; /// = getfa(y,dep[x]+dep[y]-2*dep[lca]+1-len) ;
        int a=dep[x]-dep[lca] ;
        ret=h1[x]-h1[fa[lca]]*xpow[a+1] ;
        int b=a+dep[y]-dep[lca] ;
        LL add=h2[y]-h2[fa[lca]] ;
        if(b>dep[y]) add*=xpow[b-dep[y]] ;
        else if(b<dep[y]) add*=ixpow[dep[y]-b] ;
        ret+=add-(s[lca]-'a')*xpow[a] ;
    }
    ret%=MOD ;
    if(ret<0) ret+=MOD ;
    return ret ;
}
 
bool process_query()
{
    bool ok=1 ;
    for(int i=1;i<=Qnum;i++) if(q[i].l+1!=q[i].r)
    {
        ok=0 ;
        int len=(q[i].l+q[i].r)>>1 ;
        int x1=q[i].x1 , y1=q[i].y1 , lca1=q[i].lca1 ;
        int x2=q[i].x2 , y2=q[i].y2 , lca2=q[i].lca2 ;
        if(dep[x1]-dep[lca1]+1>=len)
            v2[x1].push_back((P){len,2*i-1}) ;
        else
            v2[y1].push_back((P){dep[x1]+dep[y1]-2*dep[lca1]+1-len,2*i-1}) ;
        if(dep[x2]-dep[lca2]+1>=len)
            v2[x2].push_back((P){len,2*i}) ;
        else
            v2[y2].push_back((P){dep[x2]+dep[y2]-2*dep[lca2]+1-len,2*i}) ;
    }
    if(ok) return 1 ;
    dfs_solve_fa(0,0) ;
    for(int i=1;i<=Qnum;i++) if(q[i].l+1!=q[i].r)
    {
        int mid=(q[i].l+q[i].r)>>1 ;
        if(gethash(q[i].x1,q[i].y1,q[i].lca1,mid,ansval[2*i-1])
           ==gethash(q[i].x2,q[i].y2,q[i].lca2,mid,ansval[2*i]))
            q[i].l=mid ;
        else q[i].r=mid ;
    }
    return 0 ;
}
 
main()
{
    n=getint() ; for(int i=1;i<=n;i++) s[i]=getchar() ;
    LL IX=inv(X) ;
    for(int i=0;i<=n;i++)
    {
        xpow[i]= i ? xpow[i-1]*X%MOD : 1LL ;
        ixpow[i]= i ? ixpow[i-1]*IX%MOD : 1LL ;
        bs[i]=i ;
    }
    for(int i=1;i<n;i++)
    {
        int x=getint() , y=getint() ;
        v[x].push_back(y) ;
        v[y].push_back(x) ;
    }
    v[0].push_back(1) ;
    dfs(0,0,0) ;
 
    Qnum=getint() ;
    for(int i=1;i<=Qnum;i++) q[i].get(i) ;
    precal_query() ;
 
    while(!process_query()) ;
    for(int i=1;i<=Qnum;i++) printint(q[i].l) ;
}
 

沒有留言:

張貼留言