TIW: Binary Indexed Tree

Binary indexed tree, also called Fenwick tree, is a pretty advanced data structure for a specific use. Recall the range sum post: binary indexed tree is used to compute the prefix sum. In the prefix sum problem, we have an input array v, and we need to calculate the sum from the first item to index k. There are two operations. Update: change the number at one index by adding a value (not resetting the value), and query: getting the sum from begin to a certain index. How do we do it? There are two trivial ways:

  1. Every time someone queries the sum, just loop through it and return the sum. O(1) update, O(n) query.
  2. Precompute the prefix sum array, and return the precomputed value from the table. O(n) update, O(1) query.

To illustrate the differences and better explain what we’re trying to achieve, I will write the code for both approaches. They are not the theme of this post though.

class Method1 {
private:
    vector<int> x;
public:
    Method1(int size) {
        x = vector<int>(size);
    }
    void update(int v, int k) {
        x[k] += v;
    }
    int query(int k) {
        int ans = 0;
        for (int i = 0; i <= k; i++)
            ans += x[i];
        return ans;
    }
};
class Method2 {
private:
    vetor<int> s;
public:
    Method2(int size) {
        s = vector<int>(size);
    }
    void update(int v, int k) {
        for (; k < s.size(); k++)
            x[k] += v;
    }
    int query(int k) {
        return s[k];
    }
};

Read through this and make sure you can write this code with ease. One note before we move on: we’re computing the sum from the first item to index k, but in general we want the range sum from index i to index j. To obtain range sum, you can simply subtract the prefix sums: query(j)-query(i-1).

OK, that looks good. If we make a lot of updates, we use method 1; if we make a lot of queries, we use method 2. What if we make the same amount of updates and queries? Say we make n each operations, then no matter which method we use, we end up getting O(n^2) time complexity (verify!). We either spend too much time pre-computing or too much time calculating the sum over and over again. Is there any way to do better?

Yes, of course! Instead of showing the code and convincing you that it works, I will derive it from scratch.

The quest of log(n)

The problem: say we have same amount of updates and queries, and we do not want to bias the computation on one of them. So we do a bit of pre-computation, and a bit of summation. That’s the goal.

Say we have an array of 8 numbers, {1, 2, 3, 4, 5, 6, 7, 8}. To calculate the sum of first 7 numbers, we would like to sum up a bunch of numbers (since there has to be a bit of summation). But the amount of numbers to be summed has to be sub-linear. Let’s say we want it to be log(n). log2(7) is almost 3, then maybe we can sum 3 numbers. In this case, we choose to sum the 3 numbers: sum{1, 2, 3, 4}, sum{5, 6} and sum{7}. Assume that we have these sums already pre-computed, we have log(n) numbers to sum, hence querying will be log(n). For clarity, let me put everything in a table:

Table 1a

sum{1} = sum{1}

sum{1, 2} = sum{1, 2}

sum{1, 2, 3} = sum{1, 2} + sum{3}

sum{1, 2, 3, 4} = sum{1, 2, 3, 4}

sum{1, 2, 3, 4, 5} = sum{1, 2, 3, 4} + sum{5}

sum{1, 2, 3, 4, 5, 6} = sum{1, 2, 3, 4} + sum{5, 6}

sum{1, 2, 3, 4, 5, 6, 7} = sum{1, 2, 3, 4} + sum{5, 6} + sum{7}

sum{1, 2, 3, 4, 5, 6, 7, 8} = sum{1, 2, 3, 4, 5, 6, 7, 8}

The left hand side of the table is the query, and all the terms on the right hand side are pre-computed. If you look closely enough you will see the pattern: for summing k numbers, first take the largest power of 2, 2^m, that is ≤ k, and pre-compute it. Then for the rest of the numbers, k-2^m, take the largest power of 2, 2^m’ such that 2^m’ ≤ k-2^m, and pre-compute it, and so on.

There are two steps to do: show that querying (adding terms on the right hand side) is log(n) and show that pre-computing the terms on the right hand side is log(n).

Querying is log(n) is easily seen, because by taking out the largest power of 2 each time, we will at least take out half of the numbers (Use proof by contradiction). Taking out no less than one half each time, after O(log(n)) time we would have taken out all of it.

Now we are one step from finishing on the theoretical side: how do we pre-compute those terms?

Let’s say we want to change the number 1 into 2, essentially carrying out update(1, 0). Look at the terms above: we need to change sum{1}, sum{1, 2}, sum{1, 2, 3, 4} and sum{1, 2, 3, 4, 5, 6, 7, 8}. Each time we update one more pre-computed term, we cover double the number of elements in the array. Therefore we also only need to update log(n) terms. Let’s see it in a table:

Table 1b

update 1: sum{1}, sum{1, 2}, sum{1, 2, 3, 4}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 2: sum{1, 2}, sum{1, 2, 3, 4}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 3: sum{3}, sum{1, 2, 3, 4}, sum {1, 2, 3, 4, 5, 6, 7, 8}

update 4: sum{1, 2, 3, 4}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 5: sum{5}, sum{5, 6}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 6: sum{5, 6}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 7: sum{7}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 8: sum{1, 2, 3, 4, 5, 6, 7, 8}

Cool, now we have a vague idea about what to pre-compute for update and what to add for query. Now we should figure out the details of the code.

How is the code written?

First, we need to determine the representation of the pre-computed terms. Here is a list of all pre-computed terms:

{1}, {1, 2}, {3}, {1, 2, 3, 4}, {5}, {5, 6}, {7}, {1, 2, 3, 4, 5, 6, 7, 8}

The last number of each term is unique and covers the range 1-8. That’s great news! We can use a vector to store these terms easily, and let the index of the array be the last number of the term. For example, the sum of {5, 6} will be stored at bit[6].

First, the query operation. Let’s revisit the table with binary representation of numbers:

Table 2a: revised version of table 1a, with sums written as bit elements, indices in binary

query 0001: bit[0001]

query 0010: bit[0010]

query 0011: bit[0011]+bit[0010]

query 0100: bit[0100]

query 0101: bit[0101]+bit[0100]

query 0110: bit[0110]+bit[0100]

query 0111: bit[0111]+bit[0110]+bit[0100]

query 1000: bit[1000]

Do you see the pattern yet? Hint: for queries that have k ones, we have k terms on the right. The pattern is that while the index has at least 2 ones, we remove the lowest bit that is one, then move on to the next term. 0111->0110->0100. Finally, here’s the code:

int query(vector<int>& bit, int k) {
    int ans = 0;
    for (k++; k; k -= k & (-k))
        ans += bit[k];
    return ans;
}

After all the work we’ve been through, the code is extremely concise! Two things to notice: the k++ is to change the indexing from 0-based to 1-based, as we can see from the above derivation we go from 1 to 8, instead of 0 to 7. The second thing is the use of k & (-k) to calculate the lowest bit. You can refer to the previous blog post on bitwise operations.

OK, we’re almost done. What about update? Another table:

Table 2b: revised version of table 1b

update 0001: bit[0001], bit[0010], bit[0100], bit[1000]

update 0010: bit[0010], bit[0100], bit[1000]

update 0011: bit[0011], bit[0100], bit[1000]

update 0100: bit[0100], bit[1000]

update 0101: bit[0101], bit[0110], bit[1000]

update 0110: bit[0110], bit[1000]

update 0111: bit[0111], bit[1000]

update 1000: bit[1000]

What’s the pattern this time? Hint: again, look for the lowest bit! Yes, this time instead of removing the lowest bit, we add the lowest bit of the index to itself. This is less intuitive than the last part. For example, lowest bit of 0101 is 1, so the next index is 0101+1 = 0110; lowest bit is 0010, next index is 0110+0010 = 1000.

So here’s the code, note also the k++:

void update(vector<int>& bit, int v, int k) {
    for (k++; k < bit.size(); k += k & (-k))
        bit[k] += v;
}

This is deceivingly easy! That can’t be right. It can’t be that easy… Actually it can, if you look at the category of this post; nothing I write about is hard.

Actually it was easy because we were just matching patterns and assuming it would generalize. Let’s study the for loops a little more to understand why and how low bits are involved in this. This is rather complicated, so for practical purposes you might as well skip them.

