這題用到了DP的斜率優化,APIO 2010 Commando ( HOJ 236 ) 也用到這個技巧,相關的東西可以在這個網站裡看到。
題目簡單來說就是要把一個數列分成 k 陀 ,使得第一坨裡的總和 * 0,加第二陀裡的總和 * 1 ,一直加到 第 k 陀的總和 * (k-1) ,再扣掉每坨的大小的平方,這個數字要最大。而會發現把整個數列先反過來會比較好DP,所以如果第一步是先把整個數列反過來,設 dp[ i ] 代表 1~ i 的最佳答案,那麼就可以寫出轉移式
dp[ i ] = max { dp[ j ] + S[ j ] - ( i-j )^2 } , j = max( 0,i-k ) ~ i-1
其中 S[ j ] 是前綴和。顯然直接作是O(n^2)的,如果把這條式子化簡,會得到
-i^2 + 2 * i * j - j^2 + dp[ j ] + S{ j ]
( 提醒一個小細節: -i^2 會爆 int ,記得轉成 long long 再乘。 )
因為 -i^2 不影響取不同的 j 的值之間的大小關係,所以可以先不理他,我們要讓後面那坨最大,所以考慮直線 y = ( 2 * j ) x + ( dp[ j ] + S[ j ] - j^2 ) ,記 A[ j ] 和 B[ j ] 分別代表 x 前的係數和常數項,所以這樣等於是要讓 i 這個點代入某條直線 A[ j ] * x + B[ j ] 之後要最大,並且因為每條直線是按照斜率由小到大加入的,又因為每次查詢的 x 坐標是遞增的,所以就可以用上面那個網站講的方法用 deque 作。
具體作法是,每次詢問時先從後面 pop 掉過期的直線( 他不在 i-k ~ i-1 的範圍內,所以不能考慮 ),然後從 deque 的後面開始,如果 i 代入最後一個直線比代入最後第二個直線的值小,則也把最後一個 pop 掉,因為從今以後最後一個都不會比最後第二個好了( 因為斜率遞增且詢問的值遞增!! )。處理完之後就可以得到目前在 deque 的最後的直線就是我們要的。然後要從 deque 前面加入這個 i 代表的直線,記在deque裡前面數來第二條直線叫L1,最前面的叫L2,現在要加入的叫L3,那麼如果L3和L2的交點在L1和L2的交點的左邊的話,L2就廢掉了,必須把他pop掉,pop完之後再加入 L3 就可以了。
但這個作法只拿到 WA 30分,我另外寫了個直接O(n^2)的作法傳上去,沒有TLE的測資都是AC的,這代表是後面斜率優化的地方出問題了,之後我自己 random 生測資並且在剛算出來 dp[ i ] 的時候 assert 那條最原始的式子,發現他常常出錯,於是我仔細把每個階段 deque 裡的東西 print 出來才發現,我有可能在處理 i 的時候,已經把最佳的直線 pop 掉了!
回到那條最佳的直線被 pop 掉的時候,這時候是在「加入新的直線並 pop 廢掉的直線」的階段,也就是L1和L3聯手把L2幹掉的時候,這代表我們認為「在L1和L3的交點以前,L1是最佳選擇,以後的話則是L3是最佳選擇」,但「在L1和L3的交點以前」有可能 L1 已經過期了!!! 也就是這裡的「直線」其實只是「線段」,不能完全按照之前的作法作。所以當在決定 L2 是否廢掉的時候,必須把L1和L2的交點和「L1線段的右端點 x 坐標」取 min ,再去和 L3和L2的交點比較。並且我們知道 L1 是在它的起點 x 座標 + k 的時候過期的。
對了,題目也沒有提到任何數字範圍,我也很怕在判斷交點那邊交叉相乘後 long long 會爆掉,轉成 double 因為數字太大精準度感覺會爛掉,還好最後是沒有爆 XD。
另外,這個網站也有這題的作法喔,也可以看看XD
code : (好短阿(汗
#include<bits/stdc++.h> #define LL long long using namespace std; const int maxn=1000000+10 ; LL a[maxn] ; LL dp[maxn],s[maxn],A[maxn],B[maxn] ; int dq[maxn] ; main() { int n,k ; scanf("%d%d",&n,&k) ; for(int i=1;i<=n;i++) scanf("%lld",&a[n+1-i]) ; for(int i=1;i<=n;i++) s[i]=s[i-1]+a[i] ; dp[0]=0LL ; A[0]=B[0]=0LL ; int l=0 , r=1 ; dq[0]=0 ; for(int i=1;i<=n;i++) { while(dq[l]+k<i) l++ ; while(l+1<r && A[dq[l]]*i+B[dq[l]] <= A[dq[l+1]]*i+B[dq[l+1]]) l++ ; dp[i]=-((LL)i)*((LL)i)+A[dq[l]]*i+B[dq[l]] ; A[i]=2*i ; B[i]=dp[i]+s[i]-((LL)i)*((LL)i) ; while(l+1<r && (B[dq[r-2]]-B[dq[r-1]])*(A[i]-A[dq[r-1]]) >= (B[dq[r-1]]-B[i])*(A[dq[r-1]]-A[dq[r-2]]) && (dq[r-2]+k)*(A[i]-A[dq[r-1]]) >= B[dq[r-1]]-B[i]) r-- ; dq[r++]=i ; } printf("%lld\n",dp[n]) ; }
沒有留言:
張貼留言