세그먼트 트리 나중에 업데이트 해야지!

배열 A가 있고, 여기서 다음과 같은 두 연산을 수행해야하는 문제가 있습니다. 10999번 문제: 구간 합 구하기 2

  1. 구간 l, r (l ≤ r)이 주어졌을 때, A[l] + A[l+1] + ... + A[r-1] + A[r]을 구해서 출력하기
  2. i번째 수부터 j번째 수에 v를 더하기. A[i] += v, A[i+1] += v, ..., A[j-1] += v, A[j] += v

수행해야하는 연산은 최대 M번입니다.

이 문제를 블로그: 세그먼트 트리 (Segment Tree)를 이용해서 풀어볼 수 있습니다.

#include <cstdio>
#include <cmath>
#include <vector>
using namespace std;
// a: 배열 a
// tree: 세그먼트 트리
// node: 세그먼트 트리 노드 번호
// node가 담당하는 합의 범위가 start ~ end
long long init(vector<long long> &a, vector<long long> &tree, int node, int start, int end) {
    if (start == end) {
        return tree[node] = a[start];
    } else {
        return tree[node] = init(a, tree, node*2, start, (start+end)/2) + init(a, tree, node*2+1, (start+end)/2+1, end);
    }
}
void update(vector<long long> &tree, int node, int start, int end, int index, long long diff) {
    if (index < start || index > end) return;
    tree[node] = tree[node] + diff;
    if (start != end) {
        update(tree,node*2, start, (start+end)/2, index, diff);
        update(tree,node*2+1, (start+end)/2+1, end, index, diff);
    }
}
long long sum(vector<long long> &tree, int node, int start, int end, int left, int right) {
    if (left > end || right < start) {
        return 0;
    }
    if (left <= start && end <= right) {
        return tree[node];
    }
    return sum(tree, node*2, start, (start+end)/2, left, right) + sum(tree, node*2+1, (start+end)/2+1, end, left, right);
}
int main() {
    int n, m, k;
    scanf("%d %d %d",&n,&m,&k);
    vector<long long> a(n);
    int h = (int)ceil(log2(n));
    int tree_size = (1 << (h+1)) - 1;
    vector<long long> tree(tree_size);
    m += k;
    for (int i=0; i<n; i++) {
        scanf("%lld",&a[i]);
    }
    init(a, tree, 1, 0, n-1);
    while (m--) {
        int t1,t2,t3;
        scanf("%d",&t1);
        if (t1 == 1) {
            int start, end;
            long long v;
            scanf("%d %d %lld",&start,&end, &v);
            start -= 1;
            end -= 1;
            for (int i=start; i<=end; i++) {
                a[i] += v;
                update(tree, 1, 0, n-1, i, v);
            }
        } else if (t1 == 2) {
            int t2,t3;
            scanf("%d %d",&t2,&t3);
            printf("%lld\n",sum(tree, 1, 0, n-1, t2-1, t3-1));
        }
    }
    return 0;
}

세그먼트 트리를 이용하면 1번 연산은 O(NlgN)만에 구할 수 있습니다. 세그먼트 트리에서 수 하나를 변경하는데 O(lgN)만큼의 시간이 걸리게 됩니다. 따라서, 수 j-i+1개를 업데이트해야 하기 때문에, 총 걸리는 시간은 O(NlgN)이 걸리게 됩니다.

update 함수를 조금 변경해, 수 1개가 아닌 여러 개를 업데이트할 수 있게 바꿀 수 있습니다.