First, observe that the low bit of each index indicates the number of integers that index sums over. Say, 0110 has low bit 0010 which is 2, so bit[6] is a sum of two numbers: 5 and 6. This is by design, since this is exactly how we picked the pre-computed terms, so there is no way of explanation.

Second, bit[k] is the sum from indices k-lowbit(k)+1 to k. This is a direct consequence from (1) bit[k] is a summation that ends at the kth number and (2) bit[k] sums over lowbit(k) numbers.

In light of this fact, the code for querying becomes clear: for an index k, we first get the sum from k-lowbit(k)+1 to k from bit[k], then we need to find the sum from 1 to k-lowbit(k). The latter becomes a sub-problem, which is solved by setting k-lowbit(k) as the new k value and going into the next iteration.

For updating, it is much trickier. From the above, we have l-lowbit(l) < k ≤ l, iff bit[l] includes k. Below is a sketch of proof, the actual proof will include more details and be more tedious and boring to go through. For the kth number, bit[k+lowbit(k)] must include it. This is because the lowbit of k+lowbit(k) must be at least 2 times lowbit(k), so k+lowbit(k)-lowbit(k+lowbit(k)) ≤ k-lowbit(k) < k ≤ k+lowbit(k), satisfying the inequality. Also, k can be updated to k+lowbit(k) in the next iteration, because given lowbit(k) < lowbit(m) < lowbit(n) and that bit[m] includes k and bit[n] includes m, bit[n] must include k as well. Till now, we have shown that the bit[k] values we have modified in the for loop must include k.

Then, we also need to show that all bit[l] values that include k are modified in our loop. We can actually count all the bit[l] values that include k: it is equal to one plus the number of zeros before lowbit(k). It is not difficult to see how the for loop reduces the number of zeros before lowbit(k) each time the loop moves on to the next iteration. The only question remaining is why that number? Let’s look at the table 2b again. The numbers of terms for the first four entries, i.e. {4, 3, 3, 2}, are one more than the number of terms for the second four entries, i.e. {3, 2, 2, 1}. This is by design, because bit[4] covers the first four but not the second four, and everything else is pretty symmetric. Again, the first two entries have one more coverage than the second two entries, because bit[2] records the first two but not the second two. Hence, each time we “go down the tree” on a “zero edge” (appending a 0 to the index prefix), the numbers will be covered once more than if we “go down the tree” on a “one edge” (appending a 1 to the index prefix). After we hit the low bit, no more terms of smaller low bits will cover this index, and of course the index itself includes itself, thus the plus one. This is a basic and very rough explanation on how the numbers of zeros relate to the number of terms including a certain index. Here we have argued semi-convincingly the update loop is valid and complete.

 

OK, anyway, time for practice: Range Sum Query – Mutable

It’s literally implementing a binary indexed tree, nothing more.

class NumArray {
private:
    vector<int> bit;
    void update_helper(int v, int k) {
        for (k++; k < bit.size(); k += k & (-k))
            bit[k] += v;
        }
    int query_helper(int k) {
        int ans = 0;
        for (k++; k; k -= k & (-k))
            ans += bit[k];
        return ans;
    }
public:
    NumArray(vector<int> &nums) {
        bit.resize(nums.size()+1);
        for (int i = 0; i < nums.size(); i++)
            update_helper(nums[i], i);
    }
    void update(int i, int val) {
        update_helper(val-query_helper(i)+query_helper(i-1), i);
    }
    int sumRange(int i, int j) {
        return query_helper(j)-query_helper(i-1);
    }
};

It got a little complicated because I didn’t store the original values, so we need some work on line 21 to calculate the change at a certain index given the new value and the old range sums. But that’s nothing important.

That’s it for the basic introduction of binary indexed trees. There are some variants to it, such as replacing the + sign in update function to a min or max function to take prefix min or max, or extending the algorithm to a 2D matrix, aka 2D binary indexed tree. We can even use it for some dynamic programming questions. There are in fact a few more questions on Leetcode that uses this data structure. But that’s for later.

I learned binary indexed tree through the TopCoder tutorial. If you think I did a really bad job and you do not understand at all, you can refer to it as well.

TIW: Bitwise

Bitwise operations are black magic. It is so simple but with them you can do things that you never thought would be so easy. For those who have never seen them, bitwise operations operate on one or two integer type variables, and treat them as an array of booleans. Each operation acts on the elements individually. Let’s see what bitwise operations we have:

  1. And: a & b. 0011 & 0101 = 0001 in binary, so 3 & 5 = 1.
  2. Or: a | b. 0011 | 0101 = 0111 in binary, so 3 & 5 = 7.
  3. Exclusive or: a ^ b. 0011 ^ 0101 = 0110 in binary, so 3 & 5 = 6.
  4. Not: ~a. ~00001111 = 11110000. Depending on the number of bits of your integer data type, the values could vary.
  5. Shift left and right: 001011 << 1 = 010110, 001011 >> 2 = 000010. It is essentially multiplying or dividing by 2 to the power of k.

Applications are sort of miscellaneous and interesting. First let me go through some common routines, then I will go over some problems.

Taking a bit at a certain index

int bitAtIndex(int v, int k) {
    return (v >> k) & 1;
}

Shift the bit you want to the least significant bit, then and it with 1 to get rid of all the other higher bits.

Clearing and setting a bit

void clearBit(int& v, int k) {
    v &= ~(1 << k);
}
void setBit(int& v, int k) {
    v |= 1 << k;
}

This idea is called masking: create a mask, apply it on the number by either or-ing or and-ing.

Getting the lowest 1 bit

int lowBit(int v) {
    return v & (-v);
}

This is sort of tricky. Let’s walk through it. Say our number is 00011000. In two’s complement, the negative number of v is obtained by 1+(~v). So the negative of 00011000 is 11100111 plus 1, which is 11101000. Taking the and result with 00011000 will yield 00001000, which is the lowest bit we want. The way to understand it is that the tail of our input number v, defined as the pattern 100000… at the end, will remain the same after taking the negative. Everything to the left of the tail will be flipped. Therefore taking the and result will yield the lowest bit that is a 1. This is particularly useful for Binary Indexed Tree, which I will talk about in a coming post.

Here’s the most cliched problem.

Single Number

Given an array, find the only number that appeared once. This problem is called Single Number for a reason: imagine it’s Christmas time and you’re out there on the streets alone, and everyone you see is with their significant others. Probably how they got the problem idea. O(n) time with O(1) space. To solve this problem, we need to find an operation that acting on the same number twice will yield the identity function, i.e. f(f(a, b), b) = a. This function had better be commutative and associative, so we can do it in any order, and cancel out all the pairs. Obviously +, -, *, / don’t meet the requirements. The answer is exclusive or. The ideas are: a^a = 0, a^0 = a. Therefore 3^2^1^2^3 = 2^2^3^3^1 = 0^1 = 1, exclusive or-ing all numbers gives you the answer.

int singleNumber(vector<int>& nums) {
    int ans = 0;
    for (int x : nums)
        ans ^= x;
    return ans;
}

One more remark: it’s possible to extend this algorithm to find the only number that appears once, given all other numbers appear n times for n ≥ 2. What we want to accomplish essentially is to create a commutative function that goes to identity after n operations. The function is this: maintain a vector of integers and size 32, counting the numbers of 1 at each bit mod n. It is easy to see after n times, the counts will be either 0 or n, both equal to 0 mod n. So we will end up with the answer. Our solution above is just a degenerate case when n = 2, so a vector of int mod 2 can be replaced by simply an integer, and modular addition can be replaced by exclusive or. Single Number II is the problem for n = 3. Don’t ask me where they got the problem idea from 🙂

int singleNumber(vector<int>& nums) {
    vector<int> c(32);
    for (int x : nums)
        for (int j = 0; j < 32; j++)
            c[j] = (c[j]+((x>>j)&1))%3;
    int ans = 0;
    for (int i = 0; i < 32; i++)
        ans |= c[i] << i;
    return ans;
}

For the people who have never seen bitwise, this is sort of complicated. You can see taking a bit on line 5 and setting a bit on line 8.

Generating the power set

