작성일 :

문제 링크

1167번 - 트리의 지름

설명

가중치가 있는 트리가 주어지는 상황에서, 정점의 개수 V (2 ≤ V ≤ 100,000)와 각 정점에 연결된 간선 정보가 주어질 때, 트리에서 가장 긴 경로(트리의 지름)의 길이를 구하는 문제입니다.

입력 형식이 특이합니다. 각 줄마다 정점 번호가 먼저 주어지고, 그 다음 (연결된 정점 번호, 거리) 쌍이 반복되며 -1로 끝납니다. 예를 들어 “1 3 2 -1”은 정점 1이 정점 3과 거리 2로 연결됨을 의미합니다.

트리의 지름은 트리에 속한 임의의 두 노드 사이의 거리 중 가장 긴 것을 의미하며, 간선의 가중치 합으로 계산됩니다.


접근법

트리의 지름을 구하기 위해 깊이 우선 탐색(DFS)을 두 번 수행하는 방법을 사용합니다.


먼저 입력을 읽으면서 인접 리스트로 트리를 저장합니다. 입력이 방향성이 없으므로 양방향 간선으로 저장해야 합니다. 각 줄을 파싱하여 -1이 나올 때까지 (정점, 거리) 쌍을 읽어 인접 리스트에 추가합니다.

다음으로, 임의의 노드(1번)에서 DFS를 시작하여 가장 먼 노드를 찾습니다. 이 노드는 트리 지름의 한쪽 끝점이 됩니다. 트리에서는 임의의 노드에서 가장 먼 노드가 반드시 지름의 끝점 중 하나이기 때문입니다.

이후, 찾은 끝점에서 다시 DFS를 수행하여 가장 먼 거리를 구합니다. 이 거리가 트리의 지름입니다.


예를 들어, 정점 1-2(거리 3), 1-3(거리 10), 3-4(거리 2) 간선이 있는 트리에서:

노드 1에서 시작하여 가장 먼 노드를 찾으면 노드 4(거리 12 = 10 + 2)가 됩니다. 노드 4에서 다시 탐색하면 노드 2까지의 거리가 15(= 2 + 10 + 3)가 되어 이것이 트리의 지름입니다.


각 DFS는 모든 간선을 한 번씩 방문하므로 O(V + E)입니다. 트리에서는 E = V - 1이므로 O(V)가 됩니다. DFS를 두 번 수행하므로 전체 시간 복잡도는 O(V)입니다.



Code

C#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
using System;
using System.Collections.Generic;

class Program {
  static List<(int to, int w)>[] adj;
  static bool[] visited;
  static int farNode;
  static int maxDist;

  static void Dfs(int u, int dist) {
    visited[u] = true;
    if (dist > maxDist) { maxDist = dist; farNode = u; }
    foreach (var edge in adj[u]) {
      int v = edge.to, w = edge.w;
      if (!visited[v]) Dfs(v, dist + w);
    }
  }

  static void Main() {
    int n = int.Parse(Console.ReadLine()!);
    adj = new List<(int,int)>[n + 1];
    for (int i = 1; i <= n; i++) adj[i] = new List<(int,int)>();

    for (int i = 0; i < n; i++) {
      var tokens = Console.ReadLine()!.Split();
      int idx = 0;
      int u = int.Parse(tokens[idx++]);
      while (true) {
        int v = int.Parse(tokens[idx++]);
        if (v == -1) break;
        int w = int.Parse(tokens[idx++]);
        adj[u].Add((v, w));
        adj[v].Add((u, w));
      }
    }

    visited = new bool[n + 1];
    maxDist = 0;
    Dfs(1, 0);

    visited = new bool[n + 1];
    maxDist = 0;
    Dfs(farNode, 0);

    Console.WriteLine(maxDist);
  }
}

C++

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <bits/stdc++.h>
using namespace std;

vector<vector<pair<int,int>>> adj;
vector<bool> visited;
int farNode = 0;
int maxDist = 0;

void dfs(int u, int dist) {
  visited[u] = true;
  if (dist > maxDist) {
    maxDist = dist;
    farNode = u;
  }
  for (auto [v, w] : adj[u]) {
    if (!visited[v]) dfs(v, dist + w);
  }
}

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);

  int n; cin >> n;
  adj.assign(n + 1, {});

  for (int i = 0; i < n; i++) {
    int u; cin >> u;
    while (true) {
      int v; cin >> v;
      if (v == -1) break;
      int w; cin >> w;
      adj[u].push_back({v, w});
      adj[v].push_back({u, w});
    }
  }

  visited.assign(n + 1, false);
  maxDist = 0;
  dfs(1, 0);

  visited.assign(n + 1, false);
  maxDist = 0;
  dfs(farNode, 0);

  cout << maxDist << "\n";
  return 0;
}