2015年5月14日 星期四

[HOJ 123][TIOJ 1810] 漫遊小鎮 / 小鎮DP

作法:

這題是標準的插頭DP題,需要把連通性壓入狀態表示中,在上一篇中提到的是簡單的版本。因為那題只要求用多個哈密頓圈覆蓋,而這題則是限定用一條哈密頓鍊覆蓋。另外還有要求是從左上角走到左下角,因此這時候能拼的拼圖總共有9種:
其中上面三種只能拼在左上角和左下角,並且左上角和左下角一定要用這三種拼圖來拼。但兩題不只有差這樣而已,如果一樣直接DP的話,$n=3$ 時會得出答案為 $3$ (多出來的路徑就是下面的左圖)。也就是在DP最後一格時,誤把同一個連通分量的兩條線接起來了。
左圖和右圖的情形在狀態表示中都是$0011$,但左圖的情形不能轉移,右圖的情形則可以,因此這樣的狀態表示會造成誤判,必須細分。所以此時只好把線的「連通性」納入狀態表示中,也就是在左圖的$0011$中,其實兩個$1$是連通的,而右圖的$0011$的兩個$1$則是不連通的。所以此時我們要想辦法把連通性壓入狀態中。對於一個狀態中的$1$代表的那條線來說,如果沿著他往回走,那麼有幾種可能:一種是走回這個狀態表示中的另外一個$1$,一種是走回左上角,一種是走回左下角。而「走回左下角」的情形在還沒有轉移左下角那格之前是不存在的。而轉移到左下角那格之前每個狀態會恰有一個$1$代表從左上角走過來的線。因此我們就可以用以下的方法表示狀態:如果這格的線是走回左上或左下角的,那麼就把這格位子的值設為$1$,而如果是走回這個狀態的另一條線的情形,就把這一對連通的線的值都設為$2$,如果有第三組就設為$3$,以此類推。這樣因為$n\leq 10$(TIOJ的數字範圍),也就是一個狀態最多有$11$個位子,用到最多數字的情形是類似$12233445566$的樣子,因此我們可以用$7$進制來表示一個儲存了連通訊息的狀態(當然沒有線的話就是$0$)。(註:事實上用$8$進制會更快,在編碼解碼的過程可以用位元運算加速。)

但這樣的表示還不夠,因為可能有很多重複的狀態,例如$12233000$和$13322000$兩個是一模一樣的狀態,所以對每個狀態我們可以把他「標準化」,把他改成「在所有其他和他等價的狀態表示中字典序最小的狀態表示法」,這只要$O(n)$掃過去就可以做到了,細節在此省略(這東西好像叫作最小表示法)。這樣可以用排列組合稍微算一下,會得到一個階段中的節點數量最多只會有幾萬個,再乘上$n^2$得到總狀態數最多幾百萬,還在可以接受的範圍內(好像還有一種表示法叫括號表示法,不過我還沒研究)。

有了狀態表示法後,轉移的部份也蠻麻煩的。首先當然要滿足不能在邊界放有線撞到邊界的拼圖,還有兩塊拼圖之間必須同時有線或無線。假設現在在轉移某個格子的某個狀態,記當前格的左邊界的狀態值為$x$,上邊界的狀態值為$y$,那麼會分成好幾種可能:

1. $x=y=0$
2. $x=0 , y\neq 0$
3. $x\neq 0,y=0$
4. $x,y\neq 0$

因為 2. 和 3. 幾乎一樣,所以等等就略過 3 的討論。對於 1. 來說,勢必要鋪上 ┏ 這塊拼圖,而轉移後的狀態的計算方法可以先在對應的兩個位子填上$7$,然後進行一次標準化得到。對於 2. 來說,可以鋪上 ┃ 和 ┗ 這兩種拼圖,而新的狀態就只要把對應的位子改成$y$就可以了,並且不用重新標準化。4. 則是最麻煩的,這時必須要放上 ┛ 。首先當 $x=y$ 時不可以放,否則就會把已經連通的分量再連起來,形成一個圈。但有個特例要判,就是如果當前格是右下角的話,那麼就可以轉移了,因為此時是把兩個$1$連起來,才能形成最後的一條哈密頓鍊。至於$x\neq y$時又分成幾種情形,如果$x=1$的話,當放上了這塊拼圖之後,另外一個值為$y$的位置就會變成$1$,因為這時候那條線就變成可以走到左上(或左下)角了,同理如果$y=1$的話,則是另一個值為$x$的位置要換成$1$。如果是$x$和$y$都不為$1$,那麼這時候就把$x$和$y$的另一個位置的值都改成同一個數就可以了。以上這些屬於 4. 的情形都必須要重新標準化一次。

