TIW: Binary Indexed Tree

11 minute read

Binary indexed tree, also called Fenwick tree, is a pretty advanced data structure for a specific use. Recall the range sum post: binary indexed tree is used to compute the prefix sum. In the prefix sum problem, we have an input array v, and we need to calculate the sum from the first item to index k. There are two operations. Update: change the number at one index by adding a value (not resetting the value), and query: getting the sum from begin to a certain index. How do we do it? There are two trivial ways:

  1. Every time someone queries the sum, just loop through it and return the sum. O(1) update, O(n) query.
  2. Precompute the prefix sum array, and return the precomputed value from the table. O(n) update, O(1) query.

To illustrate the differences and better explain what we’re trying to achieve, I will write the code for both approaches. They are not the theme of this post though.

class Method1 {
private:
    vector<int> x;
public:
    Method1(int size) {
        x = vector<int>(size);
    }
    void update(int v, int k) {
        x[k] += v;
    }
    int query(int k) {
        int ans = 0;
        for (int i = 0; i <= k; i++)
            ans += x[i];
        return ans;
    }
};
class Method2 {
private:
    vetor<int> s;
public:
    Method2(int size) {
        s = vector<int>(size);
    }
    void update(int v, int k) {
        for (; k < s.size(); k++)
            x[k] += v;
    }
    int query(int k) {
        return s[k];
    }
};

Read through this and make sure you can write this code with ease. One note before we move on: we’re computing the sum from the first item to index k, but in general we want the range sum from index i to index j. To obtain range sum, you can simply subtract the prefix sums: query(j)-query(i-1).

OK, that looks good. If we make a lot of updates, we use method 1; if we make a lot of queries, we use method 2. What if we make the same amount of updates and queries? Say we make n each operations, then no matter which method we use, we end up getting O(n^2) time complexity (verify!). We either spend too much time pre-computing or too much time calculating the sum over and over again. Is there any way to do better?

Yes, of course! Instead of showing the code and convincing you that it works, I will derive it from scratch.

The quest for log(n)

The problem: say we have same amount of updates and queries, and we do not want to bias the computation on one of them. So we do a bit of pre-computation, and a bit of summation. That’s the goal.

Say we have an array of 8 numbers, {1, 2, 3, 4, 5, 6, 7, 8}. To calculate the sum of first 7 numbers, we would like to sum up a bunch of numbers (since there has to be a bit of summation). But the amount of numbers to be summed has to be sub-linear. Let’s say we want it to be log(n). log2(7) is almost 3, then maybe we can sum 3 numbers. In this case, we choose to sum the 3 numbers: sum{1, 2, 3, 4}, sum{5, 6} and sum{7}. Assume that we have these sums already pre-computed, we have log(n) numbers to sum, hence querying will be log(n). For clarity, let me put everything in a table:

Table 1a

sum{1} = sum{1}
sum{1, 2} = sum{1, 2}
sum{1, 2, 3} = sum{1, 2} + sum{3}
sum{1, 2, 3, 4} = sum{1, 2, 3, 4}
sum{1, 2, 3, 4, 5} = sum{1, 2, 3, 4} + sum{5}
sum{1, 2, 3, 4, 5, 6} = sum{1, 2, 3, 4} + sum{5, 6}
sum{1, 2, 3, 4, 5, 6, 7} = sum{1, 2, 3, 4} + sum{5, 6} + sum{7}
sum{1, 2, 3, 4, 5, 6, 7, 8} = sum{1, 2, 3, 4, 5, 6, 7, 8}

The left hand side of the table is the query, and all the terms on the right hand side are pre-computed. If you look closely enough you will see the pattern: for summing k numbers, first take the largest power of 2, 2^m, that is ≤ k, and pre-compute it. Then for the rest of the numbers, k-2^m, take the largest power of 2, 2^m’ such that 2^m’ ≤ k-2^m, and pre-compute it, and so on.

There are two steps to do: show that querying (adding terms on the right hand side) is log(n) and show that pre-computing the terms on the right hand side is log(n).

Querying is log(n) is easily seen, because by taking out the largest power of 2 each time, we will at least take out half of the numbers (Use proof by contradiction). Taking out no less than one half each time, after O(log(n)) time we would have taken out all of it.

Now we are one step from finishing on the theoretical side: how do we pre-compute those terms?

