
세그먼트 트리에서 [a,b] 범위에서 연속 구간 합의 최댓값을 구하는 방법을 알아봅시다.
파이썬은 재귀 호출이 매우 느리기 때문에 쿼리 횟수가 많아지면 시간 내에 문제를 해결하기 어렵습니다. 따라서 본문에서는 연속 구간합의 최대 쿼리를 비재귀 방식으로 구현했습니다.
0. 사전 지식
이 내용을 알아보기 전에 다음과 같은 사전 지식이 필요합니다.
- 세그먼트 트리의 초기화
- 세그먼트 트리의 쿼리
- 비재귀 세그먼트 트리
1. 문제 상황
일단 지금 풀고 있는 문제는 백준 10167번: 금광입니다. 2014 KOI 중등부 문제로 악명이 높습니다. 지금은 이게 웰노운이 되었다고 하는데 또 저만 모르는 웰노운인 거예요. 아무튼 이걸 해결하기 위해 공부하다보니 세그먼트 트리 스위핑 이외에도 알아야 할 테크닉이 있었습니다. 바로 구간 합을 다루는 세그먼트 트리에서 연속된 구간의 구간합 최대를 빠르게 찾는 것입니다.
예를 들어 다음 배열을 봅시다.
arr = [-4, 1, -3, 6, -7, 3, 2, 7, 10, -4, 15]
11개의 값을 가진 배열입니다. 이걸 가지고 구간합 세그먼트 트리를 구성한다면 다음과 같은 트리가 만들어집니다.

여기서 최대가 되는 구간합은 어떻게 찾을 수 있을까요? 아쉽게도 이 세그먼트 트리 하나만으로는 원하는 답을 빠르게 찾을 수 없습니다. 만약 [1, 11] 구간에 대해 최대가 되는 구간합을 찾겠다고 한다면, [1, 11]에 포함되는 [l, r] 구간을 모두 설정해서 쿼리를 시도해봐야 합니다. 쿼리 한 번에 O(logN)이 소요되지만 l, r 쌍을 정하는 데 O(N^2)이 소요됩니다. 즉, O(N^2logN)이라는 시간이 소요됩니다.

원하는 답은 [6, 11]구간의 합인 33입니다. 루트 인덱스가 1이라고 했을 때, seg[21] + seg[11] + seg[12] + seg[26]으로 얻을 수 있는 답입니다.
2. 최대 구간합
전체 구간의 최대 구간합을 구하는 원래 문제 상황으로 돌아가봅시다. 세그먼트 트리는 탐색 구간을 둘로 나누며 관리합니다. 최대 구간합도 두 부분으로 나누어 구할 수 있습니다.

전체 구간 내에서 최대 구간합이 나타나는 구간은 반드시 "연속된" 구간입니다. 이 점을 유의하면 최댓이 나올 수 있는 구간은 다음 세 가지 경우 중 하나가 됩니다.

최대가 되는 연속 구간은 반드시 이 셋 중 하나에 존재합니다. 그럼 각각의 경우는 또 어떻게 구할 수 있을지 살펴봅시다.
1) 왼쪽 경계에 붙은 구간합
왼쪽 경계에 붙은 경우는 다음 두 가지 중 하나입니다.

전체 구간을 둘로 나누었을 때, 최대 구간이 왼쪽에만 포함된 경우와 왼쪽 전체를 포함하고 오른쪽의 왼쪽 최댓을 합한 경우입니다. 현 구간 node에 대해 왼쪽 경계에 붙은 최댓값을 l_max[node]라고 해보면, 각각의 경우는 다음과 같습니다.
왼쪽 구간에만 포함되는 경우 = l_max[node*2]
오른쪽 구간 일부를 포함하는 경우 = seg[node*2] + l_max[node*2+1]
l_max[node] = max(l_max[node*2], seg[node*2] + l_max[node*2+1]
2) 오른쪽 경계에 붙은 구간합
오른쪽도 마찬가지입니다.

