시간 제한 | 메모리 제한 |
2초 | 512MB |
문제
자연수를 저장하는 데이터베이스 S에 대해 다음의 쿼리를 처리합시다.
유형 1 : S에 자연수 X를 추가한다.
유형 2 : S에 포함된 숫자 중 X번째로 작은 수를 응답하고 그 수를 삭제한다.
입력
첫째 줄에 사전에 있는 쿼리의 수 N 이 주어집니다. (1 ≤ N ≤ 2,000,000)
둘째 줄부터 N개의 줄에 걸쳐 각 쿼리를 나타내는 2개의 정수 T X가 주어집니다.
T가 1이라면 S에 추가할 X가 주어지는 것입니다. (1 ≤ X ≤ 2,000,000)
T가 2라면 X는 S에서 삭제해야 할 몇 번째로 작은 수인지를 나타냅니다. S에 최소 X개의 원소가 있음이 보장됩니다.
출력
유형 2의 쿼리 개수만큼의 줄에 각 쿼리에 대한 답을 출력합니다.
풀이
S를 어떤 자료 구조로 관리할 것인지가 문제의 핵심이다. 1차원 배열을 사용한다면 구현이 굉장히 쉽겠지만 '유형 2'의 쿼리에서 매우 긴 시간이 소요된다. 다른 방법을 찾아보자.
'유형 1'의 자연수 X의 범위가 2,000,000 이하로 제한되어있기 때문에 계수 정렬(counting sort)를 사용할 수 있다. '유형 2'에서는 X번째로 작은 수를 찾아야 하므로 계수 정렬에 대한 누적합을 관리해야 한다. 누적합 구간에 해당하는 X를 찾아 삭제하고 출력하는 식이다. 누적합에 대한 쿼리는 세그먼트 트리로 구현할 수 있다.
1. 세그먼트 트리(segment tree)
세그먼트 트리(Segment Tree)
트리(Tree) 자료구조는 데이터간 위계가 있어 다양한 처리가 가능하다. 세그먼트 트리는 약간의 메모리를 할애해 구간의 값을 따로 관리하여 빠르게 구간 해를 얻도록 하는 트리이다. 세그먼트
celbeing.tistory.com
허접한 수준의 정리이긴 하지만 이전에 이미 세그먼트 트리에 대해 정리한 적이 있다.
세그먼트 트리는 구간의 정보를 트리로 관리하는 자료 구조다. 최댓값, 최솟값, 누적합, 누적곱 등을 세그먼트 트리로 관리할 수 있으며 트리의 탐색하여 원하는 구간 밖에 있는 경우, 걸친 경우, 포함되는 경우로 나눠진다. 걸친 경우는 한 단계 더 깊은 탐색을 실시해 구간이 포함되는 경우만을 확인하면 된다.
2. 구현
세그먼트 트리를 구현하기 위해서는 리프노드를 탐색해 데이터를 저장하고 탈출하며 구간 값을 저장하도록 해야 한다. 우선은 '유형 1'로 트리의 리프노드를 찾아 X를 추가하고 구간 값을 갱신하는 것부터 구현해보자.
2,000,000개의 데이터를 관리할 세그먼트 트리의 높이는 2⌈log22,000,000⌉+1이다. ⌈log22,000,000⌉+1 = 22이므로 다음과 같이 세그먼트 트리를 선언한다.
tree = [0]*(1<<22)
먼저 '유형 1'의 입력을 처리해보자. 노드 탐색 과정에 필요한 매개변수는 탐색 범위를 정하는 start, end 그리고 현재 탐색 중인 노드의 위치인 node, 목표 노드인 target이다. 세그먼트 트리의 탐색 과정에 따라 다음 세 가지 경우로 나뉜다.
- start = end인 경우. 리프노드에 도착한 것이다. 탐색을 종료하고 순차적으로 탈출한다.
- mid = (start + end) // 2의 값이 다음과 같은 경우
- mid가 target보다 큰 경우, 탐색 범위는 오른쪽 노드로 넘어간다.
- mid가 target보다 크지 않은 경우, 탐색 범위는 왼쪽 노드로 넘어간다.
Insert 함수를 구현하면 다음과 같다.
def Insert(start, end, node, target):
# 리프노드에 도달한 경우
if start == end:
tree[node] += 1
return
mid = (start + end) // 2
if target > min:
Insert(mid + 1, end, node * 2 + 1, target)
else:
Insert(start, mid, node * 2, target)
# tree의 노드 값을 1씩 증가시키고 탈출
tree[node] += 1
return
'유형 2'의 입력을 처리하기 위해서는 매개변수를 조금 다르게 넘겨주어야 한다. 각 노드의 값은 누적합을 의미하는데, 만약 현재 노드의 값이 10, 내가 찾고자 하는 값은 8, 하위 노드의 왼쪽은 7, 오른쪽은 3이라면 내가 찾는 값은 왼쪽의 7개 값들을 제거한 오른쪽 부분 트리 중 1번째 값이다. 따라서 탐색 과정에서 target 값이 현재 노드의 값보다 큰 경우와 그렇지 않은 경우로 나누어 생각해야 한다.
- start = end인 경우. 리프노드에 도달했으므로 탐색을 종료하고 순차적으로 탈출한다.
- target > tree[node * 2]인 경우. target에서 tree[node * 2] 값을 빼고 오른쪽 부분 트리만 탐색한다.
- target ≤ tree[node * 2]인 경우. 왼쪽 부분 트리를 이어서 탐색한다.
Delete 함수를 구현하면 다음과 같다.
def Delete(start, end, node, target):
if start == end:
print(start)
tree[node] -= 1
return
mid = (start + end) // 2
# 하위 항목의 갯수에 따라 탐색 진행
if target > tree[node * 2]:
Delete(mid + 1, end, node * 2 + 1, target)
else:
Delete(start, mid, node * 2, target)
# 노드의 값을 1씩 빼면서 탈출
tree[node] -= 1
return
정답 코드
import sys
input = sys.stdin.readline
def Insert(s,e,n,k):
if s == e:
tree[n] += 1
return
m = (s+e)//2
if k>m:
Insert(m+1,e,n*2+1,k)
else:
Insert(s,m,n*2,k)
tree[n] += 1
return
def Delete(s,e,n,k):
if s == e:
print(s)
tree[n] -= 1
return
m = (s+e)//2
if k > tree[n*2]:
Delete(m+1,e,n*2+1,k-tree[n*2])
else:
Delete(s,m,n*2,k)
tree[n] -= 1
return
tree = [0]*(1<<22)
N = int(input())
for _ in range(N):
S,X = map(int,input().split())
if S == 1:
Insert(1,2000000,1,X)
else:
Delete(1,2000000,1,X)
python3 로는 시간초과가 나온다. pypy3로 제출해서 6680ms가 떴다. 1000ms 안쪽으로 성공한 사람도 있으니 출력을 더 빠르게 하는 방법 찾아보면 시간 단축이 가능할 것 같다. 그런데 이러다가 C++로 갈아 타야할 날이 곧 오지 싶다... 자료구조 문제를 풀 때 마다 파이썬의 한계를 조금씩 실감해 간다.