POJ3233 Matrix Power Series
2023-10-12 11:59:37 # OI # Problem

题意

给你一个 n×nn\times n 的矩阵,设 S=A+A2+A3++AkS = A + A^{2}+A^{3}+\dots+A^k 求 S。

思路

矩阵快速幂,想想怎么加速递推。

众所周知单位矩阵为 E,零矩阵为 O。

推导:

\begin{gather} \begin{bmatrix} E & E \\ O & A \\ \end{bmatrix} \times \begin{bmatrix} S_{k - 1} \\ A^{k} \\ \end{bmatrix} = \begin{bmatrix} S_k \\ A^{k + 1} \\ \end{bmatrix} \\ \Downarrow \\ \begin{bmatrix} E & E \\ O & A \\ \end{bmatrix}^{k-1} \times \begin{bmatrix} S_1 \\ A^2 \\ \end{bmatrix} = \begin{bmatrix} S_k \\ A^{k+1} \end{bmatrix} \end{gather}

然后就可以愉快地递推了。

代码

// #include <bits/stdc++.h>
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;

ll h, k, p;

struct matrix
{
    int n, m;
    ll z[70][70]; //开小了两倍空间越界,开大了就爆栈,调了一个半小时。
    matrix()
    {
        memset(z, 0, sizeof(z));
    }
    matrix operator * (const matrix &x) const
    {
        matrix u;
        u.n = n;
        u.m = x.m;
        for(int i = 1; i <= n; i++)
            for(int k = 1; k <= m; k++)
                for(int j = 1; j <= x.m; j++)
                    u.z[i][j] = (u.z[i][j] + (z[i][k] * x.z[k][j] % p) + p) % p;
        return u;
    }
};
matrix Ach(int x)
{
    matrix res;
    res.n = res.m = x;
    for(int i = 1; i <= x; i++)
        res.z[i][i] = 1;
    return res;
}
matrix Ksm(matrix x, int y)
{
    if(y == 0) return Ach(x.n);
    if(y == 1) return x;
    matrix z = Ksm(x, y / 2);
    z = z * z;
    if(y & 1) z = z * x;
    return z;
}
matrix Init(matrix a)
{
    matrix res;
    res.n = a.n * 2;
    res.m = a.m * 2;
    for(int i = 1; i <= a.n; ++i)
    {
        res.z[i][i] = 1;
        res.z[i][i + a.n] = 1;
    }
    for(int i = a.n + 1; i <= a.n * 2; ++i)
        for(int j = a.n + 1; j <= a.n * 2; ++j)
            res.z[i][j] = a.z[i - a.n][j - a.n];
    return res;
}
void Print(matrix x)
{
    for(int i = 1; i <= x.n; i++)
    {
        for(int j = 1; j <= x.m; j++)
            printf("%lld%c", x.z[i][j], j == x.n? '\n': ' ');
    }
}
matrix Get(matrix a)
{
    matrix res;
    matrix a2 = a * a;
    res.n = 2 * a.n;
    res.m = a.m;
    for(int i = 1; i <= a.n; ++i)
        for(int j = 1; j <= a.m; ++j)
            res.z[i][j] = a.z[i][j];
    for(int i = a.n + 1; i <= res.n; ++i)
        for(int j = 1; j <= res.m; ++j)
            res.z[i][j] = a2.z[i - a.n][j];
        
    return res;
}

signed main()
{
    matrix A;
    scanf("%lld%lld%lld", &h, &k, &p);
    A.n = A.m = h;
    for(int i = 1; i <= h; ++i)
        for(int j = 1; j <= h; j++)
            scanf("%lld", &A.z[i][j]);
    matrix E = Init(A);
    matrix ans = Get(A);
    ans = Ksm(E, k - 1) * ans;
    ans.n /= 2;
    Print(ans);
    return 0;
}