
세그먼트 트리는 구간의 정보를 빠르게 얻기 위한 자료구조입니다. 연속된 구간의 합이나 곱, xor 값 등을 처리하기에 좋습니다. 길이 N인 배열의 구간 정보를 얻기 위해서 선형탐색을 실시하면 O(N)으로 탐색이 가능합니다. 아주 느린 건 아니지만 전체 데이터가 좀 크다면 답이 안나옵니다. 그리고 쿼리의 실행 횟수가 많으면 많을수록, 트리의 업데이트가 잦을수록 이런 방법은 좋지 않습니다. 세그먼트 트리는 전체 배열을 이진트리로 분할하여 관리합니다. 양쪽 노드를 타고 가면서 구간의 좌, 우 경계 값과 맞닿은 노드들의 종합으로 원하는 결과를 O(logN)에 구할 수 있습니다. 다만 트리의 업데이트는 lazy propagation을 사용하지 않으면 여전히 O(N)의 시간이 소모됩니다.
우선 트리의 초기화, 쿼리 코드만 살펴봅시다.
1. 재귀 방식(recursive)
1) 세그먼트 트리의 초기화(top-down)
N = int(input())
arr = list(map(int, input().split()))
# 세그먼트 트리의 크기를 결정합니다.
# 필요한 최소 크기는 N보다 크거나 같은 가장 작은 2의 거듭제곱의 2배입니다.
size = 1
while size < N:
size <<= 1
seg = [0] * size * 2
# arr의 값을 세그먼트 트리에 배치합니다.
# node: 세그먼트 트리의 인덱스
# start, end: 현재 node가 관리할 원본 배열 구간의 시작과 끝(1~N)
def init_tree(arr, node, start, end):
# start, end 매개변수가 일치한다면 리프에 도달했다는 의미, 즉 구간의 길이 1입니다.
# 배열의 값을 그대로 불러와 노드에 저장하고 return 합니다.
if start == end:
seg[node] = arr[start - 1]
return
# 현재 구간을 반으로 나누어 양쪽 자식 노드의 값을 구하러 갑니다.
mid = (start + end) // 2
init_tree(arr, node*2, start, mid)
init_tree(arr, node*2+1, mid+1, end)
# 하위 구간 값으로 현재 구간 값 구성
# 여기서는 구간 합을 구성하는 예시를 썼습니다.
seg[node] = seg[node*2] + seg[node*2+1]
return
# node는 1, start와 end는 구간 전체인 1, N으로 하여 탐색을 시작합니다.
# 리프부터 상향식으로 트리가 초기화됩니다.
init_tree(arr, 1, 1, N)
재귀 방식으로 세그먼트 트리가 초기화 되었습니다. N이 16인 경우를 살펴봅시다.

N이 16이라면 트리 구성을 위해 최소 31개의 노드가 필요합니다. 하지만 구간 분할 연산의 편의를 위해 트리에서 0번 인덱스는 사용하지 않는 경우가 많습니다. 따라서 32개의 노드를 갖는 위의 트리가 완성됩니다. 루트의 인덱스는 1, 깊이 2의 노드는 각각 2, 3, 깊이 3의 노드는 4, 5, 6, 7을 인덱스로 부여 받습니다.
2) 구간 쿼리(구간 합의 경우)
이제 재귀 방식으로 구현된 세그먼트 트리가 어떻게 구간 값의 쿼리를 처리하는 지 코드로 살펴봅시다.
# left로부터 right까지의 구간 값을 가져올 쿼리입니다.
# start와 end가 left, right를 벗어나지 않는 탐색만 실행합니다.
def query(node, start, end, left, right):
# end가 left보다 작거나 start가 right보다 큰 경우는
# 현재 탐색하는 node가 관리하는 구간이 쿼리 구간을 완전히 벗어났음을 의미합니다.
# 이때는 구간 데이터에 어떤 연산을 하고 있는지에 따라 0, 1 등을 return 합니다.
# 지금 예시로 보고 있는 세그먼트 트리는 구간 합을 관리하기 때문에 0을 return 합니다.
if end < left or right < start:
return 0
# 반대로 start와 end가 [left, right] 구간에 완전히 포함되는 경우입니다.
# 이 경우는 현재 node의 자식을 살펴볼 필요가 없습니다.
# 현재 node에서 관리 중인 값을 그대로 return 합니다.
if left <= start and end <= right:
return seg[node]
# 위의 두 분기에서 처리되지 않은 탐색은 구간이 걸친 경우입니다.
# 좌, 우 자식 노드를 각각 쿼리하여 결과를 반환합니다.
mid = (start + end) // 2
return query(node*2, start, mid, left, right) + query(node*2+1, mid+1, end, left, right)
아까 N = 16이었던 트리에서 [3, 11] 구간의 값을 쿼리한다고 해봅시다. 호출 스택은 다음과 같습니다.
query(node:1, [1, 16])
├query(node:2, [1, 8])
│├query(node:4, [1, 4])
││├query(node:8, [1, 2]) -> out of query range
││└query(node:9, [3, 4]) -> return current node
││
│└query(node:5, [5, 8]) -> return current node
│
└query(node:3, [9, 16])
├query(node:6, [9, 12])
│├query(node:12, [9, 10]) -> return current node
│└query(node:13, [11, 12])
│ ├query(node:26, [11, 11]) -> return current node
│ └query(node:27, [12, 12]) -> out of query range
│
└query(node:7, [13, 16]) -> out of query range
이렇게 보니까 더 복잡한 것 같아요. 그림으로 봅시다.