Given a set s, the power set is the set of all subsets of s. For example if s = {1, 2, 3}, P(s) = {{}, {3}, {2}, {2, 3}, {1}, {1, 3}, {1, 2}, {1, 2, 3}}. Here is one possible implementation using bitwise and:

vector<vector<int> > powerSet(vector<int>& s) {
    vector<vector<int> > ans;
    for (int i = 0; i < (1 << s.size()); i++) {
        ans.push_back({});
        for (int j = 0; j < s.size(); j++)
            if (i & (1 << j))
                ans.back().push_back(s[j]);
    }
    return ans;
}

The key idea is the outer loop of i. Say for s = {1, 2, 3}, i will loop from 0 to 7, or 000 to 111 in binary. We will have all the bit patterns in that case: 000, 001, 010, 011, 100, 101, 110, 111. Now for each number, each bit indicates whether we include a certain element of the original set. By looping through these bits, we can create each subset and generate the power set.

Counting the number of 1 bits

Given a number v, count how many bits are 1. There are different ways to do it, two will be shown below.

int countBits(unsigned int v) {
    int ans = 0;
    for (int i = 0; i < 32; i++)
        if ((v >> i) & 1)
            ans++;
    return ans;
}
int countBits(unsigned int v) {
    int ans = 0;
    for (; v > 0; v -= v & (-v), ans++);
    return ans;
}

The second method uses the low bit function, removing the lowest bit every time. I suppose it is more efficient, but it probably doesn’t make a real difference.

Swapping two numbers without space

void swapNumbers(int& a, int& b) {
    a = a^b;  //  a ^= b;
    b = a^b;  //  b ^= a;
    a = a^b;  //  a ^= b;
}

Not very useful, but good to know.

Sum of Two Integers

Here’s a brainteaser: add two numbers, but the code cannot have any + or -. The idea of the code: calculate the carry bits and the addition without carry bits, then add the two results together. The base case is when one number becomes 0. There will not be an infinite loop because the number of trailing zeros of the carry bits must increase each time in a recursion. In the following code, if any of the numbers is 0, the sum is equal to bitwise or. Otherwise, the bits without carry will be a exclusive or b, and the carry bits will be where both a and b are 1, shifted to the left by 1.

int getSum(int a, int b) {
    return (a == 0 || b == 0)? a | b : getSum(a ^ b, (a & b) << 1);
}

Sudoku Solver by Lee Hsien Loong

This code is written by the prime minister of Singapore. It is written in pure C, so it is kind of hard to read. He used a lot of bitwise operators in this code. I read it two years ago, and I don’t want to spend the time understanding everything again, so my short explanation might be faulty. The algorithm is not tricky, as he simply picks a grid with the smallest number of choices, and tries everything recursively (line 180). To speed things up, he used integers as boolean arrays to indicate what numbers are still available for a certain row, column or 3×3 block. Therefore to get the possible placements at a certain grid simply requires taking the bitwise and result (line 171). To use one possible result, he took the lowest bit (line 173). Another trick to reduce the runtime is to pick the grid with the fewest possible choices (lines 162, 163, 188 I assume). He also pre-computed some functions into arrays to avoid repeated work. Most of these are optimizations that reduce the time constant, replacing one O(1) operation with another O(1) operation. Tricky and efficient, but also with reduced readability, in my opinion.

 

Anyway that’s a lot already; I will split the Binary Indexed Tree part in a separate post. Surely these are mostly brainteasers, but some interviewers do like them, and for some lower level (closer to hardware) jobs they are quite important.

TIW: Linked List Cycle

This is more like a special topics post, because it is a very specific algorithm with a very narrow application. The problem statement: given a linked list which has a cycle, determine where the cycle begins.

To explain further, a linked list with a cycle looks something like the number 6. In fact, the only 2 topologies (read: shapes) a linked list could have are a straight line or the number 6. We start walking from the top, and end up walking indefinitely in a loop. To determine whether there is a loop or not is fairly simple: create an unordered set, insert all the visited nodes (or just their pointers, or anything unique to the nodes) into the set, until we insert the same node twice or we run into the end of the linked list (in which case there will be no cycle).

Let’s say we have this declaration of list node.

struct ListNode {
    ListNode* next;
    int val;
};

And this would be the function to return the first node in a cycle.

ListNode* detectCycle(ListNode* head) {
    unordered_set<ListNode*> vis;
    while (head) {
        if (vis.count(head))
            return head;
        vis.insert(head);
        head = head->next;
    }
    return NULL;
}

This is trivial. What is not trivial is to accomplish the exact same task with O(n) time as before but with O(1) space.

The algorithm that does this uses 2 pointers, one fast and one slow. They both start from the head, and the fast pointer moves 2 nodes while the slow pointer moves 1 node per iteration. If there is no loop, the fast pointer will reach the end. If there is a loop, they will fall into the loop , and eventually they will end up at the same node at some point. At that time we will be sure that there is a loop. Why are we sure that they will always collide? Consider them both in the cycle already. In each iteration, the relative distance between the two pointers will increase by 1. When the distance hits a multiple of the length of the cycle, they will effectively have a distance of 0, and hence will be at the same node.

OK, that sounds clever. But we still do not know the beginning of the loop, do we?

Here’s the real genius: yes we can figure it out! After they collide, move the fast pointer back to the head. Now in each iteration, move them both at the same pace of 1 node, and eventually they will collide again. When that happens, we have found the beginning of the loop. I do not know of an intuitive explanation of this, but it is provable with a little algebra. Let’s say the part before the loop has m nodes, and the loop itself is n nodes long. Say, after k iterations, the fast pointer meets with the slow pointer. Then we know (k*2-m)-(k-m) = 0 (mod n), i.e. the relative distance between the two pointers is 0. Hence k is a multiple of n. After m more steps, the slow pointer in total has moved k+m steps. That is equivalent to moving m steps, and then moving k more steps. But moving k steps in the loop does nothing, because k is a multiple of n. Therefore after k+m steps, the slow pointer points at the beginning of the loop. Coincidentally the fast pointer, in the second phase of the algorithm, after moving m single steps, also arrive at the beginning of the loop. Therefore, we have proven (not very rigorously) the first time the two pointers meet in the second phase of the algorithm is at the beginning of the loop.

Perhaps I should show some code to make it clearer.

ListNode* detectCycle(ListNode* head) {
    ListNode *fast = head, *slow = head;
    do {
        for (int i = 0; i < 2; i++) {
            if (!fast)
                return NULL;
            fast = fast->next;
        }
        slow = slow->next;
    } while (fast != slow);
    fast = head;
    while (fast != slow) {
        fast = fast->next;
        slow = slow->next;
    }
    return fast;
}       

That’s the algorithm. There is one problem on Leetcode that uses this algorithm: Find the Duplicate Number. It is not exactly obvious, so please spend some time to convince yourself this problem represents a linked list. Basically the idea is that each number in the array is a node, the index being the address, and the number being the address of the next node. The graph looks like a 6 because there is exactly one node with in degree 2, ignoring the parts of the graph that cannot be reached from the head. Here’s the code for it:

int findDuplicate(vector<int>& nums) {
    int fast = 0, slow = 0;
    do {
        fast = nums[nums[fast]];
        slow = nums[slow];
    } while (fast != slow);
    fast = 0;
    while (fast != slow) {
        fast = nums[fast];
        slow = nums[slow];
    }
    return fast;
}

TIW: Reverse Linked List

I’m just going to say this first: I hate linked list problems. But I also hate waking up, yet I do it every day anyways.

