2015年5月25日 星期一

[CF 494D] Birthday

作法:

首先觀察到我們有兩種方法來計算題目所求的值,令$ \displaystyle g(u)=\sum_{i=1}^{n} d(u,i)^2  $,那麼$\displaystyle f(u,v)=g(u)-2\cdot \sum_{x\notin S(v)} d(u,x)^2=2\cdot \sum_{x\in S(v)} d(u,x)^2 - g(u)$
首先先想辦法求出$g(u)$,如果我們只要知道特定一個$u$的$g$值的話,就直接對以$u$為根建立有根樹並DP就好了,DP時對每個節點$x$紀錄三個值:以$x$為根的子樹的大小,以$x$為根的子樹中的所有點走到$x$的距離和,還有以$x$為根的子樹中的所有點走到$x$的距離平方和,這三個東西的遞推式子算是顯然的(以後簡稱他們叫作三個DP值)。但我們需要對每個$u$都算出他的$g$值,如果對每個點都DFS一次顯然會太慢。這時只要注意到,實際上對兩個不同的點DFS的時候,會有很多點的那三個值是一模一樣的,更具體來說,如果現在對$X$和$Y$作DFS,那麼對於任何一個點$u$,如果從$X$走到$u$的路徑上的最後一條邊等於從$Y$走到$u$的路徑上的最後一條邊,則兩次的DFS在$u$求出的值是一樣的,因為這兩次的$u$的子樹是一模一樣的。這啟發了我們可以對每個點分別計算「當他的父節點是他的某個鄰居時,他的三個DP值會是多少」。嚴謹來說,定義對於相鄰的兩點$u,v$,定義$T(u,v)$代表砍掉$(u,v)$這條邊後以$u$為根的子樹,那麼把這棵子樹對應的三個DP值算出來。這樣只有$O(n)$個值要算,並且也不難得知這些值的$O(n)$或是$O(nlogn)$的算法,因此這裡成功解決了遇處理每個點的$g$值的部份。

有了每個點的$g$值之後,再回到題目要求的式子。當$u$不在以$v$為根的子樹內時,那麼第二條和所求等價的式子會很好算,因為我們可以把$d(u,x)$拆成$d(u,v)+d(v,x)$,這樣所求就可以由剛才求出的$T(v,fa[v])$的三個DP值計算出來了,其中$fa[v]$代表$v$的父節點(題目給的是以$1$為根的有根樹)。至於$u$落在$v$的子樹內的情形則使用第一條和所求等價的式子,此時對應的子樹就會是$T(fa[v],v)$,因此算法也根剛才類似了。另外注意到這個算法的回答詢問複雜度會是$O(logn)$,因為會需要求兩點之間的距離。

code :


#include<bits/stdc++.h>
#define LL long long
#define MOD 1000000007
using namespace std;
const int maxn=100000+10 ;
 
struct P{int to ; LL dis;};
vector<P> v[maxn] ;
map<int,int> mp[maxn] ;
vector<LL> psum[maxn],sqsum[maxn] ;
vector<int> sz[maxn] ;
LL sqval[maxn] ;
 
int anc[17][maxn],dep[maxn] ;
int tin[maxn],tout[maxn],tick ;
void dfs0(int x,int f,int d)
{
    tin[x]=tick++ ;
    dep[x]=d ;
    anc[0][x]=f ;
    for(int i=1;i<17;i++) anc[i][x]=anc[i-1][anc[i-1][x]] ;
 
    for(auto i : v[x]) if(f!=i.to)
        dfs0(i.to,x,(i.dis+d)%MOD) ;
    tout[x]=tick++ ;
}
inline bool isfa(int x,int y)
{
    return tin[x]<=tin[y] && tout[x]>=tout[y] ;
}
int LCA(int x,int y)
{
    if(isfa(x,y)) return x ;
    if(isfa(y,x)) return y ;
    for(int i=16;i>=0;i--) if(!isfa(anc[i][x],y))
        x=anc[i][x] ;
    return anc[0][x] ;
}
int getlen(int x,int y)
{
    int lca=LCA(x,y) , ret=(dep[x]+dep[y]-2*dep[lca])%MOD ;
    if(ret<0) ret+=MOD ;
    return ret ;
}
 