이렇게 두 가지 경우가 있으며 역시 현재 구간 node에 대하여 오른쪽 경계에 붙은 최값을 r_max[node]라고 한다면,
오른쪽 구간에만 포함되는 경우 = r_max[node*2+1]
왼쪽 구간 일부를 포함하는 경우 = r_max[node*2] + seg[node*2+1]
r_max[node] = max(r_max[node*2+1], r_max[node*2] + seg[node*2+1]
3) 그 사이 어딘가에 있는 경우
이번엔 경우의 수가 3가지 입니다.

이번에는 현재 구간 node 내에서 최대 구간합이 나타나는 부분을 t_max[node]라고 표현하겠습니다. 여기서 주의할 점은 t_max[node]는 경우에 따라 l_max[node] 또는 r_max[node]와 같을 수도 있습니다. 어쨌든 구간 내에서의 최대이기 때문에 경계에 붙어있는 경우가 최대일 수도 있다는 거죠. 각각의 경우는 또 이렇게 나타낼 수 있게 됩니다.
왼쪽 구간에만 포함되는 경우 = t_max[node*2]
오른쪽 구간에만 포함되는 경우 = t_max[node*2+1]
양쪽에 걸친 경우 = r_max[node*2] + l_max[node*2+1]
세 번째 경우가 좀 특이합니다. 양쪽에 걸쳤다는 건, 왼쪽 구간에서는 오른쪽에 붙어있는 구간이, 오른쪽 구간에서는 왼쪽에 붙어있는 구간이 합쳐져 이루어지는 구간이라는 뜻입니다.
4) 종합
이제 현재 구간 node에 대해 왼쪽, 오른쪽 경계에 붙은 각각의 구간 최대합과 양쪽 구간에 걸친 구간 최대합을 모두 구할 수 있습니다. seg, l_max, r_max, t_max를 다음과 같이 초기화할 수 있습니다.
n = int(input())
arr = list(map(int, input().split()))
void = -float('inf')
size = 1
while size < n:
size <<= 1
seg = [0] * size * 2
l_max = [void] * size * 2
r_max = [void] * size * 2
t_max = [void] * size * 2
for i in range(n):
seg[size+i] = arr[i]
l_max[size+i] = arr[i]
r_max[size+i] = arr[i]
t_max[size+i] = arr[i]
for i in range(size-1, 0, -1):
seg[i] = seg[i*2] + seg[i*2+1]
l_max[i] = max(l_max[i*2], seg[i*2] + l_max[i*2+1])
r_max[i] = max(r_max[i*2+1], seg[i*2+1] + r_max[i*2])
t_max[i] = max(max(t_max[i*2], t_max[i*2+1]), r_max[i*2] + l_max[i*2+1])
void 값은 구간합이 음수가 되는 곳에서 초기 값이 0으로 되었을 때 값 반영이 되지 않는 오류를 방지하기 위한 값입니다. 구간곱 문제에는 이 테크닉을 사용해본 적 없는데 이에 대해서는 아마 0이 되어야 할 것 같습니다.
3. 쿼리
이제 초기화된 세그먼트 트리를 가지고 구간 내 연속 최대 구간합(? 뭐라 불러야 할지 참 애매합니다)을 구하는 쿼리를 알아보겠습니다. 아까 예시는 버리고 길이 13인 배열을 가지고 만들어진 세그먼트 트리를 살펴보겠습니다.

배열의 3번째 값과 11번째 값이 들어있는 seg[18]과 seg[26]에 각각 l, r 포인터가 붙었습니다. 비재귀 세그먼트 트리의 쿼리 방식은 사실상 투포인터입니다. 요 두 포인터를 가지고 나중에 병합해봐야 할 데이터가 있는 노드만 찾아보겠습니다. 포인터의 위치를 보면 우리가 찾아야 할 쿼리는 [3, 11] 구간입니다. 비재귀 방식은 리프에서 출발해야 하기 때문에 각각 size-1을 더해 18, 26 인덱스를 가리키도록 합니다. 이제 가리키는 인덱스가 홀수인지 짝수인지 확인하며 상위 노드로 이동합니다.
그 전에 우리가 목표로 하는 구간을 담고 있는 노드를 표시해보겠습니다.