First, what is a linked list: an array that supports O(1) insertion and O(n) random access, in contrast to vector’s O(n) insertion and O(1) random access. Here’s how a linked list look like in real life:

      oooOOOOOOOOOOO"
     o   ____          :::::::::::::::::: :::::::::::::::::: __|-----|__
     Y_,_|[]| --++++++ |[][][][][][][][]| |[][][][][][][][]| |  [] []  |
    {|_|_|__|;|______|;|________________|;|________________|;|_________|;
     /oo--OO   oo  oo   oo oo      oo oo   oo oo      oo oo   oo     oo
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

Here’s how a linked list look like in C++:

struct ListNode {
    int val;
    ListNode* next;
};

First, here’s the link to the Leetcode problem Reverse Linked List.

Like many other problems, linked list reversal could be achieved in two ways: iterative and recursive. The one thing we need to do is point all the “next” pointers backwards. Let’s look at iterative solution first.

ListNode* reverseList(ListNode* head) {
    ListNode* last = NULL;
    while (head) {
        ListNode* next = head->next;
        head->next = last;
        last = head;
        head = next;
    }
    return last;
}

This code is trivially easy to write. I have found 2 slightly different ways to interpret what this code does, and you can see which one you find easier to understand.

The first way to look at it: imagine the linked list 1->2->3->NULL. Let () denote the last, and [] denote head. After the first iteration: NULL<-(1) [2]->3->NULL; NULL<-1<-(2) [3]->NULL, NULL<-1<-2<-(3), [NULL]. Essentially we used more variables to store the last and next guy, so we can transition without losing track of anybody. When head hits null, the last guy will be the new head.

The second way to look at it is to think of “last” as a new linked list. What we attempt here is to pop the first node of the original linked list, and insert it to the front of the new linked list. Therefore, “head” is the first node in the old linked list, and “last” is the first node in the new one. 1->2->3->NULL, NULL; 2->3->NULL, 1->NULL; 3->NULL, 2->1->NULL; NULL, 3->2->1->NULL. In a certain sense, this is describing exactly the same thing, but it might be more clear now why we return the “last” variable: it is the head of the new linked list.

OK, that’s not bad, let’s look at a recursive way. Hold back for a minute and think how do we reduce the problem size, given that we can handle a smaller case. Perhaps we can take the first node out, and reverse the rest. Now we need to insert the original guy into the back of the new list. Oh no, we do not know where the end of the new linked list is! How do we solve this?

pair<ListNode*, ListNode> helper(ListNode* head) {
    if (!head || !head->next)
        return make_pair(head, head);
    auto ans = helper(head->next);
    ans.second->next = head;
    head->next = NULL;
    ans.second = head;
    return ans;
}
ListNode* reverseList(ListNode* head) {
    return helper(head).first;
}

This way is to instead of just writing the function with the given function signature, we write a helper function that also returns the pointer to the last node in the partially reversed list. In that way we can append to the list easily. But this is really not an elegant solution, as we can see in the next code snippet.

ListNode* reverseList(ListNode* head) {
    if (!head || !head->next)
        return head;
    ListNode* ans = reverseList(head->next);
    head->next->next = head;
    head->next = NULL;
    return ans;
}

The solution is on line 5. The end of the partially reversed list used to be the first in the list, so it is pointed to by the current head. So setting head->next->next to head, we have appended head to the end, essentially creating a cycle. Therefore we break the cycle by setting head->next to NULL.

It all looks alright, code is 2 lines shorter and all that. But here’s the catch: although it is O(n) time, the recursive solution is not O(1) space. This is because by calling the function n times within itself, we have created n times the local variables in this function. What is the problem, you may ask? There are two: first, stack space is much more limited than heap space (dynamic memory allocation), so we might encounter stack overflow if the linked list is huge. This will not be a problem for the iterative solution. The second: if we could use so much space, why don’t we just store everything on an external vector and trivially random access everything? What is the whole point of using linked list anyways?

Despite my rant, some interviewers actually don’t care if you write recursion. They might even think it’s cleaner. So it’s still good to know.

For some trickier problems, a recursive solution (or at least a non-O(1) space solution) might be necessary. But if you can do it in O(1) space, you should prefer to do so, because using linear space for linked list problem is kinda cheating.

Now if you cannot wait to challenge yourself, you can try this one: Reverse Linked List II. It’s definitely more code than 10 lines though.

TIW: Dijkstra

Shortest path problems using Dijkstra are actually very easy to write. There is a fixed format and you just need to fill in the blanks. The format:

  1. Optional pre-computation to create the graph;
  2. Make a set of pairs, insert initial state;
  3. While the set is not empty, pop the best one;
  4. If it is the destination, return the distance;
  5. If we have been to this node, continue to the next, otherwise mark this node as visited;
  6. Insert all the neighbors into the set.

In fact, Dijkstra is such a no-brainer that I sometimes write Dijkstra where BFS suffices, even though it gives an extra log(n) factor to the runtime.

Let’s make up a toy problem and see how it works. Say, a 2D maze search. Input: a matrix of characters, walk from any ‘S’ to any ‘T’, using only ‘.’ as path. Return the length of the shortest path. For example

..S
.XX
..T

Will return 6.

Here’s the code:

 

bool inboard(int x, int y, int m, int n) {
    return x >= 0 && x < m && y >= 0 && y < n;
}
int mazeSearch(vector<string>& maze) {
    int m = maze.size(), n = maze[0].size();
    set<pair<int, pair<int, int> > > st;
    vector<vector<bool> > vis(m, vector<bool>(n));
    for (int i = 0; i < m; i++)
        for (int j = 0; j < n; j++)
            if (maze[i][j] == 'S')
                st.insert(make_pair(0, make_pair(i, j)));  // insert all initial states with distance 0
    vector<vector<int> > dir{{1, 0}, {0, 1}, {-1, 0}, {0, -1}};
    while (!st.empty()) {
        auto sm = *st.begin();  // smallest item in the set, closest unvisited state from source
        st.erase(sm);  // popping the heap
        int r = sm.second.first, c = sm.second.second;
        if (vis[r])
            continue;
        vis[r] = true;
        if (maze[r] == 'T')
            return sm.first;  // found shortest path
        for (auto d : dir) {
            int nr = r+d[0], nc = c+d[1];
            if (inboard(nr, nc, m, n) && (maze[nr][nc] == '.' || maze[nr][nc] == 'T'))
                st.insert(make_pair(sm.first+1, make_pair(nr, nc)));  // next states with 1 distance farther from source
        }
    }
    return -1;  // no path found
}

The main tricks here are to use a set as a heap and a pair to denote a state. A few benefits of this trick:

  1. Set has log(n) insert, remove and maximum/minimum query, ideal for Dijkstra.
  2. Set has built-in duplicate removal, potentially saving a lot of extra work.
  3. Pair has default comparator as mentioned in STL#1, so we do not have to write our own.

The alternative is to use std::priority_queue<>. I cannot think of any reason why using priority queue would be superior. There are a few reasons I do not prefer them: the name is longer to type, does not have duplicate removal, requires extra work to get a min-heap (since it is by default a max-heap), and I don’t remember the syntax and would have to Google every time. Getting a min-heap or max-heap out of std::set<> is trivial: just change *st.begin() to *st.rbegin(). the first item from the end is the largest item.

Of course this problem only requires BFS.  You can modify the above code by changing the set to a vector and changing the while to a for loop, shown in the BFS blog post earlier. But this is just an example to demonstrate the structure of a Dijkstra implementation.

Unfortunately there is no problem that needs Dijkstra on Leetcode; although you could certainly use it to replace some BFS problems. Then maybe I’ll slightly go over A* using the same code structure just for fun. If you really understand the above, it is trivial to modify it for similar problems.

A* by augmenting Dijkstra

A* is just Dijkstra with a “postpone” function that estimates a lower bound for the distance of a node to destination. When the postpone function is 0, we get exactly Dijkstra. Think of it this way: you are a college student with a homework assigned, and have to submit before it is due. Given that you don’t know the due date, but it is at least tomorrow night. Are you going to do it now? Of course no. But given you only know the due date is before next Sunday, it might actually be tonight, so you have to do it now. The same idea applies to A*: if we know this path definitely takes at least 10 steps, we can safely postpone walking it after we have tried all paths that could possibly take 9 steps.

To implement A* with the above code: instead of putting in the distance from source at the first item of each pair, put the distance from source plus the estimated lower bound distance from destination.

Once again I don’t have a good problem, and it is also not that common, so I’ll skip the details. The actual function form highly depends on the problem. In some cases it might even be impossible or useless to calculate a lower bound. So A* is definitely not the one ring to rule it all, but it’s still good to know.

TIW: Monotonic Stack

I actually made this name up; I don’t know what it is called. This is how the Chinese call it apparently, 單調棧. It sometimes comes handy when you need O(n) performance in a problem that seemingly needs more than that and which requires finding a number.

Basic facts about monotonic stacks:

  1. It is a kind of precomputation, just like precomputing the prefix sum array.
  2. It is O(n), so it almost never hurts your runtime, just like prefix sum array.
  3. As the name suggests, it is a method that uses a stack (implemented using a vector) that has all elements sorted.
  4. It is used to find the next (or previous) element in the array that is larger (or smaller) than the current element.

Let’s look at the code:

vector<int> lastBigger(vector<int> v) {
    vector<int> ans(v.size()), mono;
    for (int i = 0; i < v.size(); i++) {
        while (!mono.empty() && v[mono.back()] <= v[i])
            mono.pop_back();
        ans[i] = mono.empty() ? -1 : mono.back();
        mono.push_back(i);
    }
    return ans;
}

This piece of code finds the index of the closest previous element that is larger than the current one. For example, for an input vector {1, 5, 4, 2, 3}, it will return {-1, -1, 1, 2, 2}. The index is stored instead of the actual element, because it is more informative and we could look up the actual values from the indices.

The algorithm is based on the idea that to find the last bigger value, we do not need to look at all the previous values. If we go farther to the left, it is because we want a larger value. For example, on the above array, at the last element, we only need to look at the values {5, 4, 2} but not the value 1, because taking 5 is strictly better than taking 1. And after the element 3, 2 will be useless to all future numbers, if there is any. Therefore we keep popping the stack until we find something that is potentially useful for later.

To find a smaller instead of bigger value, simply change <= to >=. To find the next instead of the previous, just start looping from the end to the front.

I have seen this technique come up from time to time from seemingly unrelated problems. I will go through 3 problems.

Sliding Window Maximum

The first problem: given an integer array v of size n and an integer k, return an integer array w of size n-k+1 where w[i] = max(v[i], v[i+1], …, v[i+k-1]).

The most naive thing to do is to run a loop of k times for each element in w, taking the maximum for each value. An obvious observation is that each time we are only moving the window by one index, so the maximum often remains unchanged. So a slight speedup would be to check if the element we are discarding by moving the window is the previous maximum, if not, then we can take the max of the previous maximum and the current value to update the current maximum.

But we are still at worst case O(nk), because if we always discard the maximum, then we need to update the maximum again by looping. Here’s how monotonic stack will compress the runtime to O(n). For each element in w:

  1. Maintain a variable, idx, that stores the index of the maximum element in the previous window.
  2. Move the window to the right by one.
  3. If idx stays in the window, and the new value is larger than the one pointed by idx, update idx to the new index.
  4. If idx stays in the window, and the new value is not larger, then idx remains unchanged.
  5. If idx falls out of the window, first move idx to the right by one. Now it might not be the largest anymore, so we look for the next bigger element. This is O(1), as precomputed by monotonic stack. While the next bigger element is in the window, update idx to point to that element.

The runtime of this algorithm is O(n) because idx and the window only move to the right, and we never go backwards.

vector<int> maxSlidingWindow(vector<int>& nums, int k) {
    if (nums.empty())
        return {};
    vector<int> nextBigger(nums.size(), INT_MAX), mono;
    for (int i = nums.size()-1; i >= 0; i--) {
        while (!mono.empty() && nums[mono.back()] <= nums[i])
            mono.pop_back();
        if (!mono.empty())
            nextBigger[i] = mono.back();
        mono.push_back(i);
    }
    int pt = 0;
    vector<int> ans;
    for (int i = k-1; i < nums.size(); i++) {
        if (pt == i-k)
            pt++;
        while (nextBigger[pt] <= i)
            pt = nextBigger[pt];
        ans.push_back(nums[pt]);
    }
    return ans;
}

There is a 2D version of this problem: given a 2D array v of size n by n and an integer k, return a 2D array w of size n-k+1 by n-k+1 where w[i][j] = max(v[i][j], v[i][j+1], …, v[i][j+k-1], v[i+1][j], …, v[i+k-1][j+k-1]). It is in fact very easy, just run this algorithm twice, first horizontally then vertically. It will be O(n^2), which is optimal as input itself would be O(n^2).

Largest Rectangle in Histogram

This one, although categorized as hard, is fairly straightforward: enumerate elements and use each one as the height, then look for the previous one and the next one that are smaller. Therefore we need to apply monotonic stack twice.

int largestRectangleArea(vector<int>& heights) {
    if (heights.empty())
        return 0;
    vector<int> lastSmaller(heights.size()), nextSmaller(heights.size()), mono;
    for (int i = 0; i < heights.size(); i++) {
        while (!mono.empty() && heights[mono.back()] >= heights[i])
            mono.pop_back();
        lastSmaller[i] = mono.empty() ? -1 : mono.back();
        mono.push_back(i);
    }
    mono.clear();
    for (int i = heights.size()-1; i >= 0; i--) {
        while (!mono.empty() && heights[mono.back()] >= heights[i])
            mono.pop_back();
        nextSmaller[i] = mono.empty() ? heights.size() : mono.back();
        mono.push_back(i);
    }
    int ans = 0;
    for (int i = 0; i < heights.size(); i++)
        ans = max(ans, heights[i]*(nextSmaller[i]-lastSmaller[i]-1));
    return ans;
}

On an irrelevant note, some people like to condense code that should be multiple lines into one line, like squeezing return 0, i++ etc into the previous line. There is no way to be consistent about this, because sometimes the lines get too long and they need to split into two lines. Inconsistent coding style is a sin. It leads to bugs and confusion. Code should be made as short as possible, but no shorter.

132 Pattern

The last problem is fairly new. For these problems, usually we need to enumerate something. In this case, we enumerate the “2”. The algorithm is simple: for each “2”, find “3” by greedy algorithm, and see whether we can find a “1”. The greedy way for “3” is to find the number larger than “2” that is closest to the left of it, which is essentially the last bigger element. Checking whether we can find a “1” is simply finding the range minimum from the left up to “3” and checking whether it is smaller than “2”.

bool find132pattern(vector<int>& nums) {
    if (nums.size() < 3)
        return false;
    vector<int> mins{nums[0]};
    for (int i = 1; i < nums.size(); i++)
        mins.push_back(min(nums[i], mins.back()));
    vector<int> mono, lastbigger(nums.size());
    for (int i = 0; i < nums.size(); i++) {
        while (!mono.empty() && nums[mono.back()] <= nums[i])
            mono.pop_back();
        lastbigger[i] = mono.empty() ? -1 : mono.back();
        mono.push_back(i);
    }
    for (int i = nums.size()-1; i > 1; i--)
        if (lastbigger[i] > 0 && mins[lastbigger[i]] < nums[i])
            return true;
    return false;
}

Now that you have seen the above, you should be able to solve Maximal Rectangle with minimum effort.

TIW: BFS, DFS

Breadth First Search and Depth First Search are both essential skills to passing medium problems. If you’ve been following along, I’ve already written BFS once in STL#2. Depending on the problem, there are usually two ways I would go about writing BFS, and also two ways for DFS. But before I go on with solving problems, let me point out a few basic facts:

  1. BFS is good for finding shortest paths on a graph with no edge weight, or just traversing the whole graph in general.
  2. DFS is good for traversing the whole graph too, especially trees.
  3. DFS could be have time complexity issues depending on the graph you are walking, in which case it might be replaceable by BFS/Dijkstra.
  4. DFS with recursion uses the program stack and might overflow much easier than dynamically allocating memory (using vector, for example).
  5. There are probably many more use cases of BFS and DFS than you think, on things you would not think of as a graph.

I have a feeling this is going to be long. So here are the key points I will demonstrate:

  1. BFS on 2D array
  2. BFS layer by layer
  3. DFS on tree, iterative and recursive
  4. DFS as brute force search

BFS on 2D array

Let’s start with a standard problem for BFS: Number of Islands. Given a 2D matrix with ‘1’s and ‘0’s, count the connecting pieces of ‘1’s. The idea: go through each point, if it is land and we haven’t seen it, walk the whole island and mark each land as ‘0’ we see.

bool inboard(int x, int y, int n, int m) { // whether a point (x, y) is in the board with size n times m
    return x >= 0 && x < n && y >= 0 && y < m;
}

int numIslands(vector<vector<char>>& grid) {
    if (grid.empty())
        return 0;
    int m = grid.size(), n = grid[0].size();
    int ans = 0;
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            if (grid[i][j] == '1') {
                ans++;
                queue<pair<int, int> > bfs;
                bfs.push(make_pair(i, j));
                grid[i][j] = '0';
                while (!bfs.empty()) {
                    pair<int, int> pos = bfs.front();
                    bfs.pop();
                    for (auto& v : vector<vector<int> >{{1, 0}, {0, 1}, {-1, 0}, {0, -1}}) {
                        int next_i = pos.first+v[0], next_j = pos.second+v[1];
                        if (inboard(next_i, next_j, m, n) && grid[next_i][next_j] == '1') {
                            grid[next_i][next_j] = '0';
                            bfs.push(make_pair(next_i, next_j));
                        }
                    }
                }
            }
        }
    }
    return ans;
}