Let’s say we want to change the number 1 into 2, essentially carrying out update(1, 0). Look at the terms above: we need to change sum{1}, sum{1, 2}, sum{1, 2, 3, 4} and sum{1, 2, 3, 4, 5, 6, 7, 8}. Each time we update one more pre-computed term, we cover double the number of elements in the array. Therefore we also only need to update log(n) terms. Let’s see it in a table:

Table 1b

update 1: sum{1}, sum{1, 2}, sum{1, 2, 3, 4}, sum{1, 2, 3, 4, 5, 6, 7, 8}
update 2: sum{1, 2}, sum{1, 2, 3, 4}, sum{1, 2, 3, 4, 5, 6, 7, 8}
update 3: sum{3}, sum{1, 2, 3, 4}, sum {1, 2, 3, 4, 5, 6, 7, 8}
update 4: sum{1, 2, 3, 4}, sum{1, 2, 3, 4, 5, 6, 7, 8}
update 5: sum{5}, sum{5, 6}, sum{1, 2, 3, 4, 5, 6, 7, 8}
update 6: sum{5, 6}, sum{1, 2, 3, 4, 5, 6, 7, 8}
update 7: sum{7}, sum{1, 2, 3, 4, 5, 6, 7, 8}
update 8: sum{1, 2, 3, 4, 5, 6, 7, 8}

Cool, now we have a vague idea about what to pre-compute for update and what to add for query. Now we should figure out the details of the code.

How is the code written?

First, we need to determine the representation of the pre-computed terms. Here is a list of all pre-computed terms:

{1}, {1, 2}, {3}, {1, 2, 3, 4}, {5}, {5, 6}, {7}, {1, 2, 3, 4, 5, 6, 7, 8}

The last number of each term is unique and covers the range 1-8. That’s great news! We can use a vector to store these terms easily, and let the index of the array be the last number of the term. For example, the sum of {5, 6} will be stored at bit[6].

First, the query operation. Let’s revisit the table with binary representation of numbers:

Table 2a:

revised version of table 1a, with sums written as bit elements, indices in binary

query 0001: bit[0001]
query 0010: bit[0010]
query 0011: bit[0011]+bit[0010]
query 0100: bit[0100]
query 0101: bit[0101]+bit[0100]
query 0110: bit[0110]+bit[0100]
query 0111: bit[0111]+bit[0110]+bit[0100]
query 1000: bit[1000]

Do you see the pattern yet? Hint: for queries that have k ones, we have k terms on the right. The pattern is that while the index has at least 2 ones, we remove the lowest bit that is one, then move on to the next term. 0111->0110->0100. Finally, here’s the code:

int query(vector<int>& bit, int k) {
    int ans = 0;
    for (k++; k; k -= k & (-k))
        ans += bit[k];
    return ans;
}

After all the work we’ve been through, the code is extremely concise! Two things to notice: the k++ is to change the indexing from 0-based to 1-based, as we can see from the above derivation we go from 1 to 8, instead of 0 to 7. The second thing is the use of k & (-k) to calculate the lowest bit. You can refer to the previous blog post on bitwise operations.

OK, we’re almost done. What about update? Another table:

Table 2b:

revised version of table 1b

update 0001: bit[0001], bit[0010], bit[0100], bit[1000]
update 0010: bit[0010], bit[0100], bit[1000]
update 0011: bit[0011], bit[0100], bit[1000]
update 0100: bit[0100], bit[1000]
update 0101: bit[0101], bit[0110], bit[1000]
update 0110: bit[0110], bit[1000]
update 0111: bit[0111], bit[1000]
update 1000: bit[1000]

What’s the pattern this time? Hint: again, look for the lowest bit! Yes, this time instead of removing the lowest bit, we add the lowest bit of the index to itself. This is less intuitive than the last part. For example, lowest bit of 0101 is 1, so the next index is 0101+1 = 0110; lowest bit is 0010, next index is 0110+0010 = 1000.

So here’s the code, note also the k++:

void update(vector<int>& bit, int v, int k) {
    for (k++; k < bit.size(); k += k & (-k))
        bit[k] += v;
}

This is deceivingly easy! That can’t be right. It can’t be that easy… Actually it can, if you look at the category of this post; nothing I write about is hard.

