세그먼트 트리

6 minute read

변하는 배열의 구간 합 구하기

크기가 $N$인 정수 배열 $Arr$가 있는데, 다음과 같은 연산을 수행하려고 한다.

  • 구간 $[l, r] (0 \leq l \leq r \leq N - 1)$의 합 $A[l] + A[l+1] + \cdots + A[r]$ 구하기
  • $A[i] = x(0 \leq i \leq N - 1)$ : 배열의 값 변경

연산이 $M$개일 때 문제를 해결하려고 한다.

그냥 구하기

구간 합

int sum = 0;
for (int i = l; i <= r; i++) {
    sum += arr[i];
}

$l$부터 $r$까지의 배열의 원소에 접근하여 순회하며 합을 구한다. 시간복잡도는 $O(N)$

배열 값 변경

배열 접근 시간이 $O(1)$이니 시간복잡도는 $O(1)$

전체 시간복잡도

구간 합은 $O(N)$, 배열 값 변경은 연산 $M$개에 대해 $O(M)$로 총 시간복잡도는 $O(NM)$

누적합으로 구하기

구간 합

cumulSum[0] = arr[0];
for (int i = 1; i < N; i++) {
    cumulSum[i] = arr[i] + cumulSum[i - 1];
}

누적합 배열을 $O(N)$의 시간을 들여 만들어놓으면 구간합은 $O(1)$에 가능하다.

배열 값 변경

배열 값 변경은 $O(1)$이지만, 누적합 배열 또한 변경을 해줘야 하기 때문에 $O(N)$

전체 시간복잡도

연산 $M$개에 대해 $O(NM)$
위 두 방법처럼 $O(NM)$의 연산을 수행하면 많이 느려서 문제의 시간 제한을 맞추지 못할 수 가 있다.

세그먼트 트리

세그먼트 트리는 노드에 구간 합을 저장해 구간 합 연산과 배열 값 변경을 $O(\log N)$에 할 수 있게 하는 자료구조다.

0-4
0-4
0-2
0-2
3-4
3-4
3
3
4
4
0-1
0-1
2
2
0
0
1
1
1
1
2
2
3
3
6
6
7
7
4
4
5
5
8
8
9
9
Text is not SVG - cannot display

리프 노드가 아닌 노드들은 노드가 가리키고 있는 범위의 합이 각각의 노드에 저장돼있다. 리프 노드의 경우 배열의 값이 그대로 저장돼있다. 부모 노드의 값은 자식 노드의 값의 합과 같다.

트리 만들기

세그먼트 트리는 정 이진 트리(full binary tree)로 모든 노드가 0개 또는 2개의 자식을 갖는다. 그러므로 배열로 표현하는 것이 구현하기 좋다.

  • index 노드의 왼쪽 자식 노드: index * 2
  • index 노드의 왼쪽 자식 노드: index * 2

트리를 배열로 표현하기 위해서는 배열의 크기, 즉 트리가 가진 노드의 개수가 필요하다. 트리의 높이 $H=\lceil \log N \rceil$이고, 높이가 $H$인 포화 이진 트리(perfect binary tree)의 노드의 개수는 $2^{H+1}-1$이다. index 0를 사용하지 않으니 배열의 크기를 $2^{H+1}$로 설정한다.

int h = static_cast<int>(ceil(log2(N)));
int treeSize = 1 << (h + 1);
vector<long long> tree(treeSize);

트리의 크기를 결정해서 배열을 선언했으면 트리 안의 값을 채워야 한다. 이는 재귀적으로 가능하다.

long long initTree(vector<long long> &arr, vector<long long> &tree, int node, int low, int high) {
    if (low == high) {
        return tree[node] = arr[low];
    }
    int mid = low + (high - low) / 2;
    return tree[node] = initTree(arr, tree, node * 2, low, mid) + initTree(arr, tree, node * 2 + 1, mid + 1, high);
}

lowhigh는 각각 arr에서 참조할 범위의 맨 왼쪽 index와 오른쪽 index를 나타낸다. low == high의 경우 리프 노드를 의미하기 때문에 tree[node] = arr[low]로 대입하고 그 값을 반환한다. 아닐 경우 lowhigh의 중간값 mid를 정의하고 왼쪽 자식 노드의 값과 오른쪽 자식 노드의 값을 재귀적으로 설정해준 후 합을 저장한다.

0-4
15
0-4...
0-2
6
0-2...
3-4
9
3-4...
3
4
3...
4
5
4...
0-1
3
0-1...
2
3
2...
0
1
0...
1
2
1...
1
1
2
2
3
3
6
6
7
7
4
4
5
5
8
8
9
9
Text is not SVG - cannot display

노드의 붉은색 숫자가 노드에 저장된 값이라고 할 때, 리프 노드는 배열이 그대로 저장되고 리프 노드가 아닌 노드는 자식 노드의 값들의 합을 저장하고 있다.

트리 업데이트하기

arr[idx]가 업데이트된다고 했을 때, 해당 원소를 구간에 포함하고 있는 트리의 노드의 값 또한 변경해야 한다. 이 또한 재귀적으로 가능하다.

void updateTree(vector<long long> &tree, int node, int idx, int low, int high, long long diff) {
    if (idx < low || idx > high) {
        return;
    }
    tree[node] += diff;
    if (low != high) {
        int mid = low + (high - low) / 2;
        updateTree(tree, node * 2, idx, low, mid, diff);
        updateTree(tree, node * 2 + 1, idx, mid + 1, high, diff);
    }
}
  • idx[low, high]를 벗어나는 경우(idx < low || idx > high)
    해당 노드와 자식 노드들에게 영향이 없으니 바로 반환한다.
  • 그 외의 경우
    tree[node]diff(기존 값과의 차이)를 더한 후 리프 노드가 아니면 자식 노드들에 대해 updateTree를 호출한다. 트리의 자식을 타고 내려가며 연산을 하니 시간복잡도는 $O(\log N)$이다.
