洛谷:https://www.luogu.org/problemnew/show/U50124
Description
给定一个 n*n 的矩阵 A 以及一个正整数 k,计算\(S = A^1 + A^2 + A^3+...+A^k\)
Input
输入只有一组测试数据。输入的第一行包括三个正整数 n, k, m。接下来的 n 行每行包括 n 个非负整数,按照行优先的顺序输入矩阵 A 的元素。
Output
输出 S 中每一个元素 mod(%)m 以后的值
Sample Input
2 2 4 0 1 1 1
Sample Output
1 2 2 3
解析及代码
暴力:
我们考虑暴力解法,根据题意模拟,每次令i从1枚举到k,矩阵快速幂求解\(power(A,i)\),进行矩阵加法将\(power(A,i)\)累加起来。我觉得复杂度是\(O(k(n^3logn + n^2))\)的,反正肯定过不了。
//matrix.cpp #include <cstdio> #include <cmath> #include <cstring> #include <algorithm> #include <iostream> #include <vector> #include <queue> using namespace std; const int MAXN = 35; struct Mat { int a[MAXN][MAXN]; }; int n,k,MOD; Mat x; Mat Ans; inline Mat jz(Mat a,Mat b){ Mat tmp; memset(tmp.a,0,sizeof tmp.a); for(int i = 1;i <= n;++i) for(int j = 1;j <= n;++j) for(int k = 1;k <= n;++k) tmp.a[i][j] = (tmp.a[i][j] + (a.a[i][k] * b.a[k][j]) % MOD) % MOD; return tmp; } inline Mat power(Mat a,long long b){ Mat ans; memset(ans.a,0,sizeof ans.a); for(int i = 1;i <= n;++i) ans.a[i][i]=1; while(b){ if (b & 1) ans = jz(ans,a); a = jz(a,a); b >>= 1; } return ans; } inline void Add(Mat T) { for(int i = 1;i <= n;++i) for(int j = 1;j <= n;++j) Ans.a[i][j] += T.a[i][j],Ans.a[i][j] %= MOD; } int main() { freopen("matrix.in","r",stdin); freopen("matrix.out","w",stdout); scanf("%d%d%d",&n,&k,&MOD); for(int i = 1;i <= n;++i) for(int j = 1;j <= n;++j) scanf("%d",&x.a[i][j]); for(int i = 1;i <= k;++i) Add(power(x,i)); for(int i = 1;i <= n;++i) { for(int j = 1;j <= n;++j) printf("%d ",Ans.a[i][j] % MOD); puts(""); } return 0; }
正解:
考虑分类讨论,将k分为奇数和偶数进行分类讨论,对于k为偶数的情况,我们可以推出如下规律(例如k为6)
\(S(6)=A^1+A^2+A^3+A^4+A^5+A^6\) \(=A^1+A^2+A^3+A^3*(A+A^2+A^3)\) \(=S(3)*(1 + A^3)\)所以当k为偶数时,就有:\(S(k)=S(k/2)*(1+A^\frac{k}{2})\)
那么当k为奇数时,显然有:\(S(k)=S(k-1)+A^k\)
根据以上规律,二分进行求解 复杂度不知道多少,反正是:\(O(AC)\)
//matrix.cpp #include <cstdio> #include <cmath> #include <cstring> #include <algorithm> #include <iostream> #include <vector> #include <queue> using namespace std; const int MAXN = 35; struct Mat { int a[MAXN][MAXN]; void clear() { memset(a,0,sizeof a); } }; int n,k,MOD; Mat x; Mat Ans; inline Mat Mul(Mat a,Mat b){ Mat tmp; tmp.clear(); for(int i = 1;i <= n;++i) for(int j = 1;j <= n;++j) for(int k = 1;k <= n;++k) tmp.a[i][j] = (tmp.a[i][j] + (a.a[i][k] * b.a[k][j]) % MOD) % MOD; return tmp; } inline Mat power(Mat a,long long b){ Mat ans; ans.clear(); for(int i = 1;i <= n;++i) ans.a[i][i]=1; while(b){ if (b & 1) ans = Mul(ans,a); a = Mul(a,a); b >>= 1; } return ans; } inline Mat Add(Mat T_1,Mat T_2) { Mat Tmp; Tmp.clear(); for(int i = 1;i <= n;++i) for(int j = 1;j <= n;++j) Tmp.a[i][j] = T_1.a[i][j] + T_2.a[i][j],Tmp.a[i][j] %= MOD; return Tmp; } inline Mat Work(Mat a,int k) { if(k == 1) return a; if(k & 1) return Add(Work(a,k - 1),power(a,k)); else return Mul(Add(power(a,0),power(a,k >> 1)),Work(a,k >> 1)); } int main() { freopen("matrix.in","r",stdin); freopen("matrix.out","w",stdout); scanf("%d%d%d",&n,&k,&MOD); for(int i = 1;i <= n;++i) for(int j = 1;j <= n;++j) scanf("%d",&x.a[i][j]); Ans = Work(x,k); for(int i = 1;i <= n;++i) { for(int j = 1;j <= n;++j) printf("%d ",Ans.a[i][j] % MOD); puts(""); } return 0; }
文章评论