The inboard() function is often useful in 2D array graph problems. Notice the line:

for (auto& v : vector<vector<int> >{{1, 0}, {0, 1}, {-1, 0}, {0, -1}})

This for loop will give you four directions along which to go, and you just need to add v[0] to your current position’s row index and v[1] to column index to get to the next position. Stop copy pasting code 4 times, please.

A queue<> has 3 functions we need: pop, push and front. If you need a queue that you can access from both sides, you should consider using deque<>, which has [push/pop]_[front/back], and also front() and back(). Everything is of course O(1), so you don’t have to worry.

It is important to note that in the above problem, there is no need to walk the graph in any particular order as far as you get the job done, so I did not mark the distance to each point. A DFS approach would also have worked and might even be preferred. In a DFS recursion approach, the extra memory will be allocated on the program stack instead of the heap, which is quite advantageous in terms of efficiency.

BFS layer by layer

Sometimes doing things layer by layer will make life easier. See Maximum Depth of Binary Tree:

int maxDepth(TreeNode* root) {
    if (!root)
        return 0;
    vector<TreeNode*> bfs{root};
    int ans;
    for (ans = 0; !bfs.empty(); ans++) {
        vector<TreeNode*> next;
        for (TreeNode* node : bfs)
            for (TreeNode* c : {node->left, node->right})
                if (c)
                    next.push_back(c);
        bfs = move(next);
    }
    return ans;
}

