펜윅 트리 (바이너리 인덱스 트리)

블로그: 세그먼트 트리 (Segment Tree) 에서 풀어본 문제를 Fenwick Tree를 이용해서 풀어보겠습니다. Fenwick Tree는 Binary Indexed Tree라고도 하며, 줄여서 BIT라고 합니다.

Fenwick Tree를 구현하려면, 어떤 수 X를 이진수로 나타냈을 떄, 마지막 1의 위치를 알아야 합니다.

  • 3 = 112
  • 5 = 1012
  • 6 = 1102
  • 8 = 10002
  • 9 = 10012
  • 10 = 10102
  • 11 = 10112
  • 12 = 11002
  • 16 = 100002

마지막 1이 나타내는 값을 L[i]라고 표현하겠습니다. L[3] = 1, L[10] = 2, L[12] = 4이 됩니다.

수 N개를 A[1] ~ A[N]이라고 했을 때, Tree[i]A[i] 부터 앞으로 L[i]개의 합이 저장되어 있습니다.

아래 그림은 각각의 i에 대해서, L[i]를 나타낸 표입니다. 아래 초록 네모는 i부터 앞으로 L[i]개가 나타내는 구간입니다.

L[i] = i & -i가 됩니다. 그 이유는 아래와 같습니다.

      -num = ~num + 1
       num = 100110101110101100000000000
      ~num = 011001010001010011111111111
      -num = 011001010001010100000000000
num & -num = 000000000000000100000000000

A = [3, 2, 5, 7, 10, 3, 2, 7, 8, 2, 1, 9, 5, 10, 7, 4]인 경우에, 각각의 Tree[i]가 저장하고 있는 값은 다음과 같게 됩니다.

예를 들어, Tree[12]에는 12부터 앞으로 L[12] = 4개의 합은 A[9] + A[10] + A[11] + A[12]가 저장되어 있습니다. Tree[7]에는 7부터 앞으로 L[7] = 1개의 합인 A[7]이 저장되어 있습니다.

합 구하기

Tree를 이용해서 A[1] + ... + A[13]은 어떻게 구할 수 있을까요?

13을 이진수로 나타내면 1101입니다. 따라서, A[1] + ... + A[13] = Tree[1101] + Tree[1100] + Tree[1000]이 됩니다. Tree의 인덱스는 이진수입니다.

1101 -> 1100 -> 1000는 마지막 1의 위치를 빼면서 찾을 수 있습니다. 이것을 코드로 작성해보면 다음과 같습니다.

int sum(int i) {
    int ans = 0;
    while (i > 0) {
        ans += tree[i];
        i -= (i & -i);
    }
    return ans;
}

모든 i에 대해서, A[1] + ... + A[i]를 구하는 과정을 그림으로 나타내면 다음과 같습니다.

어떤 구간의 합 A[i] + ... + A[j]A[1] + ... + A[j]에서 A[1] + ... + A[i-1]을 뺀 값과 같습니다. 따라서, sum(j) - sum(i-1)을 이용해서 구할 수 있습니다.

변경

어떤 수를 변경한 경우에는, 그 수를 담당하고 있는 구간을 모두 업데이트해줘야 합니다. 아래와 같이 마지막 1의 값을 더하는 방식으로 구현할 수 있습니다.

void update(int i, int num) {
    while (i <= n) {
        tree[i] += num;
        i += (i & -i);
    }
}

아래 그림은 i를 변경했을 때, 바꿔줘야하는 Tree[i]를 나타낸 그림입니다.

2042번 문제: 구간 합 구하기를 Fenwick Tree를 이용해서 풀어봤습니다.

#include <cstdio>
#include <vector>
using namespace std;
long long sum(vector<long long> &tree, int i) {
    long long ans = 0;
    while (i > 0) {
        ans += tree[i];
        i -= (i & -i);
    }
    return ans;
}
void update(vector<long long> &tree, int i, long long diff) {
    while (i < tree.size()) {
        tree[i] += diff;
        i += (i & -i);
    }
}
int main() {
    int n, m, k;
    scanf("%d %d %d",&n,&m,&k);
    vector<long long> a(n+1);
    vector<long long> tree(n+1);
    for (int i=1; i<=n; i++) {
        scanf("%lld",&a[i]);
        update(tree, i, a[i]);
    }
    m += k;
    while (m--) {
        int t1;
        scanf("%d",&t1);
        if (t1 == 1) {
            int t2;
            long long t3;
            scanf("%d %lld",&t2,&t3);
            long long diff = t3-a[t2];
            a[t2] = t3;
            update(tree, t2, diff);
        } else {
            int t2,t3;
            scanf("%d %d",&t2,&t3);
            printf("%lld\n",sum(tree, t3) - sum(tree, t2-1));
        }
    }
    return 0;
}

함께 읽으면 좋은 글


댓글 (6개) 댓글 쓰기


mic1021 10달 전

자료구조라는 게 추상적인 개념이라 이해하기 힘들었습니다만 이 글보고나서 트리구조를 이해하는 데 도움이 많이 되었습니다. 좋은 설명 감사합니다:)


nisroeld99 5달 전

감사합니다


taso 5달 전

감사합니다! ( )


bright2013 3달 전

감사합니다 ! 도움이 많이 되네요 ㅎㅎ


nberserk 3달 전

L[12] =3 이 아닐까요?


bachjs 2달 전

저도 첨에 이글만 보고 이상하다 싶어서 다른 글 찾아 봤는데 짝수는 0의 갯수로 범위를 정하네요~ 2=10 -> 2^1=2 4=100 -> 2^2=4 12=1100->2^2=4 이런식이네요~