POJ3233 Matrix Power Series
题意
给你一个 的矩阵,设 求 S。
思路
矩阵快速幂,想想怎么加速递推。
众所周知单位矩阵为 E,零矩阵为 O。
推导:
\begin{gather} \begin{bmatrix} E & E \\ O & A \\ \end{bmatrix} \times \begin{bmatrix} S_{k - 1} \\ A^{k} \\ \end{bmatrix} = \begin{bmatrix} S_k \\ A^{k + 1} \\ \end{bmatrix} \\ \Downarrow \\ \begin{bmatrix} E & E \\ O & A \\ \end{bmatrix}^{k-1} \times \begin{bmatrix} S_1 \\ A^2 \\ \end{bmatrix} = \begin{bmatrix} S_k \\ A^{k+1} \end{bmatrix} \end{gather}然后就可以愉快地递推了。
代码
// #include <bits/stdc++.h>
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;
ll h, k, p;
struct matrix
{
int n, m;
ll z[70][70]; //开小了两倍空间越界,开大了就爆栈,调了一个半小时。
matrix()
{
memset(z, 0, sizeof(z));
}
matrix operator * (const matrix &x) const
{
matrix u;
u.n = n;
u.m = x.m;
for(int i = 1; i <= n; i++)
for(int k = 1; k <= m; k++)
for(int j = 1; j <= x.m; j++)
u.z[i][j] = (u.z[i][j] + (z[i][k] * x.z[k][j] % p) + p) % p;
return u;
}
};
matrix Ach(int x)
{
matrix res;
res.n = res.m = x;
for(int i = 1; i <= x; i++)
res.z[i][i] = 1;
return res;
}
matrix Ksm(matrix x, int y)
{
if(y == 0) return Ach(x.n);
if(y == 1) return x;
matrix z = Ksm(x, y / 2);
z = z * z;
if(y & 1) z = z * x;
return z;
}
matrix Init(matrix a)
{
matrix res;
res.n = a.n * 2;
res.m = a.m * 2;
for(int i = 1; i <= a.n; ++i)
{
res.z[i][i] = 1;
res.z[i][i + a.n] = 1;
}
for(int i = a.n + 1; i <= a.n * 2; ++i)
for(int j = a.n + 1; j <= a.n * 2; ++j)
res.z[i][j] = a.z[i - a.n][j - a.n];
return res;
}
void Print(matrix x)
{
for(int i = 1; i <= x.n; i++)
{
for(int j = 1; j <= x.m; j++)
printf("%lld%c", x.z[i][j], j == x.n? '\n': ' ');
}
}
matrix Get(matrix a)
{
matrix res;
matrix a2 = a * a;
res.n = 2 * a.n;
res.m = a.m;
for(int i = 1; i <= a.n; ++i)
for(int j = 1; j <= a.m; ++j)
res.z[i][j] = a.z[i][j];
for(int i = a.n + 1; i <= res.n; ++i)
for(int j = 1; j <= res.m; ++j)
res.z[i][j] = a2.z[i - a.n][j];
return res;
}
signed main()
{
matrix A;
scanf("%lld%lld%lld", &h, &k, &p);
A.n = A.m = h;
for(int i = 1; i <= h; ++i)
for(int j = 1; j <= h; j++)
scanf("%lld", &A.z[i][j]);
matrix E = Init(A);
matrix ans = Get(A);
ans = Ksm(E, k - 1) * ans;
ans.n /= 2;
Print(ans);
return 0;
}