This is essentially the same approach I used for the problem in STL#2. DFS would also work for this problem, and might perform better in terms of space for some test cases, since space for a recursive approach would be O(height) instead of O(width). Anyway I still included this problem here for the sake of completeness. Noteworthy: (1) I made a for loop since I want to get the height, otherwise we can write a while (!bfs.empty()) loop; (2) there is no need to keep track of the visited nodes because we know for sure we will not visit the same node twice.

DFS on tree, iterative and recursive

While BFS is always written using a queue/vector, DFS uses a stack, so we could use either vector or recursion to accomplish the same goal.

I will redo the Maximum Depth problem twice to illustrate the idea:

// Iterative DFS
int maxDepth(TreeNode* root) {
    if (!root)
        return 0;
    vector<pair<TreeNode*, int> > dfs{make_pair(root, 1)};
    int ans = 0;
    while (!dfs.empty()) {
        pair<TreeNode*, int> back = dfs.back();
        dfs.pop_back();
        ans = max(ans, back.second);
        for (TreeNode* c : {back.first->left, back.first->right})
            if (c)
                dfs.push_back(make_pair(c, back.second+1));
    }
    return ans;
}
// Recursive DFS
int maxDepth(TreeNode* root) {
    return root ? 1+max(maxDepth(root->left), maxDepth(root->right)) : 0;
}

As you can see, recursion on trees is black magic. I mean, what’s shorter than a one liner? Whenever applicable, it often shortens the code by a great margin, and present the logic in a very clear and concise way.

DFS with recursion is great for tree problems that are self-similar. For example, the height of this tree is the height of the left subtree + 1, or the right subtree + 1, whichever is larger.

  1. Use recursion if you are going to DFS. Except for some Leetcode problems that test on memory limits, I have not found any reason why iterative DFS is superior. Recursion runs on stack while vector dynamically allocate memory in the heap, so in that particular case vector has less limitations on the space constraint.
  2. In recursion, first handle all base cases, then implement the recurrence relation (well… what else are you supposed to do anyways?)
  3. The essential problem here is to determine in order to solve the original problem, what information do you need for a sub-problem. Sometimes you need extra information from your parent or children to solve a task.

To illustrate the third point above, I’m going to do a slightly more complicated one (not a one-liner): given a TreeNode*, determine whether the tree is a binary search tree. This is a medium problem on LeetCode.

pair<int, int> helper(TreeNode* root, bool& ans) {
    int l = root->val, r = root->val;
    if (root->left) {
        pair<int, int> left = helper(root->left, ans);
        if (left.second >= root->val)
            ans = false;
        l = left.first;
    }
    if (root->right) {
        pair<int, int> right = helper(root->right, ans);
        if (right.first <= root->val)
            ans = false;
        r = right.second;
    }
    return make_pair(l, r);
}
bool isValidBST(TreeNode* root) {
    if (!root)
        return true;
    bool ans = true;
    helper(root, ans);
    return ans;
}

The recurrence relation here is that the largest element in the left subtree must be smaller than the current node and the smallest element in the right subtree must be larger than the current node.

DFS as brute force search

Sometimes some problems require brute force enumeration of all possibilities. Then, recursive DFS would be a good approach. Let’s look at Decode String: given a string in the recursive format: #[x] where # is a number and x is an string, print x multiple times as denoted by #. For example “2[abc]3[cd]ef” becomes “abcabccdcdcdef”.

There is not much to optimize in this problem, since the runtime is lower bounded by the length of the output. The algorithm is basically running each recursive level of helper() as the same bracket level of the original string, passing by reference the current index in the string as we go along.

string helper(string& s, int& i) {
    string ans;
    while (i < s.length() && s[i] != ']') {
        if (s[i] >= 'a' && s[i] <= 'z') {
            ans += s[i];
            i++;
        } else {
            int k = 0;
            for (; s[i] != '['; i++)
                k = k*10+s[i]-'0';
            i++;
            string ret = helper(s, i);
            for (int j = 0; j < k; j++)
                ans += ret;
        }
    }
    i++;
    return ans;
}
string decodeString(string s) {
    int i = 0;
    return helper(s, i);
}

There are many applications of DFS brute force, but they could be very different from each other and hard for me to generalize. So this is only a short description of what you can do with DFS.

Recap

Use BFS or recursive DFS for 2D board traversal; use vectors or queues for BFS; use recursion for trees almost always; use recursive DFS to brute force.

Binary Search: A Better Way

So you think you know how to write binary search – just like everyone else. I mean, it’s easy, right? Maybe, but you can probably do it better. Let’s start with an example. Say you have a sorted integer array, and you want to find the largest integer no larger than k, and return its index.

So it usually goes like this: let n = size of array, l = 0 and r = n-1, and dive into a while loop. In the loop, you find a mid point between l and r. And you quit when you have only one number left, which is l = r.

int l = 0, r = n - 1;
while (l < r) {
    int mid = (l + r) / 2;
}

Then what do you do? Well, you could compare it with k.

while (l < r) {
    int mid = (l + r) / 2;
    if (v[mid] > k)
        r = mid - 1;
    else
        l = mid;
}
return l;

If v[mid] is larger than k, then all the valid answers can only be between l and mid-1. Otherwise, anywhere from mid to r could be the answer.

Now you’re happy and you try to run it – and it doesn’t work! In fact, not only does it not work, it’s a terrible piece of crap. Here’s why.

1. Edge cases

What happens if (1) the array is empty, (2) all integers are smaller than k, (3) all integers are larger than k? Well in (1) you’ll return 0 which is wrong, in (2) you’ll get an infinite loop (to be explained later), and in (3) you’ll get 0, which is also wrong. The correct results would be (1) -1 for not found, (2) n-1, and (3) -1 for not found. This algorithm gets all edge cases wrong!