Actually it was easy because we were just matching patterns and assuming it would generalize. Let’s study the for loops a little more to understand why and how low bits are involved in this. This is rather complicated, so for practical purposes you might as well skip them.

First, observe that the low bit of each index indicates the number of integers that index sums over. Say, 0110 has low bit 0010 which is 2, so bit[6] is a sum of two numbers: 5 and 6. This is by design, since this is exactly how we picked the pre-computed terms, so there is no way of explanation.

Second, bit[k] is the sum from indices k-lowbit(k)+1 to k. This is a direct consequence from (1) bit[k] is a summation that ends at the kth number and (2) bit[k] sums over lowbit(k) numbers.

In light of this fact, the code for querying becomes clear: for an index k, we first get the sum from k-lowbit(k)+1 to k from bit[k], then we need to find the sum from 1 to k-lowbit(k). The latter becomes a sub-problem, which is solved by setting k-lowbit(k) as the new k value and going into the next iteration.

For updating, it is much trickier. From the above, we have l-lowbit(l) < k ≤ l, iff bit[l] includes k. Below is a sketch of proof, the actual proof will include more details and be more tedious and boring to go through. For the kth number, bit[k+lowbit(k)] must include it. This is because the lowbit of k+lowbit(k) must be at least 2 times lowbit(k), so k+lowbit(k)-lowbit(k+lowbit(k)) ≤ k-lowbit(k) < k ≤ k+lowbit(k), satisfying the inequality. Also, k can be updated to k+lowbit(k) in the next iteration, because given lowbit(k) < lowbit(m) < lowbit(n) and that bit[m] includes k and bit[n] includes m, bit[n] must include k as well. Till now, we have shown that the bit[k] values we have modified in the for loop must include k.

Then, we also need to show that all bit[l] values that include k are modified in our loop. We can actually count all the bit[l] values that include k: it is equal to one plus the number of zeros before lowbit(k). It is not difficult to see how the for loop reduces the number of zeros before lowbit(k) each time the loop moves on to the next iteration. The only question remaining is why that number? Let’s look at the table 2b again. The numbers of terms for the first four entries, i.e. {4, 3, 3, 2}, are one more than the number of terms for the second four entries, i.e. {3, 2, 2, 1}. This is by design, because bit[4] covers the first four but not the second four, and everything else is pretty symmetric. Again, the first two entries have one more coverage than the second two entries, because bit[2] records the first two but not the second two. Hence, each time we “go down the tree” on a “zero edge” (appending a 0 to the index prefix), the numbers will be covered once more than if we “go down the tree” on a “one edge” (appending a 1 to the index prefix). After we hit the low bit, no more terms of smaller low bits will cover this index, and of course the index itself includes itself, thus the plus one. This is a basic and very rough explanation on how the numbers of zeros relate to the number of terms including a certain index. Here we have argued semi-convincingly the update loop is valid and complete.

OK, anyway, time for practice: Range Sum Query - Mutable It’s literally implementing a binary indexed tree, nothing more.

class NumArray {
private:
    vector<int> bit;
    void update_helper(int v, int k) {
        for (k++; k < bit.size(); k += k & (-k))
            bit[k] += v;
        }
    int query_helper(int k) {
        int ans = 0;
        for (k++; k; k -= k & (-k))
            ans += bit[k];
        return ans;
    }
public:
    NumArray(vector<int> &nums) {
        bit.resize(nums.size()+1);
        for (int i = 0; i < nums.size(); i++)
            update_helper(nums[i], i);
    }
    void update(int i, int val) {
        update_helper(val-query_helper(i)+query_helper(i-1), i);
    }
    int sumRange(int i, int j) {
        return query_helper(j)-query_helper(i-1);
    }
};

It got a little complicated because I didn’t store the original values, so we need some work on line 21 to calculate the change at a certain index given the new value and the old range sums. But that’s nothing important.

That’s it for the basic introduction of binary indexed trees. There are some variants to it, such as replacing the + sign in update function to a min or max function to take prefix min or max, or extending the algorithm to a 2D matrix, aka 2D binary indexed tree. We can even use it for some dynamic programming questions. There are in fact a few more questions on Leetcode that uses this data structure. But that’s for later.

I learned binary indexed tree through the TopCoder tutorial. If you think I did a really bad job and you do not understand at all, you can refer to it as well.