觀察$f$函數的式子可以發現他就類似某種最短路,$i$指的就是走了$i$步,因為$f(i,j)$都是從$f(i-1,x)$轉移過來的。觀察到這件事就可以先把原題轉換成最短路問題:考慮一張$n+1$個點的圖,第$i$有連到自己的邊權為$a[i]$的邊($i=1,...,n$),$i$也有連到$i+1$的邊權為$a[i+1]$的邊($i=1,...,n-1$),並且$n+1$有連到$i$的邊權為$a[i]$的邊($i=1,..,n$),那麼$f(i,j)$就等於從$n+1$開始走$i$步到達$j$的最短距離。假設現在要算$f(x,y)$,如果從$n+1$開始第一步走到$k$,那麼走$i$步的距離就可以表示成$\displaystyle \sum_{i=k}^{y} c[i]\cdot a[i]$,其中$c[i]$同時也代表這條路徑在$i$走了$c[i]-1$次自環。並且由走$x$步的限制可知$\displaystyle \sum_{i=k}^{y} c[i]=x$,並且顯然的有$c[i]\geq 1$,因此$k$必須滿足$k\geq y-x+1$。接下來看達到最小值時$c$陣列應該要長怎樣,考慮$a[k],...,a[y]$中最小的那個數,那麼把所有的$c$都集中到那個數一定是最好的,並且如果那個數不是$k$的話可以把$k$的$c$值也丟到他身上,因此最佳解中$c$陣列一定長的像:$c[k+1]=...=c[y]=1$。此時我們得到的路徑長度$=a[k+1]+...+a[y]+(x-y+k)a[k]$
$=S[y]-S[k]+(x-y+k)a[k]$
$=S[y]-(y-x)a[k]+(k\cdot a[k]-S[k])$,其中$S[i]=a[1]+...+a[i]$。
其中的$S[y]$可以先不用理他,最後再加到答案上就可以了。至於後面那項,因為在詢問的時候只有$y-x$的值會變,$a[k],S[k]$的值都是不變的($k$固定時),因此當不同的$y-x$代入時可以把這東西看成一個一次函數,也就是令直線$L[i]:y=-a[k]\cdot x+(k\cdot a[k]-S[k])$,那麼這樣在詢問$f(x,y)$時就等價於詢問$y-x$這個數代入$L[y-x+1],...,L[y]$的最小值是多少。這可以用線段樹套凸包來作,具體來說,假設線段樹的某個節點區間為$[L,R]$,那麼就在這個節點先算好第$L,...,R$條直線形成的上凸包長怎樣(紀錄頂點們和直線們),這樣就可以在這個節點上用$log$的時間查詢某個數代入這個區間中所有直線所得到的最小值是多少了。而當詢問一個區間時就只要把他拆成好幾個線段樹的節點,合併起來就是答案了。
code :
#include<bits/stdc++.h> #define DB double using namespace std; const int maxn=100000+10 ; struct line { int a,b ; /// y=ax+b bool operator < (const line &rhs) const { return a==rhs.a ? b<rhs.b : a>rhs.a ; } }; DB inter(const line &p,const line &q) { return (q.b-p.b)*1.0/(p.a-q.a) ; } struct node { node *l,*r ; vector<line> vl ; vector<DB> vx ; void build(int L,int R) ; int query(int x) ; }; line li[maxn],tmp[maxn] ; void node::build(int L,int R) { for(int i=L;i<=R;i++) tmp[i-L]=li[i] ; int sz=R-L+1 ; sort(tmp,tmp+sz) ; for(int i=0;i<sz;i++) { if(!vl.empty() && vl.back().a==tmp[i].a) continue ; while(vl.size()>=2) { int s=vl.size() ; if(inter(vl[s-2],vl[s-1])<inter(vl[s-1],tmp[i])) break ; vl.pop_back() ; vx.pop_back() ; } if(!vl.empty()) vx.push_back(inter(vl.back(),tmp[i])) ; vl.push_back(tmp[i]) ; } } int node::query(int x) { int id=lower_bound(vx.begin(),vx.end(),x)-vx.begin() ; int ret=vl[id].a*x+vl[id].b ; if(id+1<vl.size()) ret=min(ret,vl[id+1].a*x+vl[id+1].b) ; if(id-1>=0) ret=min(ret,vl[id-1].a*x+vl[id-1].b) ; return ret ; } node *build(int l,int r) { node *u=new node ; u->build(l,r) ; if(l==r) return u ; int mid=(l+r)/2 ; u->l=build(l,mid) ; u->r=build(mid+1,r) ; return u ; } int query(int l,int r,int L,int R,int x,node *u) { if(l==L && r==R) return u->query(x) ; int mid=(L+R)/2 ; if(r<=mid) return query(l,r,L,mid,x,u->l) ; else if(l>mid) return query(l,r,mid+1,R,x,u->r) ; else return min(query(l,mid,L,mid,x,u->l), query(mid+1,r,mid+1,R,x,u->r)) ; } int a[maxn],s[maxn] ; main() { int n,Q ; scanf("%d",&n) ; for(int i=1;i<=n;i++) { scanf("%d",&a[i]) , s[i]=s[i-1]+a[i] ; li[i].a=-a[i] ; li[i].b=i*a[i]-s[i] ; } node *root=build(1,n) ; scanf("%d",&Q) ; while(Q--) { int i,j ; scanf("%d%d",&i,&j) ; printf("%d\n",query(j-i+1,j,1,n,j-i,root)+s[j]) ; } }
沒有留言:
張貼留言