#include <cstdio>
#include <cmath>
#include <vector>
using namespace std;
// a: 배열 a
// tree: 세그먼트 트리
// node: 세그먼트 트리 노드 번호
// node가 담당하는 합의 범위가 start ~ end
long long init(vector<long long> &a, vector<long long> &tree, int node, int start, int end) {
    if (start == end) {
        return tree[node] = a[start];
    } else {
        return tree[node] = init(a, tree, node*2, start, (start+end)/2) + init(a, tree, node*2+1, (start+end)/2+1, end);
    }
}
void update_range(vector<long long> &tree, int node, int start, int end, int left, int right, long long diff) {
    if (left > end || right < start) {
        return;
    }
    // leaf 노드
    if (start == end) {
        tree[node] += diff;
        return;
    }
    update_range(tree,node*2, start, (start+end)/2, left, right, diff);
    update_range(tree,node*2+1, (start+end)/2+1, end, left, right, diff);
    tree[node] = tree[node*2] + tree[node*2+1];
}
long long sum(vector<long long> &tree, int node, int start, int end, int left, int right) {
    if (left > end || right < start) {
        return 0;
    }
    if (left <= start && end <= right) {
        return tree[node];
    }
    return sum(tree, node*2, start, (start+end)/2, left, right) + sum(tree, node*2+1, (start+end)/2+1, end, left, right);
}
int main() {
    int n, m, k;
    scanf("%d %d %d",&n,&m,&k);
    vector<long long> a(n);
    int h = (int)ceil(log2(n));
    int tree_size = (1 << (h+1)) - 1;
    vector<long long> tree(tree_size);
    m += k;
    for (int i=0; i<n; i++) {
        scanf("%lld",&a[i]);
    }
    init(a, tree, 1, 0, n-1);
    while (m--) {
        int t1,t2,t3;
        scanf("%d",&t1);
        if (t1 == 1) {
            int start, end;
            long long v;
            scanf("%d %d %lld",&start,&end, &v);
            start -= 1;
            end -= 1;
            for (int i=start; i<=end; i++) {
                a[i] += v;
            }
            update_range(tree, 1, 0, n-1, start, end, v);
        } else if (t1 == 2) {
            int t2,t3;
            scanf("%d %d",&t2,&t3);
            printf("%lld\n",sum(tree, 1, 0, n-1, t2-1, t3-1));
        }
    }
    return 0;
}

update_range의 원리는 suminit을 합치는 것입니다. sum은 어떤 구간 [left, right]에 포함되는 노드 중에서 루트에서 가장 가까운 노드를 찾는 방식이라 left <= start && end <= right에는 더 이상 탐색을 수행하지 않고 바로 return을 했습니다.

update_range는 그 숫자를 트리상에서 찾고, 업데이트 해야 합니다. 따라서, 리프 노드를 나타내는 start == end가 나올 때 까지 계속해서 탐색을 하게 되고, 리프 노드가 아닌 경우에는 두 자식의 합을 계산하는 방식을 사용하게 됩니다.

효율적으로 보이지만, 전체 숫자를 업데이트할면 트리의 모든 노드를 방문해야 합니다. 트리의 노드 개수는 NlgN이기 때문에, update_range도 O(NlgN)이 걸립니다.

Segment Tree Lazy Propagation (세그먼트 트리 나중에 업데이트 하기로 번역했습니다) 을 사용하면 이런 구간 업데이트 연산을 효율적으로 사용할 수 있게 됩니다.

아래 그림은 세그먼트 트리에서 3~7을 변경하는 경우에 변경해야 하는 노드를 초록색 또는 파란색으로 칠한 그림입니다.

파란색으로 색칠되어 있는 3~4와 5~7은 변경해야하는 구간 3~7에 포함됩니다. 따라서, 3~4와 5~7을 루트로하는 서브트리는 모두 3~7에 포함되게 됩니다. 이런 경우에는 더 이상 업데이트를 수행하지 않고, 나중에 다시 업데이트를 수행하러 그 노드에 방문했을 때, 업데이트를 진행해도 됩니다.

이렇게 업데이트를 미룰 때 사용하는 배열이 lazy가 됩니다.

lazy[i]i번 노드가 담당하는 구간에 더해져야할 수가 저장되어 있습니다.

3~4를 나타내는 노드의 lazy에 10이 저장되어 있다면, 3번째 수와 4번째 수에 10을 더해야 하는데, 나중에 10을 더하겠다는 의미를 가지게 되고, 5~7의 lazy에 20이 저장되어 있다면, 5, 6, 7번째 수에 20을 더해야 하지만, 지금은 더하지 않고 나중에 더하겠다는 의미를 가지게 됩니다.

예를 들어, 다음과 같은 경우를 살펴봅시다.

A[0] A[1] A[2] A[3] A[4] A[5] A[6] A[7] A[8] A[9]
3 6 2 5 3 1 8 9 7 3

이 정보를 가지고 세그먼트 트리를 만들면 아래 그림과 같습니다. 위쪽 숫자는 각 노드가 담당하고 있는 범위, 아래 숫자는 저장되어 있는 값입니다.

