[백준 11049] 행렬 곱셈 순서 (C#, C++) - soo:bak
작성일 :
문제 링크
설명
N개의 행렬이 주어지는 상황에서, N (1 ≤ N ≤ 500)과 각 행렬의 크기 (r, c)가 주어질 때, 모든 행렬을 곱하는 데 필요한 최소 곱셈 연산 횟수를 구하는 문제입니다.
행렬 곱셈은 결합법칙이 성립하므로 순서를 바꾸지 않고도 괄호를 치는 방법에 따라 연산 횟수가 달라집니다. 예를 들어 A(r×c) × B(c×d) × C(d×e)를 계산할 때, (A×B)×C는 r×c×d + r×d×e번의 연산이 필요하고, A×(B×C)는 c×d×e + r×c×e번의 연산이 필요합니다.
행렬의 순서는 바꿀 수 없고, 괄호를 어떻게 치느냐에 따라 연산 횟수가 최소가 되도록 해야 합니다.
접근법
구간 동적 프로그래밍을 사용하여 최적 분할 지점을 찾습니다.
먼저 dp[l][r]을 l번째부터 r번째 행렬까지 곱하는 최소 연산 횟수로 정의합니다. 길이가 1인 구간(단일 행렬)은 곱셈이 필요 없으므로 dp[i][i] = 0으로 초기화합니다.
다음으로, 구간 [l, r]을 두 부분으로 나누는 지점 k를 선택합니다. k는 l부터 r-1까지 가능합니다. [l, k]와 [k+1, r]을 각각 곱한 후 두 결과를 곱하는 비용은 dp[l][k] + dp[k+1][r] + (l번째 행렬의 행 수) × (k번째 행렬의 열 수) × (r번째 행렬의 열 수)입니다.
이후, 모든 가능한 k에 대해 위 비용을 계산하여 최솟값을 dp[l][r]에 저장합니다. 구간 길이를 1부터 N-1까지 늘려가며 모든 구간을 채웁니다.
이렇게 계산하면 dp[1][N]이 전체 행렬을 곱하는 최소 연산 횟수가 됩니다.
예를 들어, 3개의 행렬 A(5×3), B(3×2), C(2×6)가 있는 경우:
(A×B)×C: A×B는 5×3×2 = 30번, 결과(5×2)와 C를 곱하면 5×2×6 = 60번, 총 90번 A×(B×C): B×C는 3×2×6 = 36번, A와 결과(3×6)를 곱하면 5×3×6 = 90번, 총 126번
따라서 최소 연산 횟수는 90번입니다.
구간 길이와 분할 지점을 모두 시도하므로 시간 복잡도는 O(N³)입니다.
Code
C#
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
using System;
namespace Solution {
class Program {
static void Main(string[] args) {
var n = int.Parse(Console.ReadLine()!);
var r = new int[n + 1];
var c = new int[n + 1];
for (var i = 1; i <= n; i++) {
var parts = Console.ReadLine()!.Split();
r[i] = int.Parse(parts[0]);
c[i] = int.Parse(parts[1]);
}
var dp = new int[n + 1, n + 1];
const int INF = int.MaxValue;
for (var len = 1; len < n; len++) {
for (var l = 1; l + len <= n; l++) {
var rIdx = l + len;
dp[l, rIdx] = INF;
for (var k = l; k < rIdx; k++) {
var cost = (long)dp[l, k] + dp[k + 1, rIdx] + (long)r[l] * c[k] * c[rIdx];
if (cost < dp[l, rIdx])
dp[l, rIdx] = (int)cost;
}
}
}
Console.WriteLine(dp[1, n]);
}
}
}
C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vi;
typedef vector<vi> vvi;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n; cin >> n;
vi r(n + 1), c(n + 1);
for (int i = 1; i <= n; i++)
cin >> r[i] >> c[i];
const int INF = INT_MAX;
vvi dp(n + 1, vi(n + 1, 0));
for (int len = 1; len < n; len++) {
for (int l = 1; l + len <= n; l++) {
int rIdx = l + len;
dp[l][rIdx] = INF;
for (int k = l; k < rIdx; k++) {
ll cost = (ll)dp[l][k] + dp[k + 1][rIdx] + 1LL * r[l] * c[k] * c[rIdx];
if (cost < dp[l][rIdx])
dp[l][rIdx] = (int)cost;
}
}
}
cout << dp[1][n] << "\n";
return 0;
}