시간 제한 | 메모리 제한 |
1초 | 128MB |
문제
(원문이 영어라 gtp에게 번역 하청)
매일 우유를 짜기 위해 Farmer John의 N마리의 소들 (1 ≤ N ≤ 50,000)은 항상 같은 순서로 줄을 섭니다. 어느 날 Farmer John은 일부 소들과 함께 궁극의 프리스비 게임을 하기로 결정했습니다. 단순하게 하기 위해, 그는 우유 짜는 줄에서 연속적인 범위의 소들을 선택하여 게임에 참여시킬 것입니다. 하지만 모든 소들이 즐거운 시간을 보내려면 키 차이가 너무 크지 않아야 합니다.
Farmer John은 Q개의 가능한 소 그룹 (1 ≤ Q ≤ 180,000)에 대해 각 소의 키 (1 ≤ height ≤ 1,000,000)를 기록했습니다. 각 그룹에 대해, 그는 해당 그룹에서 가장 키가 큰 소와 가장 키가 작은 소의 키 차이를 알고 싶어합니다.
참고: 가장 큰 테스트 케이스에서는 입출력(I/O) 시간이 대부분을 차지합니다.
입력
첫 번째 줄: 두 개의 공백으로 구분된 정수 N과 Q
다음 N개의 줄: i + 1번째 줄에는 i번째 소의 키가 주어집니다.
다음 Q개의 줄: 두 개의 정수 A와 B가 주어지며 (1 ≤ A ≤ B ≤ N), 이는 소의 범위가 A부터 B까지 포함됨을 나타냅니다.
출력
Q개의 줄: 각 줄에는 입력된 쿼리에 대한 답을 나타내는 단일 정수가 포함되며, 해당 범위에서 가장 키가 큰 소와 가장 키가 작은 소의 키 차이를 나타냅니다.
풀이
최대 50,000개의 데이터에서 180,000번의 구간 최솟값과 최댓값을 구해야 한다. 매 쿼리마다 단순 선형탐색을 실시할 경우, 매회 모든 데이터를 살펴본다면 9,000,000,000번의 연산이 필요하다. 가능한 모든 구간 조합에 대해 미리 최솟값과 최댓값을 구해놓는다면? 그래도 약 12.5억 번의 연산이 필요하다. 선형 탐색으로는 해결할 수 없는데, 각 쿼리마다 O(logn)으로 최댓값과 최솟값을 구할 수 있는 방법이 있다. 그럼 대강 최악의 경우 log250,000 * 180,000 ≒ 2,700,000번의 연산으로 시간 제한 내에 해결 가능하다. 모든 데이터의 최댓값과 최솟값, 누적합, 누적곱 등의 데이터를 이진트리로 관리하고 이진탐색으로 찾을 수 있는 세그먼트 트리를 이용해보자.
우선 입력을 받고 세그먼트 트리의 크기부터 정해야 한다. n개의 데이터를 이진트리의 리프 노드로 담아야 하기 때문에 이진트리의 크기는 2⌈log2n⌉+1 로 정할 수 있다.
N, Q = map(int, input().split())
cow = [0] + [int(input()) for _ in range(N)]
seg_height = 0
# math 모듈의 log2 함수를 써도 된다.
while 1 << seg_height < N:
seg_height += 1
seg_tree = [(1000000, 1)] * (1 << (seg_height + 1))
트리는 지금 (100000, 1) 튜플로 초기화 되었는데 각각 현재 노드의 최솟값과 최댓값을 나타낸다.
다음은 입력받은 cow의 값을 seg_tree에 넣어야 한다. 세그먼트 트리를 관리할 때 원본 배열의 인덱스만을 가지고 와서 참조하며 구성하는 방법도 가능하지만 너무 헷갈려서 그냥 원본 배열의 값을 그대로 가져오기로 하자. 값 초기화는 상향식으로 이루어진다. 가장 먼저 리프 노드 수준에서의 최댓값, 최솟값(자기 자신임)을 구하고 부모 노드에서 서브 트리 내의 최댓값과 최솟값을 구해야 한다. 이는 왼쪽 자식 노드와 오른쪽 자식 노드의 값을 비교해서 얻을 수 있다. 상향식으로 구성되기 때문에 현재 구하고자 하는 노드의 자식 노드에서는 최댓값과 최솟값이 모두 구해져있다.
def seg_init(tree, node, s, e, array):
# 현재 노드가 가리키는 구간인 s와 e가 같은 경우,
# 리프 노드에 해당하므로 원본 배열의 값을 그대로 가져온다.
if s == e:
tree[node] = (array[s], array[s])
else:
m = (s + e) // 2
# l은 왼쪽 자식 노드, r은 오른쪽 자식 노드
l = node * 2
r = l + 1
# 자식 노드의 값 구성하기
seg_init(tree, l, s, m, array)
seg_init(tree, r, m + 1, e, array)
# 자식 노드의 최솟값 중 더 작은 값,
# 자식 노드의 최댓값 중 더 큰 값 저장
minimum = min(tree[l][0], tree[r][0])
maximum = max(tree[l][1], tree[r][1])
tree[node] = (minimum, maximum)
이렇게 모든 구간을 이진트리에 나누어 담아 최댓값과 최솟값을 저장했다. 이제 쿼리가 들어올 때마다 구간의 값을 이진트리 내에서 찾아내 반환하면 된다. 역시 이진탐색이다.
def seg_get(tree, node, s, e, l, r):
# s, e는 현재 호출된 함수가 탐색 중인 노드가 가리키는 범위
# l, r은 구해야 하는 쿼리의 범위
# 반환할 최솟값과 최댓값의 초기화
minimum, maximum = 1000000, 1
# s와 e가 쿼리의 범위를 벗어난 경우
if l > e or r < s: return (minimum, maximum)
# s와 e가 쿼리의 탐색 범위에 완전히 포함되는 경우
elif l <= s and r >= e: return tree[node]
# 현재 노드의 자식 노드에서 각각의 최댓값과 최솟값 불러오기
m = (s + e) // 2
left_minimum, left_maximum = seg_get(tree, node * 2, s, m, l, r)
right_minimum, right_maximum = seg_get(tree, node * 2 + 1, m + 1, e, l, r)
return (min(left_minimum, right_minimum), max(left_maximum, right_maximum))
이진탐색으로 쿼리 구간 내의 최댓값과 최솟값을 O(logn)에 구할 수 있다.
이제 얻은 구간 최댓값에서 최솟값을 빼주고 출력하면 끝
정답 코드
import sys
input = sys.stdin.readline
def seg_init(tree, node, s, e, array):
if s == e:
tree[node] = (array[s], array[s])
else:
m = (s + e) // 2
l = node * 2
r = l + 1
seg_init(tree, l, s, m, array)
seg_init(tree, r, m + 1, e, array)
mn = min(tree[l][0], tree[r][0])
mx = max(tree[l][1], tree[r][1])
tree[node] = (mn, mx)
def seg_get(tree, node, s, e, l, r):
mn, mx = 1000000, 0
if l > e or r < s: return (1000000, 0)
elif l <= s and e <= r: return tree[node]
m = (s + e) // 2
lmn, lmx = seg_get(tree, node * 2, s, m, l, r)
rmn, rmx = seg_get(tree, node * 2 + 1, m + 1, e, l, r)
return (min(lmn, rmn), max(lmx, rmx))
def solution():
N, Q = map(int, input().split())
cow = [0] + [int(input()) for _ in range(N)]
seg_h = 0
while 1 << seg_h < N:
seg_h += 1
seg_h += 1
seg = [(1000000, 0)] * (1 << seg_h)
seg_init(seg, 1, 1, N, cow)
for _ in range(Q):
A, B = map(int, input().split())
mn, mx = seg_get(seg, 1, 1, N, A, B)
print(mx - mn)
solution()