여기서 3~7번째 수에 2를 더한다면, 업데이트해야하는 노드는 아래 그림과 같은 초록색과 파란색입니다.

초록색 노드는 담당하고 잇는 구간이 3~7에 일부만 포함되는 경우, 파란색은 모두 포함되는 경우입니다.

초록색 노드인 경우에는 일반적인 세그먼트 트리를 업데이트하는 방식으로 진행하면 되지만, 파란색 노드의 경우에는 조금 특별하게 진행해야 합니다.

파란색 노드의 아래에 있는 노드는 모두 업데이트하려고 하는 구간 3~7에 포함됩니다. 따라서, 아래에 있는 노드의 업데이트는 나중에 필요할 때 하기로 하고, 그 값을 lazy[i]에 적어둡니다.

A[0] A[1] A[2] A[3] A[4] A[5] A[6] A[7] A[8] A[9]
3 6 2 7 5 3 10 11 7 3

앞으로는 항상 어떤 노드를 방문할 때마다 lazy 값이 있는지를 검사해야 합니다. 만약에, lazy 값이 0이 아니라면, 현재 노드에 해당하는 값을 올바르게 더해주고, 자식 노드에게 lazy를 물려줘야 합니다. 여기서 올바르게란, 단순히 lazy[i]의 값을 더하는 것이 아니고, lazy[i]의 값에 end-start+1의 값을 곱해서 더하는 것을 의미합니다. 해당하는 노드가 담당하는 구간이 start ~ end라면, 총 담당하는 수의 개수는 end-start+1개 이기 때문에, 곱해서 더해주어야 합니다.

이제 4~9번째 수에 1을 더해봅시다.

A[0] A[1] A[2] A[3] A[4] A[5] A[6] A[7] A[8] A[9]
3 6 2 7 6 4 11 12 8 4

업데이트해야 하는 노드는 아래와 같습니다.

실제로는 3~4에서 3번과 4번을 담당하는 노드를 호출하기 때문에, 3번만 담당하는 노드도 호출하게 됩니다.

3번과 4를 담당하는 노드에 방문했을 때는, lazy값이 0보다 크기 때문에, lazy값을 먼저 업데이트해주고나서 1을 더해주게 됩니다. 따라서, 트리에 저장되어 있는 값과 lazy 값은 아래 그림과 같게 됩니다.

마지막으로 6에서 8까지 합을 구하는 과정을 살펴보겠습니다.

6에서 8의 합을 구하려면 아래 그림에서 초록색으로 색칠되어 있는 정점을 방문해야 합니다.

  1. 루트 노드 입니다. 0~9는 6~8과 겹치기 때문에, 좌우 자식 0~4와 5~9를 호출합니다.
  2. 0~4는 6~8과 전혀 겹치지 않기 때문에, 0을 리턴합니다.
  3. 5~9는 6~8과 겹치기 때문에, 좌우 자식 5~7과 8~9를 호출합니다.
  4. 5~7을 방문했는데, lazy값이 있습니다. 현재 노드를 업데이트하고, 자식에게 lazy를 물려주게 됩니다.

5~7에 저장되어 있는 값은 24이고, lazy의 값은 1이기 때문에, 5~7에 저장될 값은 24 + 1*(7-5+1) = 27이 됩니다. 이제 자식들에게 lazy를 물려주기 때문에, 자식들의 lazy에는 모두 1이 더해지게 됩니다.

5~7은 6~8과 겹치기 때문에, 자식 5~6과 7을 호출하게 됩니다.

  1. 5~6에는 lazy값이 있기 때문에, 현재 노드를 업데이트하고, 자식에게 lazy를 물려줍니다.

5~6도 6~8과 겹치기 때문에, 자식 5와 6을 호출하게 됩니다.

  1. 자식 5와 6은 합쳐서 설명합니다. 5와 6은 모두 lazy 값이 있기 때문에, 노드를 업데이트합니다. 자식은 없기 때문에, lazy를 물려줄 수 없습니다. 5는 6~8에 포함되지 않기 때문에 0을 리턴하고, 6은 6~8에 포함되기 때문에, 저장되어 있는 값을 리턴합니다.

  1. 7번 노드는 lazy값이 있기 때문에, 먼저 노드를 업데이트합니다. 6~8에 7은 포함되기 때문에, 저장되어 있는 값을 리턴하게 됩니다.

  1. 이 8~9번 노드를 방문하게 됩니다. 역시 lazy가 있기 때문에, 노드를 업데이트하고, 자식에게 lazy를 물려주게 됩니다. 그 다음, 8~9는 6~8과 겹치기 때문에, 두 자식을 각각 호출해야 합니다.

  1. 8번과 9번노드도 함께 설명합니다. 두 노드 모두 lazy가 있기 때문에, 노드를 업데이트합니다. 8은 6~8에 포함되기 때문에 저장되어 있는 값을 리턴하고, 9는 포함되지 않기 때문에 0을 리턴합니다.

