작성일 :

문제 링크

1967번 - 트리의 지름

설명

가중치가 있는 루트 트리가 주어지는 상황에서, 노드의 개수 n (1 ≤ n ≤ 10,000), 루트 노드(1번), 그리고 각 간선의 정보(부모 노드, 자식 노드, 가중치)가 주어질 때, 트리에서 가장 긴 경로(트리의 지름)의 길이를 구하는 문제입니다.

트리의 지름은 트리에 속한 임의의 두 노드 사이의 거리 중 가장 긴 것을 의미합니다. 간선에 가중치가 있으므로 경로의 길이는 경로상의 가중치 합으로 계산됩니다.


접근법

트리의 지름을 구하기 위해 깊이 우선 탐색(DFS)을 두 번 수행하는 방법을 사용합니다.


먼저 인접 리스트로 트리를 저장합니다. 입력에서는 부모-자식 관계로 주어지지만, 양방향 간선으로 저장하여 어느 노드에서든 탐색할 수 있도록 합니다.

임의의 노드(1번)에서 DFS를 시작하여 가장 먼 노드를 찾습니다. 이 노드는 트리 지름의 한쪽 끝점이 됩니다. 트리에서는 임의의 노드에서 가장 먼 노드가 반드시 지름의 끝점 중 하나이기 때문입니다.

찾은 끝점에서 다시 DFS를 수행하여 가장 먼 거리를 구합니다. 이 거리가 트리의 지름입니다.


예를 들어, 노드 1-2(가중치 3), 1-3(가중치 10), 3-4(가중치 2) 간선이 있는 트리에서:

노드 1에서 시작하여 가장 먼 노드를 찾으면 노드 4(거리 12 = 10 + 2)가 됩니다. 노드 4에서 다시 탐색하면 노드 2까지의 거리가 15(= 2 + 10 + 3)가 되어 이것이 트리의 지름입니다.


각 DFS는 모든 노드를 한 번씩 방문하므로 O(N)입니다. DFS를 두 번 수행하므로 전체 시간 복잡도는 O(N)입니다.



Code

C#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
using System;
using System.Collections.Generic;

class Program {
  static List<(int to, int w)>[] adj;
  static bool[] visited;
  static int farNode;
  static int maxDist;

  static void Dfs(int u, int dist) {
    visited[u] = true;
    if (dist > maxDist) { maxDist = dist; farNode = u; }
    foreach (var edge in adj[u]) {
      int v = edge.to, w = edge.w;
      if (!visited[v]) Dfs(v, dist + w);
    }
  }

  static void Main() {
    int n = int.Parse(Console.ReadLine()!);
    adj = new List<(int,int)>[n + 1];
    for (int i = 1; i <= n; i++) adj[i] = new List<(int,int)>();

    for (int i = 0; i < n - 1; i++) {
      var parts = Array.ConvertAll(Console.ReadLine()!.Split(), int.Parse);
      int a = parts[0], b = parts[1], w = parts[2];
      adj[a].Add((b, w));
      adj[b].Add((a, w));
    }

    visited = new bool[n + 1];
    maxDist = 0;
    Dfs(1, 0);

    visited = new bool[n + 1];
    maxDist = 0;
    Dfs(farNode, 0);

    Console.WriteLine(maxDist);
  }
}

C++

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include <bits/stdc++.h>
using namespace std;

vector<vector<pair<int,int>>> adj;
vector<bool> visited;
int farNode = 0;
int maxDist = 0;

void dfs(int u, int dist) {
  visited[u] = true;
  if (dist > maxDist) {
    maxDist = dist;
    farNode = u;
  }
  for (auto [v, w] : adj[u]) {
    if (!visited[v]) dfs(v, dist + w);
  }
}

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);

  int n; cin >> n;
  adj.assign(n + 1, {});
  for (int i = 0; i < n - 1; i++) {
    int a, b, w; cin >> a >> b >> w;
    adj[a].push_back({b, w});
    adj[b].push_back({a, w});
  }

  visited.assign(n + 1, false);
  maxDist = 0;
  dfs(1, 0);

  visited.assign(n + 1, false);
  maxDist = 0;
  dfs(farNode, 0);

  cout << maxDist << "\n";
  return 0;
}