파란색으로 표시된 노드들의 구간을 합치면 [3, 11] 구간이 나오게 됩니다. 한 번의 상향식 탐색으로 이 노드들을 모두 찾았을 때, 구간 순서로 정렬이 되어있다면 이 작업이 가능합니다.

먼저 첫 번째 구간인 [3, 4] 구간에 대해 구간 최댓값을 구합니다. 그 다음으로 [5, 8] 구간의 구간 최댓값을 가지고 [3, 8] 구간의 구간 최댓값을 구합니다. 구간을 하나씩 붙여가며 이 과정을 반복해 [3, 11] 구간의 구간 최댓값을 구할 수 있습니다.
병합할 때는 지금까지 합쳐진 구간의 최대 구간과 오른쪽 끝 구간을 알고 있어야 합니다. 다음 두 가지 경우가 존재하기 때문입니다.

세부적으로 다음 구간이 포함되는 경우를 3가지로 나누었지만 결국 현재 구간의 l_max 값이 사용되는 경우는 없다는 걸 알 수 있습니다. 혹시 왼쪽 경계에 붙은 구간이 최대가 되는 경우가 있다면, 그 값은 역시 t_max에 동일하게 들어있기 때문에 l_max는 결국 필요 없습니다. 다만 다음 구간은 l_max가 반드시 필요합니다. 즉, 현재 구간 now에 대해 다음 구간 next를 병합한 새로운 구간 new의 구간 최댓값들은 이렇게 정리됩니다.
현재 구간에만 포함되는 경우 = now_total
다음 구간에 포함되는 경우 = max(now_right+l_max[next], t_max[next])
그리고 병합되는 구간도 오른쪽 경계에 붙은 최댓값이 필요합니다.
병합된 구간의 오른쪽 최댓값 = max(now_right + seg[next], r_max[next])
이제 병합까지 끝나면 쿼리 구간의 최대 연속합은 now_total에 저장되어있게 됩니다.
정답 코드
import sys
input = sys.stdin.readline
n = int(input())
arr = list(map(int, input().split()))
void = -float('inf')
size = 1
while size < n:
size <<= 1
seg = [0] * size * 2
l_max = [void] * size * 2
r_max = [void] * size * 2
t_max = [void] * size * 2
for i in range(n):
seg[size+i] = arr[i]
l_max[size+i] = arr[i]
r_max[size+i] = arr[i]
t_max[size+i] = arr[i]
for i in range(size-1, 0, -1):
seg[i] = seg[i*2] + seg[i*2+1]
l_max[i] = max(l_max[i*2], seg[i*2] + l_max[i*2+1])
r_max[i] = max(r_max[i*2+1], seg[i*2+1] + r_max[i*2])
t_max[i] = max(max(t_max[i*2], t_max[i*2+1]), r_max[i*2] + l_max[i*2+1])
for _ in range(int(input())):
l, r = map(int, input().split())
l += size-1; r += size-1
l_merge = []
r_merge = []
while l < r:
if l & 1:
l_merge.append(l)
l += 1
l >>= 1
if not(r & 1):
r_merge.append(r)
r -= 1
r >>= 1
if l == r:
l_merge.append(l)
r_merge.reverse()
l_merge += r_merge
right, total = r_max[l_merge[0]], t_max[l_merge[0]]
for i in range(1, len(l_merge)):
total = max(max(total, right+l_max[l_merge[i]]), t_max[l_merge[i]])
right = max(right+seg[l_merge[i]], r_max[l_merge[i]])
print(max(total, right))