[백준 1562] 계단 수 (C#, C++) - soo:bak
작성일 :
문제 링크
설명
인접한 자릿수의 차이가 1인 수를 계단 수라고 합니다.
길이가 N이면서 0부터 9까지 모든 숫자가 등장하는 계단 수의 개수를 1,000,000,000으로 나눈 나머지를 구하는 문제입니다. 선행 0은 허용되지 않습니다.
접근법
먼저, 0부터 9까지 모든 숫자가 등장해야 하므로 어떤 숫자를 사용했는지 추적해야 합니다. 숫자가 10개이므로 10비트 비트마스크로 사용 여부를 관리할 수 있습니다. 예를 들어 0, 1, 2를 사용했다면 마스크는 0000000111(=7)이 됩니다.
다음으로, DP 상태를 현재 길이, 마지막 숫자, 사용한 숫자 집합으로 정의합니다. 계단 수는 마지막 숫자에서 1을 더하거나 뺀 숫자만 다음에 올 수 있으므로, 마지막 숫자가 d일 때 다음 숫자는 d-1 또는 d+1입니다. 이때 마스크에 새 숫자를 추가합니다.
이후, 길이 1부터 N까지 전이를 반복합니다. 메모리를 절약하기 위해 이전 길이와 현재 길이 두 층만 번갈아 사용합니다. 초기값은 길이 1일 때 1부터 9까지 각 숫자로 시작하는 경우입니다.
마지막으로, 길이 N에서 마스크가 1111111111(=1023)인 모든 경우를 합산하면 답이 됩니다.
Code
C#
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
using System;
namespace Solution {
class Program {
const int MOD = 1_000_000_000;
static void Main(string[] args) {
var N = int.Parse(Console.ReadLine()!);
var dp = new int[2, 10, 1 << 10];
for (var d = 1; d <= 9; d++)
dp[1 % 2, d, 1 << d] = 1;
for (var len = 1; len < N; len++) {
var cur = len % 2;
var nxt = (len + 1) % 2;
Array.Clear(dp, nxt * 10 * (1 << 10), 10 * (1 << 10));
for (var d = 0; d <= 9; d++) {
for (var mask = 0; mask < (1 << 10); mask++) {
var val = dp[cur, d, mask];
if (val == 0)
continue;
if (d > 0) {
var nm = mask | (1 << (d - 1));
dp[nxt, d - 1, nm] = (dp[nxt, d - 1, nm] + val) % MOD;
}
if (d < 9) {
var nm = mask | (1 << (d + 1));
dp[nxt, d + 1, nm] = (dp[nxt, d + 1, nm] + val) % MOD;
}
}
}
}
var ans = 0;
var full = (1 << 10) - 1;
var last = N % 2;
for (var d = 0; d <= 9; d++)
ans = (ans + dp[last, d, full]) % MOD;
Console.WriteLine(ans);
}
}
}
C++
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
using namespace std;
const int MOD = 1'000'000'000;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N; cin >> N;
static int dp[2][10][1 << 10];
for (int d = 1; d <= 9; d++)
dp[1 % 2][d][1 << d] = 1;
for (int len = 1; len < N; len++) {
int cur = len % 2;
int nxt = (len + 1) % 2;
memset(dp[nxt], 0, sizeof(dp[nxt]));
for (int d = 0; d <= 9; d++) {
for (int mask = 0; mask < (1 << 10); mask++) {
int val = dp[cur][d][mask];
if (!val)
continue;
if (d > 0) {
int nm = mask | (1 << (d - 1));
dp[nxt][d - 1][nm] = (dp[nxt][d - 1][nm] + val) % MOD;
}
if (d < 9) {
int nm = mask | (1 << (d + 1));
dp[nxt][d + 1][nm] = (dp[nxt][d + 1][nm] + val) % MOD;
}
}
}
}
int full = (1 << 10) - 1;
int last = N % 2;
int ans = 0;
for (int d = 0; d <= 9; d++)
ans = (ans + dp[last][d][full]) % MOD;
cout << ans << "\n";
return 0;
}