2015年2月28日 星期六

[TIOJ 1243] 感染 Infection

第一次寫這種比較難轉移的樹上的DP,但這題的測資爛了,最後一筆測資和範測一模一樣,輸出 4 卻是錯的,沒辦法AC QQ
測資修正之後AC了!總共獲得了24個AC XD

作法:

因為其實是「把邊砍掉」之後才會有「哪些點被感染哪些點沒被感染」,所以可以反過來作砍掉 m 條邊之後最多有幾個點沒被感染。把沒被感染的叫白點,剩下的叫黑點,記 dp[ b ][ x ][ m ] 代表在以 x 為根的子樹中,當 x 的顏色是 b ,且砍掉 m 個邊的時候,白點的最大值。所以這樣可以寫出很大一坨的轉移式:

dp[ b ][ x ][ m ] = ( b==white ) + MAX{ sigma dp[ bi ][ vi ][ mi ] } ,其中 bi 和 mi 們滿足
sigma( mi + ( bi != b ) ) = m 。 vi 代表的是 x 的子結點們,而其實這句話的意思就是將可以砍的 m 條邊分配給 x 的子結點,並且當 x 的子結點和 x 不同色的時候必須要再多加一條邊分隔他們。

所以這個的意思就是我們必須從好幾個 DP 陣列 ( 也就是 dp[ bi ][ vi ] 們 ) 合併成一個新的陣列 dp[ b ][ x ] ,但當然不可能枚舉所有的 mi ,所以在這裡我們還需要另一個DP,設 dp2[ i ][ m ]代表目前已經考慮了 i 顆子樹了,那麼砍掉 m 條邊至多可以有幾個白點。每次多考慮一個子樹就會把那個子樹的DP陣列和現在的dp2 陣列跑過一次。

這樣看起來每個節點都會花掉 O( n^2 ) 的時間轉移,但其實只要注意到對於一個子樹 x ,他的邊數只有 x 的 size -1 ,也就是 dp[ b ][ x ][ m ] 的 m 其實只要作 0 ~ x 的 size -1 就夠了。所以設 x 的子樹的 size 分別為 n1 , n2 , ... , nt ,當在作 dp[ i+1 ] 的時候其實要跑過的陣列只有 n(i+1) * ( n1+n2+...+ni ) ,把他們 sum 起來會得到總合併時間 = sigma ni * nj ,其中 i , j 跑遍 1~t 且 i < j 。所以如果用數歸,假設 x 的子樹們都能在 O( size ^2 ) 的時間內完成,那麼完成 x 的時間就是 (sigma ni^2) + (sigma ni * nj) <= ( sigma ni ) ^2 < size[x]^2 ,得證。

code :

#include<bits/stdc++.h>
#define INF 10000000
using namespace std;
const int maxn=1000+50 ;
 
int sz[maxn],dp[2][maxn][maxn] ;
int dp2[maxn][maxn] ;
vector<int> v[maxn] ;
bool G[maxn] ;
 
void dfs(int x,int fa)
{
    sz[x]=1 ;
    if(x && (int)v[x].size()==1)
    {
        if(!G[x]) dp[0][x][0]=1 , dp[1][x][0]=0 ;
        else dp[0][x][0]=-INF , dp[1][x][0]=0 ;
        return ;
    }
    for(auto i : v[x]) if(i!=fa)
        dfs(i,x) , sz[x]+=sz[i] ;
    for(int b=0;b<2;b++)
    {
        if(b==0 && G[x])
        {
            for(int i=0;i<sz[x];i++) dp[b][x][i]=-INF ;
            continue ;
        }
        for(int i=0;i<sz[x];i++) dp2[0][i]=-INF ;
        dp2[0][0]=0 ;
        int tot=0,cnt=0 ;
        for(int i=0;i<v[x].size();i++) if(v[x][i]!=fa)
        {
            int y=v[x][i] ;
            tot+=sz[y] ;
            for(int j=0;j<sz[x];j++) dp2[cnt+1][j]=-INF ;
 
            for(int j=(b!=0);j<=tot;j++)
                for(int k=0;k<sz[y] && j-k-(b!=0)>=0;k++)
                dp2[cnt+1][j]=max(dp2[cnt+1][j],
                        dp2[cnt][j-k-(b!=0)]+dp[0][y][k]) ;
 
            for(int j=(b!=1);j<=tot;j++)
                for(int k=0;k<sz[y] && j-k-(b!=1)>=0;k++)
                dp2[cnt+1][j]=max(dp2[cnt+1][j],
                        dp2[cnt][j-k-(b!=1)]+dp[1][y][k]) ;
            cnt++ ;
        }
        for(int i=0;i<sz[x];i++)
            dp[b][x][i]= (b==0) + dp2[cnt][i] ;
    }
}
 
main()
{
    int n,k,num ;
    scanf("%d%d%d",&n,&k,&num) ;
    if(k+num>n) { printf("ACM rules!\n") ; return 0 ; }
    while(num--)
    {
        int x ; scanf("%d",&x) ;
        G[x]=1 ;
    }
    for(int i=1;i<n;i++)
    {
        int x,y ; scanf("%d%d",&x,&y) ;
        v[x].push_back(y) ;
        v[y].push_back(x) ;
    }
 
    dfs(0,-1) ;
 
    int ans1=n-1,ans2=n-1 ;
    for(int i=0;i<n;i++) if(dp[0][0][i]>=k)
        { ans1=i ; break ; }
    for(int i=0;i<n;i++) if(dp[1][0][i]>=k)
        { ans2=i ; break ; }
    printf("%d\n",min(ans1,ans2)) ;
}
 

沒有留言:

張貼留言