이렇게, 구간의 업데이트가 필요한 경우에는 꼭 필요할 때가 오기 전까지 업데이트를 미뤄두는 방법을 사용하면 조금 더 빠르게 문제를 풀 수 있습니다.

10999번 문제: 구간 합 구하기 2

#include <cstdio>
#include <cmath>
#include <vector>
using namespace std;
// a: 배열 a
// tree: 세그먼트 트리
// node: 세그먼트 트리 노드 번호
// node가 담당하는 합의 범위가 start ~ end
long long init(vector<long long> &a, vector<long long> &tree, int node, int start, int end) {
    if (start == end) {
        return tree[node] = a[start];
    } else {
        return tree[node] = init(a, tree, node*2, start, (start+end)/2) + init(a, tree, node*2+1, (start+end)/2+1, end);
    }
}
void update_lazy(vector<long long> &tree, vector<long long> &lazy, int node, int start, int end) {
    if (lazy[node] != 0) {
        tree[node] += (end-start+1)*lazy[node];
        // leaf가 아니면
        if (start != end) {
            lazy[node*2] += lazy[node];
            lazy[node*2+1] += lazy[node];
        }
        lazy[node] = 0;
    }
}
void update_range(vector<long long> &tree, vector<long long> &lazy, int node, int start, int end, int left, int right, long long diff) {
    update_lazy(tree, lazy, node, start, end);
    if (left > end || right < start) {
        return;
    }
    if (left <= start && end <= right) {
        tree[node] += (end-start+1)*diff;
        if (start != end) {
            lazy[node*2] += diff;
            lazy[node*2+1] += diff;
        }
        return;
    }
    update_range(tree, lazy, node*2, start, (start+end)/2, left, right, diff);
    update_range(tree, lazy, node*2+1, (start+end)/2+1, end, left, right, diff);
    tree[node] = tree[node*2] + tree[node*2+1];
}
long long sum(vector<long long> &tree, vector<long long> &lazy, int node, int start, int end, int left, int right) {
    update_lazy(tree, lazy, node, start, end);
    if (left > end || right < start) {
        return 0;
    }
    if (left <= start && end <= right) {
        return tree[node];
    }
    return sum(tree, lazy, node*2, start, (start+end)/2, left, right) + sum(tree, lazy, node*2+1, (start+end)/2+1, end, left, right);
}
int main() {
    int n, m, k;
    scanf("%d %d %d",&n,&m,&k);
    vector<long long> a(n);
    int h = (int)ceil(log2(n));
    int tree_size = (1 << (h+1)) - 1;
    vector<long long> tree(tree_size);
    vector<long long> lazy(tree_size);
    m += k;
    for (int i=0; i<n; i++) {
        scanf("%lld",&a[i]);
    }
    init(a, tree, 1, 0, n-1);
    while (m--) {
        int t1,t2,t3;
        scanf("%d",&t1);
        if (t1 == 1) {
            int start, end;
            long long v;
            scanf("%d %d %lld",&start,&end, &v);
            update_range(tree, lazy, 1, 0, n-1, start-1, end-1, v);
        } else if (t1 == 2) {
            int start, end;
            scanf("%d %d",&start,&end);
            printf("%lld\n",sum(tree, lazy, 1, 0, n-1, start-1, end-1));
        }
    }
    return 0;
}

댓글 (7개) 댓글 쓰기




chatterboy 11달 전

정말 좋은 글이네요. 감사합니다.


pl0892029 10달 전

게으른 전파 알고리즘이군요. 'ㅂ' 시간복잡도 얘기도 같이 다루면 더 좋을 것 같아요! (이 문서를 이해하는 사람들은 이미 코드를 보고 시간복잡도를 알겠지만...)


peter1201 4달 전

많은 도움이 됬습니다. 감사합니다!