PS/BOJ

11660: 구간 합 구하기 5

전라남도교육지원청 2023. 10. 30. 22:29

 

11660번: 구간 합 구하기 5

첫째 줄에 표의 크기 N과 합을 구해야 하는 횟수 M이 주어진다. (1 ≤ N ≤ 1024, 1 ≤ M ≤ 100,000) 둘째 줄부터 N개의 줄에는 표에 채워져 있는 수가 1행부터 차례대로 주어진다. 다음 M개의 줄에는 네

www.acmicpc.net

기초적인 다이나믹 프로그래밍 문제다. 입력에서 테스트케이스 개수가 100,000까지 주어지는 걸로 봐서는 절대 매번 따로 연산하는 방법으로 해결할 수 없다. 두 좌표가 주어지면, 그 좌표를 기준으로 미리 처리된 값으로 연산해 연산 횟수를 일정하게 유지하도록 해야 하는 문제.

 

입력은 N*N 정사각행렬 형태로 주어진다. 처음 문제를 봤을 때는 (x1, y1)부터 개행하며 (x2, y2)까지의 모든 항을 합하는 문제인 줄 알았는데 그게 아니라 (x1, y1)을 좌상단, (x2, y2)를 우하단으로 하는 사각형 안쪽 항의 값을 더하는 문제였다. 그래서 처음 짠 코드는

for i in range(1, N):
    dp[0][i] += dp[0][i - 1]
for i in range(1, N):
    dp[i][0] += dp[i - 1][N - 1]
    for j in range(1, N):
        dp[i][j] += dp[i][j - 1]

이런 식이었다. 헛짓거리였고...

 

데이터의 위치가 [i][j]일 때, [0][0]부터의 누적합이 필요하다.

a b c
d e f
g h i

이렇게 행렬이 주어지면

a a+b a+b+c
a+d a+b+d+e a+b+c+d+e+f
a+d+g a+b+d+e+g+h a+b+c+d+e+f+g+h+i

이걸 만들어야 한다. 이렇게 만들고 나면 (x1, y1)부터 (x2, y2)까지의 합은 누적합 행렬의 (x2, y2) - (x1 - 1, y2) - (x2, y1 - 1) + (x1 - 1, y1 - 1)을 구하면 된다. 각 테스트케이스마다 연산횟수가 4번으로 고정된다. 여기서 엿같았던 점은 x1과 y1은 1씩 빼주어야 올바르게 연산이 된다는 점이었다. 디버깅할 때 이거 찾느라 거의 20분 썼다...

 

누적합이랍시고 짜본 코드

for i in range(1, N):
    dp[0][i] += dp[0][i - 1]
for i in range(1, N):
    dp[i][0] += dp[i - 1][0]
    for j in range(1, N):
        dp[i][j] += dp[i - 1][j] + dp[i][j - 1]

이 코드는 틀렸다. 내가 원한 누적합은 현재 좌표보다 x값이나 y값이 크지 않은 모든 데이터의 합인데 이 코드로 누적합을 구하면 x값과 y값이 작은 값들이 여러번 더해진다. 문제에서 주어진 4*4 행렬의 누적합 행렬은

1 3 6 10
3 8 15 24
6 15 27 42
10 24 42 64

로 나와야 하는데 위 코드로 구한 누적합은 이렇다.

1 3 6 10
3 9 19 34
6 19 43 83
10 34 83 173

dp[i][j] += dp[i - 1][j] + dp[i][j - 1] 부분이 문제였다. 행에 대한 누적합을 먼저 구해놓고, 열에 대한 누적합을 구해야 중복없는 결과가 나온다.

 

수정한 코드

for i in range(N):
    for j in range(1, N):
        dp[i][j] += dp[i][j - 1]
        
for i in range(N):
    for j in range(1, N):
        dp[i][j] += dp[i - 1][j]

이후로는 index out of range 예방용 배열을 하나 더 만들어서(N + 1 * N + 1) 0행 0열을 0으로 채워주고 끝냈다.

여기서 뭔가 더 줄일 부분이 있었겠지만 일단 제출은 시간초과였다. 어디서 pypy3로 제출하면 문법은 같지만 시간에서 유리하다고 한 말을 주워듣고 pypy3로 제출했더니 맞았다. 좌표 연산이 복잡해서 이 부분 최적화하면 해결 될 것 같다.

 

정답 코드

더보기

PyPy3

#11660: 구간 합 구하기 5
N,M = map(int, input().split())
t = [0 for i in range(N)]
for i in range(N):
    t[i] = list(map(int, input().split()))
for i in range(N):
    for j in range(1,N):
        t[i][j] += t[i][j - 1]
for j in range(N):
    for i in range(1,N):
        t[i][j] += t[i - 1][j]
table = [[0 for i in range(N + 1)] for j in range(N + 1)]
for i in range(N):
    for j in range(N):
        table[i + 1][j + 1] = t[i][j]

for i in range(M):
    xy = list(map(int, input().split()))
    result = table[xy[2]][xy[3]]
    result -= table[xy[0] - 1][xy[3]]
    result -= table[xy[2]][xy[1] - 1]
    result += table[xy[0] - 1][xy[1] - 1]
    print(result)