2. Infinite loop

Say if we are given k = 2 and the array {0, 1}. l starts at 0 and r starts at 1, so mid would be (0 + 1) / 2 = 0. And 0 is not larger than 2, so we set l to mid, which is 0. We’re back to the same situation, and this loop will never end!

There are ways to fix the code, of course – you handle the special cases by spraying if statements all over the place, and you can fix the infinite loop with setting mid to (l + r + 1)/2. Like this:

if (n == 0 || v[0] > k)
    return -1;
int l = 0, r = n - 1;
while (l < r) {
    int mid = (l + r + 1) / 2;
    if (v[mid] > k)
        r = mid - 1;
    else
        l = mid;
}
return l;

But it’s ugly! We have to pick between (l + r) and (l + r + 1), which seems totally arbitrary, and there are edge cases to consider. What went wrong?

Everything from the very beginning went wrong. At the point we set l to 0 and r to n-1, we’re already making an implicit statement (loop invariant). We’re assuming that every iteration as we enter the loop, the answer could be anywhere from l to r. But this isn’t even true! If the answer doesn’t exist, then obviously the answer can’t be between l to r. The other assumption made is that the r-l always decreases after each iteration (so the loop eventually terminates) – which is also not true, as examplified in the infinite loop case. Since the loop invariants never held, we get gibberish at the end of the algorithm.

It turns out that if we change our loop invariant, we could be in a much better situation. Here’s the proposal: the transition between the last integer that’s at most k and the first integer that’s larger than k is always between l and r.

That means l must start at -1, and r at n. This is because the transition could happen from -1 to 0 (before the first element) or from n-1 to n (after the last element). The stopping condition would be l + 1 = r, because by then we would know the transition is from l to r, and the answer would be l. Let’s try to code this out:

int l = -1, r = n;
while (l + 1 < r) {
    int mid = (l + r) / 2;
    if (v[mid] > k)
        r = mid;
    else
        l = mid;
}
return l;

Magically, this code handles all edge cases correctly! It’s absolutely correct given any input. We don’t have any more edge cases because we can express all cases using l and r. If the list is empty or all numbers are greater than k, then our “transition” would occur before the first number, which means l = -1 and r = 0. Even though v[-1] is invalid, (l, r) = (-1, 0) is a valid statement that implies there is no element at most k. We also don’t have infinite loops anymore, because when l and r differ at least by 2, (l + r) / 2 is guaranteed to not equal l or r.

As we can see here, instead of searching for an element, we’re really searching for a transition. And by making this change, our code becomes more elegant.

Let’s try this again and do Guess Number Higher or Lower. The problem is that given a function [int guess(int num)] which is equal to [int compare(int magic, int num)], guess the magic number.

int guessNumber(int n) {
    long long l = 0, r = n;
    while (r-l > 1) {
        long long mid = (l+r)/2;
        int res = guess(mid);
        if (!res) return mid;
        else if (res == -1)
            r = mid;
        else
            l = mid;
    }
    return r;
}

Binary search other than array indices

Split Array Largest Sum (hard)

Sometimes, instead of binary searching over the space of indices of an array, we might instead binary search over the space of the answer. See Split Array Largest Sum (hard). Given an array of nonnegative integers, return the least possible threshold such that you can partition the array into m subarrays where the sum of each subarray does not exceed the threshold. The idea here is that directly finding this number is hard, but it is easy to tell given the threshold, the minimum number of subarrays needed such that the threshold condition is met. The algorithm is then binary search for the threshold, compute the minimum number of partitions of that threshold, and find the smallest threshold such that the number of partitions is at most m.

bool valid(vector<int>& nums, int m, long long sum) {
    int cnt = 1;
    long long run = 0;
    for (int x : nums) {
        if (run + x > sum) {
            cnt++;
            run = x;
        } else {
            run += x;
        }
    }
    return cnt <= m;
}
 
int splitArray(vector<int>& nums, int m) {
    long long tot = nums[0];
    int low = nums[0];
    for (int x : nums) {
        tot += x;
        low = max(low, x);
    }
    long long l = low-1, r = tot;
    while (r-l > 1) {
        long long mid = (l+r)/2;
        if (valid(nums, m, mid))
            r = mid;
        else
            l = mid;
    }
    return r;
}

In the valid function, I calculate the number of sections needed, and if it goes over m, that means the current number is too low, hence invalid. We also know that the answer will not be smaller than the largest element in the array, and will not be larger than the sum of the entire array.

Searching for real

And of course you can binary search over a real number range as well. Imagine in the previous example, the input is instead an array of doubles, and the threshold is also a double. The algorithm is basically the same, except that we need to handle floating point comparisons using an [int sign(double x)] function.

int sign(double x) {
    double eps = 1e-10;
    return x < -eps ? -1 : x > eps;
}
 
bool valid(vector<double>& nums, int m, double sum) {
    int cnt = 1;
    double run = 0;
    for (double x : nums) {
        if (sign(run + x - sum) > 0) {
            cnt++;
            run = x;
        } else {
            run += x;
        }
    }
    return (cnt <= m);
}
 
double splitArray(vector<double>& nums, int m) {
    double tot = nums[0];
    int low = nums[0];
    for (double x : nums) {
        tot += x;
        low = max(low, x);
    }
    long long l = low, r = tot;
    while (sign(r - l) > 0) {
    double mid = (l + r) / 2;
    if (valid(nums, m, mid))
        r = mid;
    else
        l = mid;
    }
    return r;
}

The sign function is used to control the accuracy of double. This is because double arithmetic is not exact, and we need to make sure we are comparing them correctly. Using this function, a < b is rewritten as sign(b-a) > 0, and a >= b is rewritten as sign(a-b) >= 0. The easy way to think about this is to move terms across the inequality signs such that one side becomes zero, then wrap the other side in the sign function.

That’s it for now.

TIW: Range Sum

Prefix array sum is a useful tool that shows up often. It achieves the following:

  1. For a given array of n integers that will not change over the course of program execution, precomputation runs in O(n).
  2. After the precomputation step, querying the range sum (summation of v[l..r]) is O(1).

Here’s how you do it in code:

vector<int> v{1, 2, 3, 4, 5};
vector<int> s{0};
for (int x : v)
    s.push_back(s.back()+x);
int i = 2, j = 4;
cout << "3+4+5 = " << s[j+1]-s[i] << endl;

There are two things to note: first, the size of s, our prefix sum array, is one larger than our original array, and the indices are shifted by one; second, in all summation problems, always think about overflow and whether you need to use long long instead of int.

For those who haven’t seen this before, s[i] is the sum of integers from v[0] to v[i-1]. We’re essentially generating the sums of {}, {1}, {1, 2}, {1, 2, 3}, {1, 2, 3, 4} and {1, 2, 3, 4, 5}, and say we want to know the sum of {3, 4, 5}, you just use sum {1, 2, 3, 4, 5} to substract sum {1, 2}.

The trick here that I want to bring out is the initial value in the array, 0. Without the 0, the code will look like this:

vector<int> v{1, 2, 3, 4, 5};
vector<int> s{v[0]};
for (int i = 1; i < v.size(); i++) {
    s.push_back(s.back()+v[i]);
int i = 2, j = 4;
cout << "3+4+5 = " << s[j]-(i > 0 ? s[i-1] : 0) << endl;

OMG, no. Please don’t do this.
This is actually all it takes to solve Range Sum Query – Immutable:

class NumArray {
private:
    vector<int> s;
public:
    NumArray(vector<int> &nums) {
        s.push_back(0);
        for (int x : nums)
            s.push_back(s.back()+x);
    }

    int sumRange(int i, int j) {
        return s[j+1]-s[i];
    }
};

It’s just my personal preference to put class members and methods as private. Everything is the same as above, nothing to see here.

A slightly harder version is doing this in 2 dimensions, but it’s merely putting the above code in some loops and doing it over and over again. Anyway, let’s look at this problem first:

Range Sum Query 2D – Immutable

Given a 2D array, compute summation of a rectangle of the array in O(1) time after precomputation.

So there are two parts to explain: what our “2D prefix sum array” looks like, and how to compute it in linear time.

In 1D, we have one index, i, for each number, and we sum everything from index 0 to index i-1. Similarly in 2D, we have two indices, i, j, and we sum everything 0 <= i’ < i, 0 <= j’ < j. With the same shift-by-one convention, our notation will be: s[i][j] = sum {v[0..i-1][0..j-1]}. To calculate the sum of numbers in the rectangle v[r1..r2][c1..c2], the equation will be: s[r2+1][c2+1] – s[r1][c2+1] – s[r2+1][c1] + s[r1][c1]. To understand why, say if a matrix looks like this:

[ A ] [ B ]

[ C ] [ D ]

Then the sum of D, {D} = {A+B+C+D} – {A+B} – {A+C} + {A}. Therefore there are 4 terms in the expression.

Now down to the second part: how to calculate it in linear time. It’s the same relation, but reversed: {A+B+C+D} = {D} + {A+B} + {A+C} – {A}. In other words, s[i][j] = v[i][j] + s[i-1][j] + s[i][j-1] + s[i-1][j-1].

vector<vector<int> > v{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
int r = v.size(), c = v[0].size();
vector<vector<int> > s(r+1, vector<int>(c+1));
for (int i = 1; i <= r; i++)
    for (int j = 1; j <= c; j++)
        s[i][j] = v[i-1][j-1]+s[i-1][j]+s[i][j-1]-s[i-1][j-1];
int r1 = 1, r2 = 2, c1 = 0, c2 = 2;
cout << "sum {{4, 5, 6}, {7, 8, 9}} = "
     << s[r2+1][c2+1]-s[r1][c2+1]-s[r2+1][c1]+s[r1][c1] << endl;

As a remark, with some more caution you can do this without any extra space; just store everything in the original array. However you will need to worry about the case about the first item in each 1D array in that case, since you won’t have the zeros padding your matrix to prevent segmentation fault.

The idea of padding a matrix with dummy items is also useful in 2D BFS sometimes, but we can talk about that later.

With some adaptation, this is the code that gets accepted:

class NumMatrix {
private:
    vector<vector<int> > s;
public:
    NumMatrix(vector<vector<int> > &matrix) {
        int r = matrix.size();
        if (r == 0)
            return;
        int c = matrix[0].size();
        s = vector<int>(r+1, vector<int>(c+1));
        for (int i = 0; i < r; i++)
            for (int j = 0; j < c; j++)
                s[i+1][j+1] = matrix[i][j] + s[i][j+1] + s[i+1][j] - s[i][j];
    }

    int sumRegion(int row1, int col1, int row2, int col2) {
        return s[row2+1][col2+1]-s[row1][col2+1]-s[row2+1][col1]+s[row1][col1];
    }
};

Few problems will be so straightforward to apply this approach, but it is often useful as a subroutine in other problems, to reduce the run time complexity. For example this: Count of Range Sum. This is because querying the sum naively is linear in the number of items in the summation, but with this precomputed array we can achieve constant time querying.

That’s basically it; you can call this a data structure with O(n) update and O(1) query, where update means updating an entry in the original matrix, and query means calculating the sum over a rectangle. O(1) update and O(n) query is trivial; you just use the original array for this, and run for loops to calculate the sum every time you want it. But if you want something in between, there’s actually a magical thing called binary indexed tree, which is quite advanced and I will cover in a (very) later post. It achieves O(log(n)) update and query, in somewhere around 10 lines of code.

TIW: Two Pointers

It’s quite surprising how much you can accomplish with constant extra space. There are a few classic problems with integer arrays that can be solved by two pointers.

So what is “two pointers”? Basically it’s an approach to put one pointer on the left and one pointer on the right of an array, and keep moving them towards each other one step at a time (there are also cases where you would put both pointers to the left and move them to the right). When this approach applies, the solution is often easy to code and fast to run. I will go through 2-sum, 3-sum and *gasp* a hard problem on Leetcode, trapping water.

Let’s take the first example, Two Sum II. Given a sorted array, output any pair of indices that sum to a given target (weirdly, indices start at 1).

We have done a similar problem before in the post about upper_bound. But here the twists are that the array is sorted and we need to sum exactly to that number. The solution looks something like this:

vector<int> twoSum(vector<int>& numbers, int target) {
    int l = 0, r = numbers.size()-1;
    while (l < r) {
        int sum = numbers[l]+numbers[r];
        if (sum == target)
            return {l+1, r+1};
        else if (sum > target)
            r--;
        else
            l++;
    }
    return {};
}

It is easier to explain with code. The idea here is that for each number k in the array, we only need to check whether target-k is in the array as well, and if so, return the two indices. In a sorted array, there is basically only one place to look at, which is where we will end up with binary searching for target-k. If this numbers[i]+numbers[j] is too small, and numbers[i]+numbers[j+1] is too big, that means numbers[i] will not form a solution with any other number.

The algorithm therefore proceeds both both ends, checking whether the current sum is too small or too large; if it is too small, make it bigger by incrementing i, otherwise if it is too big, make it smaller by decrementing j. Note that when you increment i, you are essentially discarding all the solutions that use numbers[i]. This is not problematic because the largest number we have not discarded yet (numbers[r]) is too small for it, so it will never form another solution. This by induction shows the correctness of the algorithm.

This can very easily be generalized to Three Sum: return all the unique triplets that sum to 0.

vector<vector<int> > threeSum(vector<int>& nums) {
    vector<vector<int> > ans;
    sort(nums.begin(), nums.end());
    for (int i = 0; i < (int)nums.size()-2; i++) {
        if (i > 0 && nums[i] == nums[i-1])
            continue;
        int l = i+1, r = nums.size()-1;
        while (l < r) {
            int sum = nums[i]+nums[l]+nums[r];
            if (sum > 0 || r < nums.size()-1 && nums[r] == nums[r+1])
                r--;
            else if (sum < 0 || l > i+1 && nums[l] == nums[l-1])
                l++;
            else {
                ans.push_back({nums[i], nums[l], nums[r]});
                l++;
                r--;
            }
        }
    }
    return ans;
}

Sorting is O(nlog(n)), and our algorithm is O(n^2) anyways, so we just sort it to make our life easier. Here we are essentially repeating our two pointers solution n times (n-2 to be exact). In the outermost for loop, I enumerate the leftmost number. To skip duplicates for this number, we just need to make sure we the previous number is not the same, otherwise we would have used this number as starting point already. Then for each inner loop, we are simply doing Two Sum for the array from i+1 to nums.size()-1 with target -nums[i]. Here we skip duplicates by making sure the middle point has not been used before (the previous one is different), and the right point has not been used before (the one after it is different, since the right pointer moves to the left).

 

Here’s another problem that could use two pointers in a different way: Trapping Rain Water

Given a height array, return the amount of rain water that would be trapped by the valleys.

This one is harder than the ones before in terms of figuring out the algorithm. The key observation here is that we only need to calculate the silhouette of the mountain after filling valleys with rain, and use it to subtract the mountain shape. It is obvious that the silhouette from left to right must first go up and then go down, so you can imagine two guys are climbing hills from both sides and must always go up or stay at the same level. Or you can imagine two pointers instead of two guys:

int trap(vector<int>& height) {
    if (height.size() < 3)
        return 0;
    int l = 0, r = height.size()-1, ans = 0, t1 = height[l], t2 = height[r];
    while (l <= r) {
        if (t1 <= t2) {
            ans += t1-height[l];
            l++;
            t1 = max(t1, height[l]);
        } else {
            ans += t2-height[r];
            r--;
            t2 = max(t2, height[r]);
        }
    }
    return ans;
}

There is nothing special in this problem after you figure out the monotonicity in the shape of the mountain. We have two pointers, and we record the current water level that can be trapped from both sides. Two things to note here: the while condition is l <= r instead of l < r because the pointers point at the positions that we haven’t calculated yet; we move the pointer that is lower in height so a pointer from a side won’t overshoot.

 

Bonus problem here: Container With Most Water. The code is very similar so I will skip it, but you guys should be able to figure out why two pointers can give a correct solution here. Hint: think about what solutions you are discarding by moving a pointer, and why those solutions will not be optimal.