표시된 노드들은 전부 쿼리로 호출된 노드들입니다.
- 그 중 주황색으로 표시된 노드는 구간을 완전히 벗어난 노드입니다.
- 파란색으로 표시된 노드들은 구간을 완전히 포함하기 때문에 하위 노드로 탐색이 진행되지 않은 노드들입니다.
- 노란색으로 표시된 노드들은 하위 쿼리를 통해 결과를 반환한 노드들입니다.
결국 [3, 11] 구간의 쿼리는 [3, 4] + [5, 8] + [9, 10] + [11, 11] 이렇게 4개의 노드가 관리하는 값을 종합해 결과를 내놓게 됩니다.
3) N이 2의 거듭제곱이 아닌 경우
N이 2의 거듭제곱인 경우는 트리가 깔끔하게 완전 이진 트리로 구성됩니다. 하지만 그렇지 않은 경우는 초기화 단계에서 몇 노드들을 빠뜨리게 됩니다. 아래 그림을 봅시다.

아까 예시로 작성했던 init_tree 함수가 N=13인 배열을 통해 세그먼트 트리를 초기화 한다면 이와 같은 트리가 완성됩니다.여기서 한 가지 불편한 점이 발생합니다. 바로 seg 배열의 인덱스가 연속성을 잃는다는 점입니다. 리프에 해당하는 노드들은 구간의 길이가 1이므로 각각의 값들을 관리하는 유일한 지점들이 됩니다. 그런데 6번째 값을 관리하는 seg[21]의 다음인 seg[22]에는 아무 것도 들어있지 않고 7번째 값은 seg[11]에 저장되어 있습니다. 8번째 값은 seg[12]가 아닌 seg[24]에 저장되어있습니다. 재귀 방식의 쿼리를 수행할 때 이 점이 문제점으로 작용하지는 않지만 세그먼트를 처음 배우는 입장에서 이 부분이 이해하는데 큰 걸림돌이 되었던 기억이 있습니다.
하지만 파이썬으로 세그먼트 트리를 구현한다면 백준 12844번 XOR 문제 같은 경우 이런 점이 큰 걸림돌이 되기도 합니다. 리프의 개수를 파악하기 어렵기 때문입니다.
2. 비재귀 방식(iterative)
세그먼트 트리를 비재귀 방식으로 구현할 수도 있습니다. 코드가 좀 더 복잡해지긴 하지만 트리의 상태는 보다 이해하기 편한 상태가 됩니다. 먼저 초기화 코드부터 봅시다.
1) 세그먼트 트리의 초기화(bottom-up)

딱 보면 리프 노드들이 가지런히 마지막 깊이에 놓인 걸 볼 수 있습니다. 이렇게 배치가 되면 리프를 다루는 데 압도적인 편리함이 생깁니다. 기존 세그먼트 트리에서는 각 리프에 도달 할 때마다 O(logN)의 시간을 소모하며 쿼리 구간이 1이 될 때까지 구간을 분할하며 탐색을 이어가야 합니다. 하지만 이 방식을 사용하면, 완전 이진트리를 구성했을 때 마지막 깊이의 첫 노드를 첫 번째 값으로, 이후 인덱스들을 순차적으로 리프에 할당해주면 됩니다.