0-4
16
0-4...
0-2
7
0-2...
3-4
9
3-4...
3
4
3...
4
5
4...
0-1
3
0-1...
2
4
2...
0
1
0...
1
2
1...
1
1
2
2
3
3
6
6
7
7
4
4
5
5
8
8
9
9
arr[2] = 3→4
arr[2] = 3→4
Text is not SVG - cannot display

예를 들어 arr[2] = 4로 업데이트하면 해당 index==2를 포함하는 범위의 노드는 diff==1만큼 증가한다. 초록색 범위가 index를 포함하는 범위고 파란색이 변경된 값이다.

구간 합 구하기

구간 합을 구하는 연산도 $O(\log N)$에 가능하다.

long long sum(vector<long long> &tree, int node, int low, int high, int left, int right) {
    if (left > high || right < low) {
        return 0;
    }
    if (left <= low && right >= high) {
        return tree[node];
    }
    int mid = low + (high - low) / 2;
    return sum(tree, node * 2, low, mid, left, right) + sum(tree, node * 2 + 1, mid + 1, high, left, right);
}

[left, right]의 구간 합을 구할 때 세 가지 경우가 존재한다.

  • [left, right][low, high]와 겹치지 않을 때
    해당 범위에 대응하는 노드의 값을 사용하지 않으니 0을 반환한다.
  • [left, right][low, high]를 완전히 포함할 때
    해당 범위의 구간 합이 노드에 저장돼있으니 tree[node]를 반환한다.
  • [left, right][low, high]와 일부 겹칠 경우(위 두 가지 경우가 아닌 경우)
    자식 노드들에 대해 sum을 호출한 후 합을 반환한다.
0-4
16
0-4...
0-2
7
0-2...
3-4
9
3-4...
3
4
3...
4
5
4...
0-1
3
0-1...
2
4
2...
0
1
0...
1
2
1...
1
1
2
2
3
3
6
6
7
7
4
4
5
5
8
8
9
9
[2, 4]
[2, 4]
0
0
4
4
4
4
9
9
13
13
Text is not SVG - cannot display

[2, 4]의 구간 합을 구하는 것을 표현해보았다. 분홍색은 일부 겹치는 경우, 주황색은 겹치지 않는 경우, 초록색은 완전히 포함하는 경우를 나타낸다. 분홍색 노드는 자식 노드의 합, 주황색 노드는 0, 녹색 노드는 노드 값을 반환하여 [2, 4]의 구간 합은 13인 것을 알 수 있다.

전체 시간복잡도

$M$개의 $O(\log N)$ 연산을 하니 총 시간복잡도는 $O(M \log N)$이다.

예제

백준 #2042 구간 합 구하기를 해당 방식으로 풀 수 있다.

#include <iostream>
#include <vector>
#include <cmath>

using namespace std;

long long initTree(vector<long long> &arr, vector<long long> &tree, int node, int low, int high) {
    if (low == high) {
        return tree[node] = arr[low];
    }
    int mid = low + (high - low) / 2;
    return tree[node] = initTree(arr, tree, node * 2, low, mid) + initTree(arr, tree, node * 2 + 1, mid + 1, high);
}

void updateTree(vector<long long> &tree, int node, int idx, int low, int high, long long diff) {
    if (idx < low || idx > high) {
        return;
    }
    tree[node] += diff;
    if (low != high) {
        int mid = low + (high - low) / 2;
        updateTree(tree, node * 2, idx, low, mid, diff);
        updateTree(tree, node * 2 + 1, idx, mid + 1, high, diff);
    }
}

long long sum(vector<long long> &tree, int node, int low, int high, int left, int right) {
    if (left > high || right < low) {
        return 0;
    }
    if (left <= low && right >= high) {
        return tree[node];
    }
    int mid = low + (high - low) / 2;
    return sum(tree, node * 2, low, mid, left, right) + sum(tree, node * 2 + 1, mid + 1, high, left, right);
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    int N, M, K;
    cin >> N >> M >> K;
    vector<long long> arr(N);
    int h = static_cast<int>(ceil(log2(N)));
    int treeSize = 1 << (h + 1);
    vector<long long> tree(treeSize);
    for (int i = 0; i < N; i++) {
        cin >> arr[i];
    }
    initTree(arr, tree, 1, 0, N - 1);
    for (int i = 0; i < M + K; i++) {
        int op;
        long long left, right;
        cin >> op >> left >> right;
        if (op == 1) {
            updateTree(tree, 1, left - 1, 0, N - 1, right - arr[left - 1]);
            arr[left - 1] = right;
        } else {
            cout << sum(tree, 1, 0, N - 1, left - 1, right - 1) << '\n';
        }
    }
    return 0;
}

백준 #11505 구간 곱 구하기 또한 세그먼트 트리를 사용해 풀 수 있다. 단, 트리 업데이트 시 diff가 아닌 값 자체를 업데이트하는 방식으로 살짝 변경해야 한다.

변하는 배열의 구간 합, 구간 곱뿐만 아니라 구간 최솟값, 구간 최댓값 또한 세그먼트 트리 방식으로 구현 가능하다.

변하지 않는 경우 구간 연산은 희소 테이블을 사용하는 것을 고려할 수 있다.
희소 테이블

Comments