https://www.acmicpc.net/problem/25489
시간 제한 | 메모리 제한 |
4초 | 1536MB |
문제
국렬이는 작년 여름에 사용하다 남은 전구 스트립을 이용해서 자신의 자취방에 있는 N개의 정점으로 이루어진 트리를 장식하려고 한다.
전구는 트리의 각 정점에 하나씩 설치되어 있다. i번째 정점에 붙어있는 전구는 전원을 넣었을 때 p_i의 확률로 켜진다. 그리고 전구에 여유가 넘치는 국렬이는 간선에 추가 전구를 하나씩 달았다. 이 추가 전구들은 연결된 두 정점에 설치된 전구들 중 하나만이 켜졌을 때 불이 켜진다.
추가로 국렬이는 Q번 특정 정점에 설치된 전구를 다른 전구로 바꿀 것이다. 각 시점별로 불이 들어오는 전구 개수의 기댓값을 구하여라.
입력
첫 번째 줄에는 정점의 개수 N이 주어진다. (2 ≤ N ≤ 500,000)
두 번째 줄에는 p_1부터 p_N이 주어진다. 이는 정확히 소수점 아래 두 자리까지 주어진다. (0.00 ≤ p_i ≤ 1.00)
세 번째 줄부터 (N+1)번째 줄까지 두 정수 u, v가 주어진다. 이는 주어진 트리에 정점 u와 정점 v를 연결하는 간선을 의미한다. (1 ≤ u,v ≤ N, u ≠ v)
(N+2)번째 줄에는 전구를 바꾸는 횟수를 의미하는 Q가 주어진다. (1 ≤ Q ≤ 500,000)
그다음 Q개의 줄에 걸쳐서 양의 정수 u와 음이 아닌 실수 p가 주어진다. 이는 정점 u에 설치된 전구를 켜질 확률이 p인 전구로 교체하는 것을 의미한다. p는 정확히 소수점 아래 두 자리까지 주어진다. (1 ≤ u ≤ N, 0.00 ≤ p ≤ 1.00)
출력
첫 번째 줄에는 초기의 불이 들어오는 전구의 기댓값을 출력한다. 그다음 Q개의 줄에 걸쳐서 각 시점별로 불이 들어오는 전구 개수의 기댓값을 출력한다. 출력한 값과 정답과의 절대 오차 또는 상대 오차가 10^-6 이하여야 한다.
풀이
뭔가 복잡해보입니다. 전체 트리에서 켜지는 전구 개수의 기댓값을 어떻게 구할 수 있을까요?
여기에는 "기댓값의 선형성"이라는 개념이 필요합니다. 이것 하나만으로도 꽤 깊게 공부할 수 있는 주제이지만 저도 문제 풀이에 딱 필요한 만큼만 생각해보고 풀었으니 이 문제에 필요한 만큼만 다뤄보도록 하겠습니다. 트리에 N개의 전구가 있고 간선에 N-1개의 전구가 있습니다. 각각의 전구는 켜질 확률를 갖고 있습니다. 간선에 있는 전구도 입력으로 주어지지 않았을 뿐이지 양끝에 있는 정점의 전구가 켜지는 확률로 켜질 확률을 구할 수 있습니다. 정점 u와 v를 잇는 간선에 설치된 다음 전구 e의 켜질 확률 pe은 이렇게 구할 수 있습니다.
u가 켜져있고 v가 꺼져있을 때 간선의 전구는 켜집니다. 반대로 u가 꺼져있고 v가 켜져있을 때 간선의 전구는 켜집니다.
pe = pu * (1 - pv) + (1 - pu) * pv
이렇게 보면 각각의 전구가 켜질 확률을 얻는 것은 그리 어렵지 않은 것 같습니다. 그럼 전체 트리에서 켜지는 전구 개수의 기댓값은 어떻게 구할까요? pi의 확률로 켜지는 전구를 생각해봅시다. 이 하나의 전구만을 놓고 봤을 때 켜지는 전구 개수의 기댓값은 이 전구가 켜질 확률과 같습니다. 같은 확률의 전구가 2개 있다면 어떨까요? 기댓값은 당연히 2배가 될겁니다. 3개라면 3배가 됩니다. 그럼 서로 다른 확률을 가진 전구 두 개가 있다면 어떨까요? 켜지는 전구 개수의 기댓값은 두 전구가 각각 갖는 기댓값의 합과 같습니다. 이것이 기댓값의 선형성에서 가장 먼저 다루는 확률변수의 특성입니다. 이것만 알면 이 문제를 해결 할 수 있습니다.
그럼 전체 트리에서 켜지는 전구 개수의 기댓값은 각각의 전구가 갖는 켜질 확률의 총합과 같습니다. 쉽게 구할 수 있습니다. 여기서 문제는 쿼리가 주어지는데, 최대 500,000번의 쿼리에서 정점 전구의 확률이 조정된다고 합니다. 그렇다면 그 정점과 연결된 간선의 전구도 확률에 영향을 받습니다. 만약 이 정점이 나머지 모든 정점과 다 연결되어있다고 해봅시다. 그렇게 되면 최대 500,000개의 정점을 갖는 트리에서 시간 복잡도는 O(N*Q)로 탐색 횟수만 2,500억회에 이릅니다. 줄일 방법을 찾아봐야 겠습니다.
이 트리를 봅시다. 정점 번호를 따로 부여하지 않았습니다. 전체 기댓값의 총합, 그러니까 켜지는 전구 개수의 기댓값은 현재 9.5개입니다. 여기서 하나의 정점을 건드려 보겠습니다.
빨간색으로 표시된 정점의 확률을 1.00에서 0.50으로 조정했습니다. 연결된 간선은 총 4개, 각각의 확률을 모두 조정해 기댓값을 다시 구했습니다. 이걸 한꺼번에 구할 수 있는 방법이 있습니다. 정점 전구의 확률이 변할 때 간선 전구의 확률이 어떻게 바뀌는지 알아봅시다.
트리의 일부분을 가져왔습니다. 각 정점 전구의 켜질 확률은 n, p, s1, s2, s3입니다. 간선 전구의 켜질 확률은 ep, e1, e2, e3입니다. n을 m으로 변화시키면 어떤 변화가 있을까요?
[정점 전구 확률 P의 변화]
P → P - n + m
먼저 정점 전구의 확률은 하나의 정점만 변했기 때문에 n을 빼고 m을 더해주는 것으로 끝입니다. 간선 전구를 봅시다.
[간선 전구 확률 e의 변화]
ep = p * (1 - n) + (1 - p) * n = p + n - 2*p*n
→ p * (1 - m) + (1 - p) * m = p + m - 2*p*m = ep - n + m + 2*p*n - 2*p*m = ep - (n - m) * (1 - 2*p)
e1 = s1 * (1- n) + (1 - s1) * n = s1 + n - 2*s1*n
→ s1 * (1 - m) + (1 - s1) * m = s1 + m - 2*s1*m = e1 - n + m + 2*s1*n - 2*s1*m = e1 - (n - m) * (1 - 2*s1)
...
변화의 양상이 아주 일정합니다. 다른 변수 없이 n만 m으로 바꿔주기만 하면 끝입니다. 그리고 연결된 각 정점의 전구의 확률도 모두 합해서 한꺼번에 계산할 수 있을 것 같습니다.
그럼 이제 모든 임의로 1을 루트로 설정하고 모든 정점에 대해 트리 탐색을 시도합니다. 각 정점마다 부모 정점이 무엇인지, 자식 정점은 무엇인지 정리해두고, 각 정점별로 자식 정점 전구의 확률 총합을 관리합니다. 이제 쿼리가 들어올 때마다 해줘야 할 일은 세 가지입니다.
1. 정점 전구의 확률 변화 적용
2. 부모 정점으로 올라가 자식 정점의 확률 변화 적용
3. 자식 정점 전구 확률 총합으로 이전 확률을 연산해 연결된 모든 간선의 확률을 빼고 새로운 확률 더하기
정답 코드
import sys
from collections import deque
input = sys.stdin.readline
def solution():
n = int(input())
p = [0] + list(map(float, input().split()))
graph = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u, v = map(int, input().split())
graph[u].append(v)
graph[v].append(u)
tree = [[] for _ in range(n + 1)]
link_p = [[0, 0] for _ in range(n + 1)]
check = [0] * (n + 1)
tree[1].append(0)
check[1] = 1
bfs = deque([1])
while bfs:
now = bfs.popleft()
for next in graph[now]:
if check[next]: continue
bfs.append(next)
check[next] = 1
tree[now].append(next)
tree[next].append(now)
link_p[now][0] += p[next]
link_p[now][1] += 1 - p[next]
res = sum(p)
for i in range(1, n + 1):
res += p[i] * link_p[i][1] + (1 - p[i]) * link_p[i][0]
print(res)
for _ in range(int(input())):
u, po = map(str, input().split())
u = int(u)
po = float(po)
parent = tree[u][0]
if parent:
res += (1 - p[parent] * 2) * (po - p[u])
link_p[parent][0] += po - p[u]
link_p[parent][1] += p[u] - po
res += (po - p[u]) * link_p[u][1] + (p[u] - po) * link_p[u][0]
res += po - p[u]
p[u] = po
print(res)
solution()