트리의 크기를 정해주기 위해 재귀 방식과 동일하게 N개의 값을 담을 수 있는 완전 이진 트리에 필요한 노드 수를 구해줍니다. 그리고 리프의 시작은 트리 크기의 절반부터입니다.
N = int(input()) # 트리가 관리할 데이터의 수
# 2**(ceil(log2(N))+1) 과 동일한 연산
size = 1
while size < N:
size <<= 1
size *= 2
seg = [0] * size
leaf = size // 2
그럼 값의 할당은 매우 빠르게 이루어집니다. 재귀 방식의 세그먼트 트리에서는 N개의 데이터를 초기화할 때 O(NlogN)이 소모되지만 비재귀 방식에서는 O(N)으로 끝나게 됩니다.
arr = list(map(int, input().split()))
# 리프 노드에 값 배치
for i in range(N):
seg[leaf + i] = arr[i]
# 리프로부터 루트로 돌아오며 노드 값 갱신
for i in range(leaf-1, 0, -1):
seg[i] = seg[i*2] + seg[i*2+1]
리프 노드들이 모두 채워져 있기 때문에 이런 연산이 가능합니다. 다만 구간 곱을 다루는 쿼리일 경우, 당연히 비어있는 리프들은 1로 배치해두어야 부모 노드들의 값이 오염되지 않습니다. 예시의 세그먼트 트리는 구간 합이기 때문에 이런 부분은 적절하게 처리하면 되겠습니다.
2) 쿼리의 처리
이제 [5, 11] 구간의 쿼리를 처리하는 과정을 살펴봅시다. 쿼리 역시 상향식으로 이루어집니다. 5와 11번 값이 저장된 리프에서 출발합니다. 쿼리를 처리할 때는 구간의 왼쪽 끝과 오른쪽 끝이 가리키는 노드의 인덱스가 홀수인지 짝수인지가 중요합니다. 여기서 세그먼트 트리의 루트를 제외한 모든 노드들이 부모 노드의 왼쪽 또는 오른쪽 자식 노드라는 성질을 이용합니다. 만약 인덱스가 짝수라면 왼쪽 자식 노드이고 홀수라면 오른쪽 자식 노드입니다. 만약 내가 살펴보고 있는 노드가 구간의 왼쪽 끝인데 짝수 인덱스를 갖고 있다면 이 노드를 포함하는 부모 노드가 구간에 들어있다는 의미가 됩니다.

구간의 양 끝이 들어있는 리프에서 탐색을 시작합니다. left 값은 [20]을, right 값은 [26]을 가리키면서 구간의 양 끝을 나타내고 있습니다. 그리고 쿼리가 수행되는 과정에서 노드의 값들을 누적적으로 관리할 result 값을 생성했습니다.
이제 노드의 인덱스를 살펴봅시다. left 값은 현재 짝수입니다. 구간의 왼쪽이 짝수라는 건 부모 노드가 쿼리 구간을 포함하고 있다는 의미가 됩니다. 반대로 구간의 왼쪽이 홀수 인덱스라는 건 현재 노드의 형제(?) 노드는 구간 밖이라는 걸 의미합니다. 즉, 현재 left는 부모 노드로 옮겨가도 됩니다. left >> 1을 처리합니다.
right 값은 짝수입니다. 구간의 오른쪽이 짝수라는 건 왼쪽 끝과는 반대로 형제 노드가 구간 밖이라는 것을 의미합니다. 즉, 현재 노드만을 result에 포함시키고 현재 수준에서 왼쪽 노드로 한 칸 이동합니다. right - 1을 처리합니다. 그리고 여기서 왼쪽으로 이동한 노드는 반드시 부모의 오른쪽 자식이고, 쿼리 구간에 반드시 포함됩니다. 따라서 right >> 1도 처리합니다.

한 수준 위로 올라왔습니다.
이번에도 left가 나타내는 왼쪽 끝 인덱스는 짝수입니다. 아까와 동일하게 부모 노드로 이동합니다. left >> 1을 처리합니다.
이번에도 right가 나타내는 오른쪽 끝 인덱스는 짝수입니다. 현재 노드를 result에 포함시키고 right - 1, right >> 1을 처리합니다.

이제 left와 right가 가리키는 노드가 같아졌습니다. 현재 노드는 구간에 포함되어있습니다. result에 포함시키고 탐색을 마칩니다.

빨간색으로 표시된 노드들은 result에 포함된 노드들입니다. 각 값이 5~8, 9~10, 11~11 구간을 나타냅니다. 5~11 구간이 모두 얻어졌습니다.
코드를 살펴볼까요?
def query(l, r):
left = leaf + l - 1
right = leaf + r - 1
result = 0
while left < right:
if left & 1:
result += seg[left]
left += 1
left >>= 1
if right % 2 == 0:
result += seg[right]
right -= 1
right >>= 1
# left 와 right 값이 엇갈린 경우는
# 하위 노드가 구간이 이어지며 모든 쿼리가 끝난 상태
# 따라서 left == right일 때만 마지막으로 result에 현재 노드를 포함
if left == right:
result += seg[left]
return result