洛谷:https://www.luogu.org/problemnew/show/U50124
Description
给定一个 n*n 的矩阵 A 以及一个正整数 k,计算\(S = A^1 + A^2 + A^3+...+A^k\)
Input
输入只有一组测试数据。输入的第一行包括三个正整数 n, k, m。接下来的 n 行每行包括 n 个非负整数,按照行优先的顺序输入矩阵 A 的元素。
Output
输出 S 中每一个元素 mod(%)m 以后的值
Sample Input
2 2 4 0 1 1 1
Sample Output
1 2 2 3
解析及代码
暴力:
我们考虑暴力解法,根据题意模拟,每次令i从1枚举到k,矩阵快速幂求解\(power(A,i)\),进行矩阵加法将\(power(A,i)\)累加起来。我觉得复杂度是\(O(k(n^3logn + n^2))\)的,反正肯定过不了。
//matrix.cpp
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <vector>
#include <queue>
using namespace std;
const int MAXN = 35;
struct Mat {
int a[MAXN][MAXN];
};
int n,k,MOD;
Mat x;
Mat Ans;
inline Mat jz(Mat a,Mat b){
Mat tmp;
memset(tmp.a,0,sizeof tmp.a);
for(int i = 1;i <= n;++i)
for(int j = 1;j <= n;++j)
for(int k = 1;k <= n;++k)
tmp.a[i][j] = (tmp.a[i][j] + (a.a[i][k] * b.a[k][j]) % MOD) % MOD;
return tmp;
}
inline Mat power(Mat a,long long b){
Mat ans;
memset(ans.a,0,sizeof ans.a);
for(int i = 1;i <= n;++i)
ans.a[i][i]=1;
while(b){
if (b & 1) ans = jz(ans,a);
a = jz(a,a);
b >>= 1;
}
return ans;
}
inline void Add(Mat T) {
for(int i = 1;i <= n;++i)
for(int j = 1;j <= n;++j)
Ans.a[i][j] += T.a[i][j],Ans.a[i][j] %= MOD;
}
int main() {
freopen("matrix.in","r",stdin);
freopen("matrix.out","w",stdout);
scanf("%d%d%d",&n,&k,&MOD);
for(int i = 1;i <= n;++i)
for(int j = 1;j <= n;++j)
scanf("%d",&x.a[i][j]);
for(int i = 1;i <= k;++i)
Add(power(x,i));
for(int i = 1;i <= n;++i) {
for(int j = 1;j <= n;++j)
printf("%d ",Ans.a[i][j] % MOD);
puts("");
}
return 0;
}
正解:
考虑分类讨论,将k分为奇数和偶数进行分类讨论,对于k为偶数的情况,我们可以推出如下规律(例如k为6)
\(S(6)=A^1+A^2+A^3+A^4+A^5+A^6\) \(=A^1+A^2+A^3+A^3*(A+A^2+A^3)\) \(=S(3)*(1 + A^3)\)所以当k为偶数时,就有:\(S(k)=S(k/2)*(1+A^\frac{k}{2})\)
那么当k为奇数时,显然有:\(S(k)=S(k-1)+A^k\)
根据以上规律,二分进行求解 复杂度不知道多少,反正是:\(O(AC)\)
//matrix.cpp
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <vector>
#include <queue>
using namespace std;
const int MAXN = 35;
struct Mat {
int a[MAXN][MAXN];
void clear() {
memset(a,0,sizeof a);
}
};
int n,k,MOD;
Mat x;
Mat Ans;
inline Mat Mul(Mat a,Mat b){
Mat tmp;
tmp.clear();
for(int i = 1;i <= n;++i)
for(int j = 1;j <= n;++j)
for(int k = 1;k <= n;++k)
tmp.a[i][j] = (tmp.a[i][j] + (a.a[i][k] * b.a[k][j]) % MOD) % MOD;
return tmp;
}
inline Mat power(Mat a,long long b){
Mat ans;
ans.clear();
for(int i = 1;i <= n;++i)
ans.a[i][i]=1;
while(b){
if (b & 1) ans = Mul(ans,a);
a = Mul(a,a);
b >>= 1;
}
return ans;
}
inline Mat Add(Mat T_1,Mat T_2) {
Mat Tmp;
Tmp.clear();
for(int i = 1;i <= n;++i)
for(int j = 1;j <= n;++j)
Tmp.a[i][j] = T_1.a[i][j] + T_2.a[i][j],Tmp.a[i][j] %= MOD;
return Tmp;
}
inline Mat Work(Mat a,int k) {
if(k == 1) return a;
if(k & 1) return Add(Work(a,k - 1),power(a,k));
else return Mul(Add(power(a,0),power(a,k >> 1)),Work(a,k >> 1));
}
int main() {
freopen("matrix.in","r",stdin);
freopen("matrix.out","w",stdout);
scanf("%d%d%d",&n,&k,&MOD);
for(int i = 1;i <= n;++i)
for(int j = 1;j <= n;++j)
scanf("%d",&x.a[i][j]);
Ans = Work(x,k);
for(int i = 1;i <= n;++i) {
for(int j = 1;j <= n;++j)
printf("%d ",Ans.a[i][j] % MOD);
puts("");
}
return 0;
}
文章评论