我們考慮一個非常樸素的DP:設$dp[i][j]$代表目前已把前$i$個數弄成非嚴格遞增了,則第$i$個數的值為$j$的最小花費,那麼顯然可以寫出轉移式:$dp[i][j]=min\{dp[i-1][k]+|a[i]-j|,k=1,...,j\}$。想像一下$dp[i]$的函數圖形,那麼他就是由$dp[i-1]$先取前綴 min ,再加上$|x-a[i]|$這個函數所形成的圖形(前綴 min 的意思就是前面轉移式中的$min\{ dp[i-1][k],k=1,...,j \}$)。不難發現他其實會形成一個下凸包,並且轉折的點只會出現在$x$座標為原本$a$數列裡出現過的值,並且每次加上$|x-a[i]|$這個函數時,可以看成$a[i]$左邊的斜率全部$-1$,右邊的斜率全部$+1$,因此就可以用線段樹來維護斜率們(所以在這之前要先離散化$a[i]$),並且取前綴 min 的操作可以等價成:把斜率$>0$的部份都設成$0$(因為他會一直維持他是下凸包的良好性質),這也不難在線段樹上做到。最後我們想知道的是$dp[n]$這個函數圖形的最小值是多少,但線段樹中只有紀錄每個區間的斜率,因此只要再多紀錄函數在$0$的值是多少就可以了(而他顯然會是所有$a[i]$的和),從左到右掃一遍就可以獲得$dp[n]$圖形的每個轉折點的座標,取$y$座標最小的點就是答案了。
最後,在「把斜率$>0$的部份都設成$0$」的操作其實可以很簡潔的完成,我們可以對每個線段樹區間都維護他的最小值,那麼就可以直接從根走下去,發現如果當前節點的右孩子的最小值$>0$(其實也只有可能是$1$),就把右孩子全部$-1$,往左孩子遞迴,反之則往右孩子遞迴就可以了。
code :
#include<bits/stdc++.h> #define LL long long using namespace std; const int maxn=3000000+10 ; struct node { node *l,*r ; int mi,tag ; node(){l=r=NULL ; mi=tag=0 ;} }; void push(node *u) { if(!u->tag) return ; u->l->mi+=u->tag ; u->l->tag+=u->tag ; u->r->mi+=u->tag ; u->r->tag+=u->tag ; u->tag=0 ; } void pull(node *u){u->mi=min(u->l->mi,u->r->mi) ;} node *build(int l,int r) { if(l==r) return new node ; node *u=new node ; int mid=(l+r)/2 ; u->l=build(l,mid) ; u->r=build(mid+1,r) ; return u ; } void modify(int l,int r,int L,int R,node *u,int add) { if(l==L && r==R){u->tag+=add ; u->mi+=add ; return ;} push(u) ; int mid=(L+R)/2 ; if(r<=mid) modify(l,r,L,mid,u->l,add) ; else if(l>mid) modify(l,r,mid+1,R,u->r,add) ; else modify(l,mid,L,mid,u->l,add) , modify(mid+1,r,mid+1,R,u->r,add) ; pull(u) ; } int query(int L,int R,node *u,int pos) { if(L==R) return u->tag ; int mid=(L+R)/2 ; if(pos<=mid) return query(L,mid,u->l,pos)+u->tag ; else return query(mid+1,R,u->r,pos)+u->tag ; } void setzero(int L,int R,node *u) { if(L==R) { if(u->tag>0) u->tag=u->mi=0 ; return ; } push(u) ; int mid=(L+R)/2 ; if(u->r->mi>0) { u->r->tag-- ; u->r->mi-- ; setzero(L,mid,u->l) ; } else setzero(mid+1,R,u->r) ; } int a[maxn] ; vector<int> v(maxn) ; int ID(int x) { return lower_bound(v.begin(),v.end(),x)-v.begin() ; } main() { int n ; scanf("%d",&n) ; LL tot=0 ; for(int i=1;i<=n;i++) scanf("%d",&a[i]) , v[i-1]=a[i] , tot+=a[i] ; v.resize(n) ; sort(v.begin(),v.end()) ; v.resize(unique(v.begin(),v.end())-v.begin()) ; int sz=v.size() ; node *root=build(0,sz) ; for(int i=1;i<=n;i++) { setzero(0,sz,root) ; int id=ID(a[i]) ; modify(0,id,0,sz,root,-1) ; modify(id+1,sz,0,sz,root,1) ; } LL ans=tot,now=tot ; for(int i=0,last=0;i<sz;last=v[i++]) { now+=(LL)query(0,sz,root,i)*(v[i]-last) ; ans=min(ans,now) ; } printf("%lld\n",ans) ; }
沒有留言:
張貼留言