작성일 :

문제 링크

10830번 - 행렬 제곱

설명

크기 N×N 행렬 A와 거듭제곱 지수 B (2 ≤ N ≤ 5, 1 ≤ B ≤ 10¹¹)가 주어질 때, A^B를 계산하여 각 원소를 1000으로 나눈 나머지를 출력하는 문제입니다.

B가 최대 10¹¹로 매우 크므로 단순 반복으로는 시간 내에 계산할 수 없습니다.

행렬 곱셈 과정에서 각 원소를 매번 1000으로 나눈 나머지로 유지하여 오버플로를 방지해야 합니다.


접근법

B가 매우 크므로 단순 반복으로는 시간 내에 계산할 수 없어, 빠른 거듭제곱(분할 정복)을 행렬에 적용합니다.

단위 행렬을 결과의 초기값으로 설정하고, 지수 B가 0보다 큰 동안 다음을 반복합니다: B가 홀수면 결과에 현재 기저 행렬을 곱하고, 기저 행렬을 제곱한 후 지수를 절반으로 줄입니다.

행렬 곱셈 시 각 원소의 중간 합계를 64비트 정수로 계산하여 오버플로우를 방지하고, 최종 결과를 1000으로 나눈 나머지로 저장합니다.


행렬 크기가 최대 5이므로 행렬 곱셈은 O(N³), 빠른 거듭제곱은 O(log B)로 전체 시간 복잡도는 O(N³ log B)입니다.



Code

C#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
using System;

namespace Solution {
  class Program {
    const int MOD = 1000;

    static int N;

    static int[,] Mul(int[,] a, int[,] b) {
      var r = new int[N, N];

      for (var i = 0; i < N; i++)
        for (var j = 0; j < N; j++) {
          var sum = 0L;

          for (var k = 0; k < N; k++)
            sum += (long)a[i, k] * b[k, j];

          r[i, j] = (int)(sum % MOD);
        }

      return r;
    }

    static int[,] Pow(int[,] baseM, long exp) {
      var res = new int[N, N];

      for (var i = 0; i < N; i++)
        res[i, i] = 1;

      while (exp > 0) {
        if ((exp & 1) == 1) res = Mul(res, baseM);

        baseM = Mul(baseM, baseM);
        exp >>= 1;
      }

      return res;
    }

    static void Main(string[] args) {
      var first = Console.ReadLine()!.Split();
      N = int.Parse(first[0]);
      var b = long.Parse(first[1]);

      var a = new int[N, N];

      for (var i = 0; i < N; i++) {
        var line = Console.ReadLine()!.Split();

        for (var j = 0; j < N; j++)
          a[i, j] = int.Parse(line[j]) % MOD;
      }

      var ans = Pow(a, b);

      for (var i = 0; i < N; i++) {
        for (var j = 0; j < N; j++) {
          Console.Write(ans[i, j] % MOD);
          if (j + 1 < N) Console.Write(' ');
        }
        Console.WriteLine();
      }
    }
  }
}

C++

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef vector<vector<int>> vvi;

const int MOD = 1000;

struct Mat {
  int n;
  vvi a;
  Mat(int n): n(n), a(n, vector<int>(n, 0)) {}
};

Mat mul(const Mat& x, const Mat& y) {
  int n = x.n;
  Mat r(n);

  for (int i = 0; i < n; i++)
    for (int j = 0; j < n; j++) {
      ll sum = 0;

      for (int k = 0; k < n; k++)
        sum += 1LL * x.a[i][k] * y.a[k][j];

      r.a[i][j] = (int)(sum % MOD);
    }

  return r;
}

Mat mpow(Mat base, ll exp) {
  int n = base.n;
  Mat res(n);

  for (int i = 0; i < n; i++)
    res.a[i][i] = 1;

  while (exp > 0) {
    if (exp & 1) res = mul(res, base);

    base = mul(base, base);
    exp >>= 1;
  }

  return res;
}

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);

  int n; ll b; cin >> n >> b;

  Mat a(n);

  for (int i = 0; i < n; i++)
    for (int j = 0; j < n; j++) {
      cin >> a.a[i][j];
      a.a[i][j] %= MOD;
    }

  Mat ans = mpow(a, b);

  for (int i = 0; i < n; i++)
    for (int j = 0; j < n; j++)
      cout << ans.a[i][j] % MOD << (j + 1 < n ? ' ' : '\n');

  return 0;
}