inline LL cal_add_s(const P &i,int id)
{
    return sqsum[i.to][id]+(i.dis*i.dis%MOD)*sz[i.to][id]+2*i.dis*psum[i.to][id] ;
}
inline LL cal_add_p(const P &i,int id)
{
    return psum[i.to][id]+sz[i.to][id]*i.dis ;
}
 
int n,cnt[maxn] ;
void dfs(int x,int id)
{
    if(psum[x][id]!=-1) return ;
    sz[x][id]=1 ;
    LL &ans1=psum[x][id] ;
    LL &ans2=sqsum[x][id] ; ans1=0 ;
    if(v[x].size()==1) return ;
    for(auto i : v[x]) if(i.to!=v[x][id].to)
    {
        int id2=mp[i.to][x] ; dfs(i.to,id2) ;
        ans1=(ans1+cal_add_p(i,id2))%MOD ;
        ans2=(ans2+cal_add_s(i,id2))%MOD ;
        sz[x][id]+=sz[i.to][id2] ;
    }
    if(++cnt[x]==2)
    {
        int id2=mp[v[x][id].to][x] ;
        LL tot1=ans1+cal_add_p(v[x][id],id2) ; tot1%=MOD ;
        LL tot2=ans2+cal_add_s(v[x][id],id2) ; tot2%=MOD ;
        int totsz=sz[x][id]+sz[v[x][id].to][id2] ;
        for(int i=0;i<v[x].size();i++) if(psum[x][i]==-1)
        {
            int id3=mp[v[x][i].to][x] ;
            assert(psum[v[x][i].to][id3]!=-1) ;
            psum[x][i]=tot1-cal_add_p(v[x][i],id3)%MOD ;
            if(psum[x][i]<0) psum[x][i]+=MOD ;
            sqsum[x][i]=tot2-cal_add_s(v[x][i],id3)%MOD ;
            if(sqsum[x][i]<0) sqsum[x][i]+=MOD ;
            sz[x][i]=totsz-sz[v[x][i].to][id3] ;
        }
    }
}
 
LL getval(int x)
{
    LL ret=0 ;
    for(auto i : v[x])
    {
        int id=mp[i.to][x] ; dfs(i.to,id) ;
        ret=(ret+cal_add_s(i,id))%MOD ;
    }
    return ret ;
}
 
LL query(int u,int v)
{
    LL ret=0LL ;
    if(!isfa(v,u))
    {
        ret=-sqval[u]+MOD ;
        int id=mp[v][anc[0][v]] ;
        LL dis=getlen(u,v) , add=(dis*dis%MOD)*sz[v][id] ;
        add+=sqsum[v][id] ;
        add+=2*dis*psum[v][id] ; add%=MOD ;
        return (ret+2*add)%MOD ;
    }
    else if(v==1) return sqval[u] ;
    else
    {
        ret=sqval[u] ;
        int v2=anc[0][v] , id=mp[v2][v] ;
        LL dis=dep[u]-dep[v2] ; if(dis<0) dis+=MOD ;
        LL sub=(dis*dis%MOD)*sz[v2][id] ;
        sub+=sqsum[v2][id] ;
        sub+=2*dis*psum[v2][id] ; sub%=MOD ;
        return ((ret-2*sub)%MOD+MOD)%MOD ;
    }
}
 
main()
{
    scanf("%d",&n) ;
    for(int i=1;i<n;i++)
    {
        int x,y ; LL dis ;
        scanf("%d%d%I64d",&x,&y,&dis) ;
        v[x].push_back((P){y,dis}) ;
        v[y].push_back((P){x,dis}) ;
    }
    for(int i=1;i<=n;i++)
    {
        psum[i].resize(v[i].size()) ;
        sqsum[i].resize(v[i].size()) ;
        sz[i].resize(v[i].size()) ;
        for(int j=0;j<v[i].size();j++)
            mp[i][v[i][j].to]=j , psum[i][j]=-1 ;
    }
    dfs0(1,1,0) ;
    for(int i=1;i<=n;i++) sqval[i]=getval(i) ;
    int Q ; scanf("%d",&Q) ;
    while(Q--)
    {
        int x,y ; scanf("%d%d",&x,&y) ;
        printf("%I64d\n",query(x,y)) ;
    }
}
 

沒有留言:

張貼留言