觀察f函數的式子可以發現他就類似某種最短路,i指的就是走了i步,因為f(i,j)都是從f(i-1,x)轉移過來的。觀察到這件事就可以先把原題轉換成最短路問題:考慮一張n+1個點的圖,第i有連到自己的邊權為a[i]的邊(i=1,...,n),i也有連到i+1的邊權為a[i+1]的邊(i=1,...,n-1),並且n+1有連到i的邊權為a[i]的邊(i=1,..,n),那麼f(i,j)就等於從n+1開始走i步到達j的最短距離。假設現在要算f(x,y),如果從n+1開始第一步走到k,那麼走i步的距離就可以表示成\displaystyle \sum_{i=k}^{y} c[i]\cdot a[i],其中c[i]同時也代表這條路徑在i走了c[i]-1次自環。並且由走x步的限制可知\displaystyle \sum_{i=k}^{y} c[i]=x,並且顯然的有c[i]\geq 1,因此k必須滿足k\geq y-x+1。接下來看達到最小值時c陣列應該要長怎樣,考慮a[k],...,a[y]中最小的那個數,那麼把所有的c都集中到那個數一定是最好的,並且如果那個數不是k的話可以把k的c值也丟到他身上,因此最佳解中c陣列一定長的像:c[k+1]=...=c[y]=1。此時我們得到的路徑長度=a[k+1]+...+a[y]+(x-y+k)a[k]
=S[y]-S[k]+(x-y+k)a[k]
=S[y]-(y-x)a[k]+(k\cdot a[k]-S[k]),其中S[i]=a[1]+...+a[i]。
其中的S[y]可以先不用理他,最後再加到答案上就可以了。至於後面那項,因為在詢問的時候只有y-x的值會變,a[k],S[k]的值都是不變的(k固定時),因此當不同的y-x代入時可以把這東西看成一個一次函數,也就是令直線L[i]:y=-a[k]\cdot x+(k\cdot a[k]-S[k]),那麼這樣在詢問f(x,y)時就等價於詢問y-x這個數代入L[y-x+1],...,L[y]的最小值是多少。這可以用線段樹套凸包來作,具體來說,假設線段樹的某個節點區間為[L,R],那麼就在這個節點先算好第L,...,R條直線形成的上凸包長怎樣(紀錄頂點們和直線們),這樣就可以在這個節點上用log的時間查詢某個數代入這個區間中所有直線所得到的最小值是多少了。而當詢問一個區間時就只要把他拆成好幾個線段樹的節點,合併起來就是答案了。
code :
#include<bits/stdc++.h> #define DB double using namespace std; const int maxn=100000+10 ; struct line { int a,b ; /// y=ax+b bool operator < (const line &rhs) const { return a==rhs.a ? b<rhs.b : a>rhs.a ; } }; DB inter(const line &p,const line &q) { return (q.b-p.b)*1.0/(p.a-q.a) ; } struct node { node *l,*r ; vector<line> vl ; vector<DB> vx ; void build(int L,int R) ; int query(int x) ; }; line li[maxn],tmp[maxn] ; void node::build(int L,int R) { for(int i=L;i<=R;i++) tmp[i-L]=li[i] ; int sz=R-L+1 ; sort(tmp,tmp+sz) ; for(int i=0;i<sz;i++) { if(!vl.empty() && vl.back().a==tmp[i].a) continue ; while(vl.size()>=2) { int s=vl.size() ; if(inter(vl[s-2],vl[s-1])<inter(vl[s-1],tmp[i])) break ; vl.pop_back() ; vx.pop_back() ; } if(!vl.empty()) vx.push_back(inter(vl.back(),tmp[i])) ; vl.push_back(tmp[i]) ; } } int node::query(int x) { int id=lower_bound(vx.begin(),vx.end(),x)-vx.begin() ; int ret=vl[id].a*x+vl[id].b ; if(id+1<vl.size()) ret=min(ret,vl[id+1].a*x+vl[id+1].b) ; if(id-1>=0) ret=min(ret,vl[id-1].a*x+vl[id-1].b) ; return ret ; } node *build(int l,int r) { node *u=new node ; u->build(l,r) ; if(l==r) return u ; int mid=(l+r)/2 ; u->l=build(l,mid) ; u->r=build(mid+1,r) ; return u ; } int query(int l,int r,int L,int R,int x,node *u) { if(l==L && r==R) return u->query(x) ; int mid=(L+R)/2 ; if(r<=mid) return query(l,r,L,mid,x,u->l) ; else if(l>mid) return query(l,r,mid+1,R,x,u->r) ; else return min(query(l,mid,L,mid,x,u->l), query(mid+1,r,mid+1,R,x,u->r)) ; } int a[maxn],s[maxn] ; main() { int n,Q ; scanf("%d",&n) ; for(int i=1;i<=n;i++) { scanf("%d",&a[i]) , s[i]=s[i-1]+a[i] ; li[i].a=-a[i] ; li[i].b=i*a[i]-s[i] ; } node *root=build(1,n) ; scanf("%d",&Q) ; while(Q--) { int i,j ; scanf("%d%d",&i,&j) ; printf("%d\n",query(j-i+1,j,1,n,j-i,root)+s[j]) ; } }
沒有留言:
張貼留言