작성일 :

문제 링크

11049번 - 행렬 곱셈 순서

설명

N개의 행렬이 주어지는 상황에서, N (1 ≤ N ≤ 500)과 각 행렬의 크기 (r, c)가 주어질 때, 모든 행렬을 곱하는 데 필요한 최소 곱셈 연산 횟수를 구하는 문제입니다.

행렬 곱셈은 결합법칙이 성립하므로 순서를 바꾸지 않고도 괄호를 치는 방법에 따라 연산 횟수가 달라집니다. 예를 들어 A(r×c) × B(c×d) × C(d×e)를 계산할 때, (A×B)×C는 r×c×d + r×d×e번의 연산이 필요하고, A×(B×C)는 c×d×e + r×c×e번의 연산이 필요합니다.

행렬의 순서는 바꿀 수 없고, 괄호를 어떻게 치느냐에 따라 연산 횟수가 최소가 되도록 해야 합니다.


접근법

구간 동적 프로그래밍을 사용하여 최적 분할 지점을 찾습니다.


먼저 dp[l][r]을 l번째부터 r번째 행렬까지 곱하는 최소 연산 횟수로 정의합니다. 길이가 1인 구간(단일 행렬)은 곱셈이 필요 없으므로 dp[i][i] = 0으로 초기화합니다.

다음으로, 구간 [l, r]을 두 부분으로 나누는 지점 k를 선택합니다. k는 l부터 r-1까지 가능합니다. [l, k]와 [k+1, r]을 각각 곱한 후 두 결과를 곱하는 비용은 dp[l][k] + dp[k+1][r] + (l번째 행렬의 행 수) × (k번째 행렬의 열 수) × (r번째 행렬의 열 수)입니다.

이후, 모든 가능한 k에 대해 위 비용을 계산하여 최솟값을 dp[l][r]에 저장합니다. 구간 길이를 1부터 N-1까지 늘려가며 모든 구간을 채웁니다.

이렇게 계산하면 dp[1][N]이 전체 행렬을 곱하는 최소 연산 횟수가 됩니다.


예를 들어, 3개의 행렬 A(5×3), B(3×2), C(2×6)가 있는 경우:

(A×B)×C: A×B는 5×3×2 = 30번, 결과(5×2)와 C를 곱하면 5×2×6 = 60번, 총 90번 A×(B×C): B×C는 3×2×6 = 36번, A와 결과(3×6)를 곱하면 5×3×6 = 90번, 총 126번

따라서 최소 연산 횟수는 90번입니다.


구간 길이와 분할 지점을 모두 시도하므로 시간 복잡도는 O(N³)입니다.



Code

C#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
using System;

namespace Solution {
  class Program {
    static void Main(string[] args) {
      var n = int.Parse(Console.ReadLine()!);
      var r = new int[n + 1];
      var c = new int[n + 1];

      for (var i = 1; i <= n; i++) {
        var parts = Console.ReadLine()!.Split();
        r[i] = int.Parse(parts[0]);
        c[i] = int.Parse(parts[1]);
      }

      var dp = new int[n + 1, n + 1];
      const int INF = int.MaxValue;

      for (var len = 1; len < n; len++) {
        for (var l = 1; l + len <= n; l++) {
          var rIdx = l + len;
          dp[l, rIdx] = INF;

          for (var k = l; k < rIdx; k++) {
            var cost = (long)dp[l, k] + dp[k + 1, rIdx] + (long)r[l] * c[k] * c[rIdx];
            if (cost < dp[l, rIdx])
              dp[l, rIdx] = (int)cost;
          }
        }
      }

      Console.WriteLine(dp[1, n]);
    }
  }
}

C++

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef vector<int> vi;
typedef vector<vi> vvi;

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);

  int n; cin >> n;
  vi r(n + 1), c(n + 1);
  for (int i = 1; i <= n; i++)
    cin >> r[i] >> c[i];

  const int INF = INT_MAX;
  vvi dp(n + 1, vi(n + 1, 0));

  for (int len = 1; len < n; len++) {
    for (int l = 1; l + len <= n; l++) {
      int rIdx = l + len;
      dp[l][rIdx] = INF;

      for (int k = l; k < rIdx; k++) {
        ll cost = (ll)dp[l][k] + dp[k + 1][rIdx] + 1LL * r[l] * c[k] * c[rIdx];
        if (cost < dp[l][rIdx])
          dp[l][rIdx] = (int)cost;
      }
    }
  }

  cout << dp[1][n] << "\n";

  return 0;
}