考慮一個斜的正方形,設他下面頂點指向右邊頂點的向量為(x,y)(也就是如果下端點是(0,0),那麼剩下三個點的座標為(x,y),(x-y,x+y),(-y,x)),我們想要知道這個正方形裡面包含了幾個單位正方形,而一個著名的結論是:給定一個n\times m的方格表,那麼他的對角線會經過n+m-gcd(n,m)個格點。因此就可以用這個結論算出這個正方形包住了幾個單位正方形了:把這個正方形補成邊平形座標軸的正方形,並把他切成四個x\times y的長方形,還有中間邊長|x-y|的正方形。那麼就可以知道在四個長方形中分別包含了\frac{xy-x-y+gcd(x,y)}{2}個所求格子了,再加上中間的正方形(x-y)^2個,就共有(x-y)^2+2(xy-x-y+gcd(x,y))=x^2+y^2-2(x+y-gcd(x,y))個。並且注意到這種正方形因為要用長x+y的正方形包住他,所以會在原本給定的範圍裡出現(n-x-y+1)(m-x-y+1)次,其中x+y\leq min(n,m)。並且注意到我們也要算x=0的情況(上面通式也會對),但是不能也算y=0的,否則會重複算到。因此就可以把答案寫成:
\displaystyle \sum_{x\geq 0,y>0,x+y\leq min(n,m)} (n-x-y+1)(m-x-y+1)(x^2+y^2-2(x+y-gcd(x,y)))
考慮把x+y值相同的一起算,令t=x+y,那麼就可以改寫成:
\displaystyle \sum_{t=1}^{min(n,m)}\sum_{x=0}^{t-1} (n-t+1)(m-t+1)(x^2+(t-x)^2-2t+2gcd(t,x))
其中用到了簡單的gcd的性質:gcd(x,t-x)=gcd(x,t)。我們可以把所求的gcd的部份拆出來,也就是改寫成
\displaystyle \sum_{t=1}^{min(n,m)}\sum_{x=0}^{t-1} (n-t+1)(m-t+1)(x^2+(t-x)^2-2t)+\displaystyle 2\sum_{t=1}^{min(n,m)}\sum_{x=0}^{t-1} (n-t+1)(m-t+1)gcd(t,x)
首先第一項其實就是一個t的多項式,只是係數會有n,m之類的,因此我們可以先預處理好t^i的前綴和陣列,其中i=1,...,6,就可以在O(1)算出第一項了(或是借住電腦的力量直接把通式爆出來XD)。再來則是第二項,把(n-t+1)(m-t+1)乘開後會得到其實我們主要需要的東西就是\displaystyle \sum_{x=0}^{t-1} gcd(t,x),還有\displaystyle \sum_{x=0}^{t-1} t\cdot gcd(t,x)和\displaystyle \sum_{x=0}^{t-1} t^2\cdot gcd(t,x)的前綴和陣列,而我們只要有辦法對每個t都算出\displaystyle \sum_{x=0}^{t-1} gcd(t,x)的值就可以了。記我們要求的這個陣列為f,我們看某個數i在哪些gcd(t,x)的項出現了,首先i要整除t,因此只有index 被i整除的f值要加上一些i。至於要加上多少個i,因為gcd(t,x)=i,所以我們也可以設x=y\cdot i,其中y<\frac{t}{i}。這樣就可以兩邊同除以i,變成gcd(\frac{t}{i},y)=1,因此我們需要知道1,...,\frac{t}{i}-1中有幾個數和\frac{t}{i}互質,這正是歐拉函數 phi ,因此只要先預處理好每個數的 phi 值,再用他算出f陣列的值就可以了。
code :
#include<bits/stdc++.h> #define LL long long #define MOD 1000000007 using namespace std; const int maxn=1000000+10 ; LL pw(LL x,int n) { if(n<=1) return n ? x : 1 ; LL t=pw(x,n/2) ; if(n%2) return (t*t%MOD)*x%MOD ; else return t*t%MOD ; } LL inv(LL x) { return pw(x,MOD-2) ; } int phi[maxn] ; LL f[maxn],s1[maxn],s2[maxn],s3[maxn] ; main() { for(int i=1;i<maxn;i++) { phi[i]+=i ; for(int j=2*i;j<maxn;j+=i) phi[j]-=phi[i] ; } for(int i=1;i<maxn;i++) { f[i]=(f[i]+i)%MOD ; s1[i]=(s1[i-1]+f[i])%MOD ; s2[i]=(s2[i-1]+f[i]*i)%MOD ; s3[i]=(s3[i-1]+f[i]*i%MOD*i)%MOD ; for(int j=2*i,k=2;j<maxn;j+=i,k++) f[j]+=i*phi[k] ; } int T ; scanf("%d",&T) ; while(T--) { LL n,m,r ; scanf("%lld%lld",&n,&m) ; r=min(m,n) ; LL X=r*(r+1)%MOD*inv(180)%MOD ; LL Y=6*m*(5*n*((r*r-3*r-1)%MOD)%MOD-4*r*r%MOD*r%MOD+14*r*r%MOD-4*r-6)%MOD +(r-1)*(n*((-24*r*r%MOD+60*r+36)%MOD)%MOD+ (20*r*r%MOD)*r%MOD-60*r*r%MOD-5*r+30)%MOD ; LL ans=X*Y%MOD ; ans+=((m+1)*(n+1)%MOD*s1[r]%MOD-(m+n+2)*s2[r]%MOD+s3[r])*2%MOD + MOD ; printf("%lld\n",((ans%MOD)+MOD)%MOD) ; } }
沒有留言:
張貼留言