작성일 :

문제 링크

7453번 - 합이 0인 네 정수

설명

네 개의 배열에서 각각 하나씩 원소를 골라 합이 0이 되는 경우의 수를 구하는 문제입니다.


접근법

네 배열을 모두 순회하면 시간 복잡도가 n의 4승이 되어 시간 초과가 발생합니다. 두 배열씩 묶어서 미리 합을 계산하면 n의 2승으로 줄일 수 있습니다.

먼저 C와 D의 모든 쌍의 합을 계산하여 저장하고 정렬합니다. 이후 A와 B의 모든 쌍의 합을 순회하면서, 그 합의 음수 값이 정렬된 배열에 몇 개 있는지 이분 탐색으로 찾습니다. 같은 값이 여러 개 있을 수 있으므로 범위의 시작과 끝을 구해 개수를 셉니다.



Code

C#

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
using System;
using System.Collections.Generic;

class Program {
  static int LowerBound(List<long> arr, long val) {
    int lo = 0, hi = arr.Count;
    while (lo < hi) {
      var mid = (lo + hi) / 2;
      if (arr[mid] < val) lo = mid + 1;
      else hi = mid;
    }
    return lo;
  }

  static int UpperBound(List<long> arr, long val) {
    int lo = 0, hi = arr.Count;
    while (lo < hi) {
      var mid = (lo + hi) / 2;
      if (arr[mid] <= val) lo = mid + 1;
      else hi = mid;
    }
    return lo;
  }

  static void Main() {
    var n = int.Parse(Console.ReadLine()!);
    var m = new int[4, n];
    for (var i = 0; i < n; i++) {
      var parts = Console.ReadLine()!.Split();
      for (var j = 0; j < 4; j++) m[j, i] = int.Parse(parts[j]);
    }

    var cd = new List<long>(n * n);
    for (var i = 0; i < n; i++)
      for (var j = 0; j < n; j++)
        cd.Add((long)m[2, i] + m[3, j]);

    cd.Sort();

    var ans = 0L;
    for (var i = 0; i < n; i++) {
      for (var j = 0; j < n; j++) {
        var target = -(long)m[0, i] - m[1, j];
        ans += UpperBound(cd, target) - LowerBound(cd, target);
      }
    }

    Console.WriteLine(ans);
  }
}

C++

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef vector<int> vi;
typedef vector<vi> vvi;
typedef vector<ll> vll;

int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);

  int n; cin >> n;
  vvi m(4, vi(n));
  for (int i = 0; i < n; i++)
    for (int j = 0; j < 4; j++) cin >> m[j][i];

  vll cd;
  cd.reserve(n * n);
  for (int i = 0; i < n; i++)
    for (int j = 0; j < n; j++)
      cd.push_back((ll)m[2][i] + m[3][j]);
  sort(cd.begin(), cd.end());

  ll ans = 0;
  for (int i = 0; i < n; i++) {
    for (int j = 0; j < n; j++) {
      ll target = -(ll)m[0][i] - m[1][j];
      auto l = lower_bound(cd.begin(), cd.end(), target);
      if (l == cd.end() || *l != target) continue;
      auto r = upper_bound(cd.begin(), cd.end(), target);
      ans += (r - l);
    }
  }

  cout << ans << "\n";

  return 0;
}