實作方法有很多種,一種是預先編碼,把每種狀態都和一個$id$值對應,那麼就可以用$dp[i][j][k]$代表當前格是$(i,j)$,並且此時為第$k$個狀態的方法數,但這樣會有很多無效的狀態。我寫的方法則是直接把$dp$陣列改成 map,用壓縮過後的七進制的那個數字直接當成狀態的 index ,這樣可以在一邊DP時把新的有效的狀態加入 map 裡,不會有無效的狀態在裡面。不過不管是用哪種方法,都必須要寫兩個函式對一個$7$(或$8$)進制的數做編碼和解碼。

最後,這裡也有關於插頭DP的文章,寫的很詳細,不過我還沒全部看完@@

code :



#include<bits/stdc++.h>
#define LL long long
#define F first
#define S second
using namespace std;
const int maxn=10+2 ;
 
int n ;
int encode(int *a)
{
    int ret=0 ;
    for(int i=0;i<=n;i++) ret=ret*7+a[i] ;
    return ret ;
}
void decode(int val,int *a)
{
    for(int i=n;i>=0;i--) a[i]=val%7 , val/=7 ;
}
 
int tmp[8] ;
void norm(int *a)
{
    memset(tmp,0,sizeof(tmp)) ;
    for(int i=0,j=2;i<=n;i++) if(a[i]>1)
    {
        if(!tmp[a[i]]) tmp[a[i]]=j++ ;
        a[i]=tmp[a[i]] ;
    }
}
 
map<int,LL> dp[2] ;
int t[maxn] ;
void trans(int cur,int i,int j,int S,LL add)
{
    decode(S,t) ;
    int x=t[j] , y=t[j+1] ;
    if(x&&y)
    {
        if(x==y && (i!=n-1||j!=n-1)) return ;
        if(x==y)
        {
            assert(x==1) ;
            t[j]=t[j+1]=0 ;
        }
        else if(x==1)
        {
            for(int k=0;k<=n;k++) if(k!=j+1&&t[k]==y)
                {t[k]=1 , t[j]=t[j+1]=0 ; break ;}
        }
        else if(y==1)
        {
            for(int k=0;k<=n;k++) if(k!=j&&t[k]==x)
                {t[k]=1 , t[j]=t[j+1]=0 ; break ;}
        }
        else
        {
            for(int k=0;k<=n;k++) if(k!=j+1&&t[k]==y)
                {t[k]=x ; t[j]=t[j+1]=0 ; break ;}
        }
        norm(t) ;
        dp[cur][encode(t)]+=add ;
    }
    else if(x&&!y)
    {
        if(i!=n-1) dp[cur][S]+=add ;
        swap(t[j],t[j+1]) ;
        if(j!=n-1) dp[cur][encode(t)]+=add ;
    }
    else if(y&&!x)
    {
        if(j!=n-1) dp[cur][S]+=add ;
        swap(t[j],t[j+1]) ;
        if(i!=n-1) dp[cur][encode(t)]+=add ;
    }
    else if(i!=n-1 && j!=n-1)
    {
        t[j]=t[j+1]=7 ;
        norm(t) ;
        dp[cur][encode(t)]+=add ;
    }
}
 
void solve()
{
    int cur=0 ;
    for(int i=0;i<n;i++) for(int j=0;j<=n;j++,cur^=1)
    {
        dp[cur].clear() ;
        if(i==0 && j==0)
        {
            memset(t,0,sizeof(t)) ;
            t[0]=1 ; dp[cur][encode(t)]=1 ; t[0]=0 ;
            t[1]=1 ; dp[cur][encode(t)]=1 ;
        }
        else if(j==n) for(auto it : dp[cur^1])
            dp[cur][it.F/7]+=it.S ;
        else if(i==n-1&&j==0) for(auto it : dp[cur^1])
        {
            decode(it.F,t) ;
            if(t[1]==1) continue ;
            if(!t[1]) t[1]=1 ;
            else for(int k=2;k<=n;k++) if(t[k]==t[1])
                {t[1]=0 ; t[k]=1 ; break ;}
            norm(t) ;
            dp[cur][encode(t)]+=it.S ;
        }
        else for(auto it : dp[cur^1])
            trans(cur,i,j,it.F,it.S) ;
    }
    printf("%lld\n",dp[cur^1][0]) ;
}
 
main()
{
    while(scanf("%d",&n)!=EOF) solve() ;
}
 

沒有留言:

張貼留言