Tinkering with Deep Learning Style Transfer

A while ago, Prisma was quite popular on social media and everyone was filtering pictures with its artistic filters. Got some free time yesterday, so I thought I should try out some neural network style transfer apps.

A few words about what deep learning style transfer does. Two pictures for input: one style image and one content image. The style image is supposed to be an artistic piece with a distinct look, and the content image is normally a photo you take. The algorithm will then produce an output image that uses the artistic style from the style image to draw the objects shown in the content image. You will see some examples below.

First I went to Google’s Deep Dream. The content image:

contentimage

This is a picture I took in New York, and you probably recognize it as I used this picture as the banner of this blog.

And the style image:

5-centimeters-per-second-14515

This is an artwork from Shinkai Makoto’s movie 5 Centimeters Per Second.

Alright, so I uploaded to deep dream and this is what I got:

dream_f9cc0d2ae9

That’s pretty cool actually. You can see the color palette is transferred over accurately, and the objects are clearly visible. Here are a few pluses:

  1. Colors are preserved well from the style image. There was a lot of yellow in content, but none in style, so it was entirely removed.
  2. The most visibly styled objects are humans. You can see the clothes turned into a gradient, and the shapes getting abstracted a little bit.
  3. Generation was relatively fast; I waited for a few minutes only. It was also free so I wouldn’t complain.

However there are two things I was not satisfied with:

  1. The resolution is pretty low and it looks quite compressed. The pictures I uploaded were HD, and I got this tiny picture with 693×520.
  2. There are many visible artifacts in the sky. It is understandable since there were clouds in the content image, and a lot of electric cables in the style image. It looks like the training was ended prematurely.

Therefore, I decided to pull the source code and run it for myself.

First attempt

First I Googled style transfer deep learning, and found the link to this Github repo. I’m running it on a Mac, and the installation instructions were quite clear. With all default settings, I got these results:

out_200

out_400

out_700

out

These four images are produced sequentially. As you can see the quality got better and better over time. There are no more line shaped artifacts in the sky, but you can still see a few red and green dots close to the skyline. Over all it looks like the one generated by deep dream, but I like the blue cloud in the center more.

However, these pictures are even smaller! The images were only 512 pixels wide, and it took my Mac 2 hours already. It’s sort of my fault for running deep neural nets on a Macbook Pro laptop without a GPU. But I really want to generate larger and clearer pictures, and if with 512 it’s already taking so long, generating a four times larger picture is going to take much much longer. So I Googled again for a faster solution.

Second attempt

With some more Googling I found this repo. It is written by the same guy working in the Stanford lab, Justin Johnson. It is more painful to install all the dependencies, and I had to modify some code for it to compile, but eventually I got it to work somehow. The read me file claims that the generation is hundreds of times faster and supports near real time video style transfer, so it should be good. Some results:

out1024_mosaic

out1024_scream

out1024_starry

These pictures are styled with the pre-trained models, and even with a width of 1024, they are generated almost in real time. These models are styled with Starry Night, The Scream and a window mosaic art respectively. They are actually very lovely! You can see the brush strokes are vivid, and the images are of such high quality.

But where’s my Shinkai Makoto?

It turns out that if we already have a model trained with a styled image, then generating with a content image is very fast. But we need to train a model each time we have a new style image, and that I assume is what takes more time. Unfortunately I didn’t do it, for reasons explained below.

Sacrificing my CPU for art and science (and reading papers in the meantime)

Since I really want to make this one good picture, I am going back to the original code. Not only does generating a large picture take a lot of CPU time, it also takes a lot of space; I had to delete some huge software that I never used to make up enough space for it. It looks like the program is going to run for approximately a day. Meanwhile I should read the papers behind the above codes, and maybe study the code a little bit.

Here’s the paper for the original algorithm by Gatys, Ecker and Bethge, and here’s the website for the faster code by Johnson. In my understanding (which could very well be wrong), here are the TL;DRs:

Original paper:

  1. The outline of the algorithm is to extract the styles and content of an image separately, then start over with a random noise picture, modifying the pixels slightly over many times until it matches both in style and content of our desired picture.
  2. We already know how to extract content of a picture before the publication of this paper. There is a free trained model online called the VGG network, which is basically a computational graph with a fixed structure and fixed parameters, that is known for identifying objects as well as humans do.
  3. The way VGG works, or any other convolutional neural network, is like the following. On each layer, we have an input vector of numbers. We carry out certain mathematical operations to these numbers. We multiply some of them together, add them, add a constant, scale them, tanh them, take the max of them… all sorts of math, and generate a new vector of numbers. A deep neural network has many layers, one layer’s output feeding into the next layer’s input. If you just do random math, then the generated numbers will be meaningless. But a “trained” network like VGG will produce a vector of numbers that are meaningful. Maybe the 130th number in the output vector indicates how likely this picture has a cat, that kind of thing. Rumor has it that the field of computer vision is started primarily to deal with cat pics.
  4. A convolutional neural network is just a neural network with a special set of mathematical operations that are designed to capture the information in a picture, as it employs a hierarchical structure of calculations that preserves the 2D structure of pixels.
  5. The key breakthrough of this paper: activations from features represent the objects, we already know that. But if we look at different “features” of the network and take the correlation matrix of the output signals across features of a certain layer, we have obtained style information.
  6. So step one: run the content picture through VGG and capture the output signals of a certain layer. Step two: run the style picture through VGG and capture the correlation matrix of output signals from features of a certain layer. Step three: start with a random noise picture, run through VGG and capture the content and style information just as above. Step four: compare our random picture with our captured signals, and figure out how to change these random pixels a little bit so that the style matches with the style picture and the content matches with the content picture. Step five: go back to step three until your computer crashes and burns. Step six: output the “random picture” – it’s already not random anymore!

As you can imagine, changing the pixels a little bit at a time to make it look like something eventually is definitely not going to be fast. That is very understandable.

Faster code:

  1. The original paper framed the problem as an optimization problem, meaning that we have a function f(x), and we want to find the x that maximizes or minimizes f(x). This is true if we think of the output picture to be x and the combined difference between x and our desired picture in terms of style and content as f(x). f(x) is indeed our loss function, and we are trying to minimize it. The style and content images are hidden in the loss function.
  2. This new paper, however, frames the problem as a transformation problem. This means we have an input x, and we want to calculate y = g(x). This is actually very natural to think about, because we have 2 input pictures, and we want 1 output picture, so x could be the style and content images, y our generated picture, and g(x) will be our algorithm.
  3. Finding an unknown function is basically what machine learning does best: first make a really dumb robot, then tell it x0 (some input), it spins out some random crap y0*. You say no, no, no, bad boy; you should say y0. It’s really dumb so it only remembers a little bit. Then you move on to the next input x1 and so on, until the robot learns some patterns from your supervision and starts to make some sense. So one way to solve the transformation problem of style transfer could be something like this: collect many style pictures and content pictures, and run through the slow code to generate pictures. Then make a dumb robot and teach it the corresponding input output pairs, until the robot can do it by itself.
  4. All of the above was prior knowledge to the paper, and this approach has a great advantage over the old one: it is very fast and simple to generate a new picture now. We don’t have to guess anymore; just throw the pictures at the robot and it will instantly give a new one back to you. The downside of this approach, of course, is that getting the robot in the first place can be very expensive; you need to generate many thousands of pictures through slow code before you can generate one picture through the robot.
  5. The key insight of this paper is more of a subtle and technical one: when we teach the robot how to turn x into y, we don’t just compare the robot’s output to a picture we want, but instead we run the output image through the VGG network to extract the style and content, then we use the style and content differences to teach the robot how to do better. Teaching the robot has a formal name called “back propagation” because of how it is practically done. This approach gives higher quality pictures.

Although training can be more expensive, generating new pictures can be real time now. This is great for commercialization. Let’s say a company trains many models based on some distinctive artistic styles, then when users upload a picture, they can get instant artistic filters provided by the company. That’s basically what Prisma does, I suppose. Yet for my purpose, it will not be any faster than the optimization approach.

There are some exciting new developments by Google as well. It builds on top of Johnson’s work, and allows interpolation between styles, so you can mix Van Gogh with Monet, for example. It came out just a month ago! Since they also released the code, I’m going to try it out a little bit. Here’s a quick Monet style:

stylized_0_1000_1_000_2_000_3_000_4_000_5_000_6_000_7_000_8_000_9_000

It’s alright, doesn’t look too great. Probably Monet’s brushes are too small, so this big picture looks just textured instead of styled. Unfortunately, training a new model takes YUGE space, like 500GB. YUGE. This is why the transformation approach is not suitable for a random individual like myself: training a model is very demanding in resources, and the benefits don’t outweigh the costs. Even more sadly, attempting to run this crashed my computer and I have to restart my 1024-Shinkai Makoto picture after running for 18 hours.

Anyway, done with reading papers, I’m just going to sit here and wait for results. After about a day of computation:

ultimate

…I should really get myself a GPU.

TIW: Binary Indexed Tree

Binary indexed tree, also called Fenwick tree, is a pretty advanced data structure for a specific use. Recall the range sum post: binary indexed tree is used to compute the prefix sum. In the prefix sum problem, we have an input array v, and we need to calculate the sum from the first item to index k. There are two operations. Update: change the number at one index by adding a value (not resetting the value), and query: getting the sum from begin to a certain index. How do we do it? There are two trivial ways:

  1. Every time someone queries the sum, just loop through it and return the sum. O(1) update, O(n) query.
  2. Precompute the prefix sum array, and return the precomputed value from the table. O(n) update, O(1) query.

To illustrate the differences and better explain what we’re trying to achieve, I will write the code for both approaches. They are not the theme of this post though.

class Method1 {
private:
    vector<int> x;
public:
    Method1(int size) {
        x = vector<int>(size);
    }
    void update(int v, int k) {
        x[k] += v;
    }
    int query(int k) {
        int ans = 0;
        for (int i = 0; i <= k; i++)
            ans += x[i];
        return ans;
    }
};
class Method2 {
private:
    vetor<int> s;
public:
    Method2(int size) {
        s = vector<int>(size);
    }
    void update(int v, int k) {
        for (; k < s.size(); k++)
            x[k] += v;
    }
    int query(int k) {
        return s[k];
    }
};

Read through this and make sure you can write this code with ease. One note before we move on: we’re computing the sum from the first item to index k, but in general we want the range sum from index i to index j. To obtain range sum, you can simply subtract the prefix sums: query(j)-query(i-1).

OK, that looks good. If we make a lot of updates, we use method 1; if we make a lot of queries, we use method 2. What if we make the same amount of updates and queries? Say we make n each operations, then no matter which method we use, we end up getting O(n^2) time complexity (verify!). We either spend too much time pre-computing or too much time calculating the sum over and over again. Is there any way to do better?

Yes, of course! Instead of showing the code and convincing you that it works, I will derive it from scratch.

The quest of log(n)

The problem: say we have same amount of updates and queries, and we do not want to bias the computation on one of them. So we do a bit of pre-computation, and a bit of summation. That’s the goal.

Say we have an array of 8 numbers, {1, 2, 3, 4, 5, 6, 7, 8}. To calculate the sum of first 7 numbers, we would like to sum up a bunch of numbers (since there has to be a bit of summation). But the amount of numbers to be summed has to be sub-linear. Let’s say we want it to be log(n). log2(7) is almost 3, then maybe we can sum 3 numbers. In this case, we choose to sum the 3 numbers: sum{1, 2, 3, 4}, sum{5, 6} and sum{7}. Assume that we have these sums already pre-computed, we have log(n) numbers to sum, hence querying will be log(n). For clarity, let me put everything in a table:

Table 1a

sum{1} = sum{1}

sum{1, 2} = sum{1, 2}

sum{1, 2, 3} = sum{1, 2} + sum{3}

sum{1, 2, 3, 4} = sum{1, 2, 3, 4}

sum{1, 2, 3, 4, 5} = sum{1, 2, 3, 4} + sum{5}

sum{1, 2, 3, 4, 5, 6} = sum{1, 2, 3, 4} + sum{5, 6}

sum{1, 2, 3, 4, 5, 6, 7} = sum{1, 2, 3, 4} + sum{5, 6} + sum{7}

sum{1, 2, 3, 4, 5, 6, 7, 8} = sum{1, 2, 3, 4, 5, 6, 7, 8}

The left hand side of the table is the query, and all the terms on the right hand side are pre-computed. If you look closely enough you will see the pattern: for summing k numbers, first take the largest power of 2, 2^m, that is ≤ k, and pre-compute it. Then for the rest of the numbers, k-2^m, take the largest power of 2, 2^m’ such that 2^m’ ≤ k-2^m, and pre-compute it, and so on.

There are two steps to do: show that querying (adding terms on the right hand side) is log(n) and show that pre-computing the terms on the right hand side is log(n).

Querying is log(n) is easily seen, because by taking out the largest power of 2 each time, we will at least take out half of the numbers (Use proof by contradiction). Taking out no less than one half each time, after O(log(n)) time we would have taken out all of it.

Now we are one step from finishing on the theoretical side: how do we pre-compute those terms?

Let’s say we want to change the number 1 into 2, essentially carrying out update(1, 0). Look at the terms above: we need to change sum{1}, sum{1, 2}, sum{1, 2, 3, 4} and sum{1, 2, 3, 4, 5, 6, 7, 8}. Each time we update one more pre-computed term, we cover double the number of elements in the array. Therefore we also only need to update log(n) terms. Let’s see it in a table:

Table 1b

update 1: sum{1}, sum{1, 2}, sum{1, 2, 3, 4}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 2: sum{1, 2}, sum{1, 2, 3, 4}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 3: sum{3}, sum{1, 2, 3, 4}, sum {1, 2, 3, 4, 5, 6, 7, 8}

update 4: sum{1, 2, 3, 4}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 5: sum{5}, sum{5, 6}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 6: sum{5, 6}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 7: sum{7}, sum{1, 2, 3, 4, 5, 6, 7, 8}

update 8: sum{1, 2, 3, 4, 5, 6, 7, 8}

Cool, now we have a vague idea about what to pre-compute for update and what to add for query. Now we should figure out the details of the code.

How is the code written?

First, we need to determine the representation of the pre-computed terms. Here is a list of all pre-computed terms:

{1}, {1, 2}, {3}, {1, 2, 3, 4}, {5}, {5, 6}, {7}, {1, 2, 3, 4, 5, 6, 7, 8}

The last number of each term is unique and covers the range 1-8. That’s great news! We can use a vector to store these terms easily, and let the index of the array be the last number of the term. For example, the sum of {5, 6} will be stored at bit[6].

First, the query operation. Let’s revisit the table with binary representation of numbers:

Table 2a: revised version of table 1a, with sums written as bit elements, indices in binary

query 0001: bit[0001]

query 0010: bit[0010]

query 0011: bit[0011]+bit[0010]

query 0100: bit[0100]

query 0101: bit[0101]+bit[0100]

query 0110: bit[0110]+bit[0100]

query 0111: bit[0111]+bit[0110]+bit[0100]

query 1000: bit[1000]

Do you see the pattern yet? Hint: for queries that have k ones, we have k terms on the right. The pattern is that while the index has at least 2 ones, we remove the lowest bit that is one, then move on to the next term. 0111->0110->0100. Finally, here’s the code:

int query(vector<int>& bit, int k) {
    int ans = 0;
    for (k++; k; k -= k & (-k))
        ans += bit[k];
    return ans;
}

After all the work we’ve been through, the code is extremely concise! Two things to notice: the k++ is to change the indexing from 0-based to 1-based, as we can see from the above derivation we go from 1 to 8, instead of 0 to 7. The second thing is the use of k & (-k) to calculate the lowest bit. You can refer to the previous blog post on bitwise operations.

OK, we’re almost done. What about update? Another table:

Table 2b: revised version of table 1b

update 0001: bit[0001], bit[0010], bit[0100], bit[1000]

update 0010: bit[0010], bit[0100], bit[1000]

update 0011: bit[0011], bit[0100], bit[1000]

update 0100: bit[0100], bit[1000]

update 0101: bit[0101], bit[0110], bit[1000]

update 0110: bit[0110], bit[1000]

update 0111: bit[0111], bit[1000]

update 1000: bit[1000]

What’s the pattern this time? Hint: again, look for the lowest bit! Yes, this time instead of removing the lowest bit, we add the lowest bit of the index to itself. This is less intuitive than the last part. For example, lowest bit of 0101 is 1, so the next index is 0101+1 = 0110; lowest bit is 0010, next index is 0110+0010 = 1000.

So here’s the code, note also the k++:

void update(vector<int>& bit, int v, int k) {
    for (k++; k < bit.size(); k += k & (-k))
        bit[k] += v;
}

This is deceivingly easy! That can’t be right. It can’t be that easy… Actually it can, if you look at the category of this post; nothing I write about is hard.

Actually it was easy because we were just matching patterns and assuming it would generalize. Let’s study the for loops a little more to understand why and how low bits are involved in this. This is rather complicated, so for practical purposes you might as well skip them.

First, observe that the low bit of each index indicates the number of integers that index sums over. Say, 0110 has low bit 0010 which is 2, so bit[6] is a sum of two numbers: 5 and 6. This is by design, since this is exactly how we picked the pre-computed terms, so there is no way of explanation.

Second, bit[k] is the sum from indices k-lowbit(k)+1 to k. This is a direct consequence from (1) bit[k] is a summation that ends at the kth number and (2) bit[k] sums over lowbit(k) numbers.

In light of this fact, the code for querying becomes clear: for an index k, we first get the sum from k-lowbit(k)+1 to k from bit[k], then we need to find the sum from 1 to k-lowbit(k). The latter becomes a sub-problem, which is solved by setting k-lowbit(k) as the new k value and going into the next iteration.

For updating, it is much trickier. From the above, we have l-lowbit(l) < k ≤ l, iff bit[l] includes k. Below is a sketch of proof, the actual proof will include more details and be more tedious and boring to go through. For the kth number, bit[k+lowbit(k)] must include it. This is because the lowbit of k+lowbit(k) must be at least 2 times lowbit(k), so k+lowbit(k)-lowbit(k+lowbit(k)) ≤ k-lowbit(k) < k ≤ k+lowbit(k), satisfying the inequality. Also, k can be updated to k+lowbit(k) in the next iteration, because given lowbit(k) < lowbit(m) < lowbit(n) and that bit[m] includes k and bit[n] includes m, bit[n] must include k as well. Till now, we have shown that the bit[k] values we have modified in the for loop must include k.

Then, we also need to show that all bit[l] values that include k are modified in our loop. We can actually count all the bit[l] values that include k: it is equal to one plus the number of zeros before lowbit(k). It is not difficult to see how the for loop reduces the number of zeros before lowbit(k) each time the loop moves on to the next iteration. The only question remaining is why that number? Let’s look at the table 2b again. The numbers of terms for the first four entries, i.e. {4, 3, 3, 2}, are one more than the number of terms for the second four entries, i.e. {3, 2, 2, 1}. This is by design, because bit[4] covers the first four but not the second four, and everything else is pretty symmetric. Again, the first two entries have one more coverage than the second two entries, because bit[2] records the first two but not the second two. Hence, each time we “go down the tree” on a “zero edge” (appending a 0 to the index prefix), the numbers will be covered once more than if we “go down the tree” on a “one edge” (appending a 1 to the index prefix). After we hit the low bit, no more terms of smaller low bits will cover this index, and of course the index itself includes itself, thus the plus one. This is a basic and very rough explanation on how the numbers of zeros relate to the number of terms including a certain index. Here we have argued semi-convincingly the update loop is valid and complete.

 

OK, anyway, time for practice: Range Sum Query – Mutable

It’s literally implementing a binary indexed tree, nothing more.

class NumArray {
private:
    vector<int> bit;
    void update_helper(int v, int k) {
        for (k++; k < bit.size(); k += k & (-k))
            bit[k] += v;
        }
    int query_helper(int k) {
        int ans = 0;
        for (k++; k; k -= k & (-k))
            ans += bit[k];
        return ans;
    }
public:
    NumArray(vector<int> &nums) {
        bit.resize(nums.size()+1);
        for (int i = 0; i < nums.size(); i++)
            update_helper(nums[i], i);
    }
    void update(int i, int val) {
        update_helper(val-query_helper(i)+query_helper(i-1), i);
    }
    int sumRange(int i, int j) {
        return query_helper(j)-query_helper(i-1);
    }
};

It got a little complicated because I didn’t store the original values, so we need some work on line 21 to calculate the change at a certain index given the new value and the old range sums. But that’s nothing important.

That’s it for the basic introduction of binary indexed trees. There are some variants to it, such as replacing the + sign in update function to a min or max function to take prefix min or max, or extending the algorithm to a 2D matrix, aka 2D binary indexed tree. We can even use it for some dynamic programming questions. There are in fact a few more questions on Leetcode that uses this data structure. But that’s for later.

I learned binary indexed tree through the TopCoder tutorial. If you think I did a really bad job and you do not understand at all, you can refer to it as well.

TIW: Bitwise

Bitwise operations are black magic. It is so simple but with them you can do things that you never thought would be so easy. For those who have never seen them, bitwise operations operate on one or two integer type variables, and treat them as an array of booleans. Each operation acts on the elements individually. Let’s see what bitwise operations we have:

  1. And: a & b. 0011 & 0101 = 0001 in binary, so 3 & 5 = 1.
  2. Or: a | b. 0011 | 0101 = 0111 in binary, so 3 & 5 = 7.
  3. Exclusive or: a ^ b. 0011 ^ 0101 = 0110 in binary, so 3 & 5 = 6.
  4. Not: ~a. ~00001111 = 11110000. Depending on the number of bits of your integer data type, the values could vary.
  5. Shift left and right: 001011 << 1 = 010110, 001011 >> 2 = 000010. It is essentially multiplying or dividing by 2 to the power of k.

Applications are sort of miscellaneous and interesting. First let me go through some common routines, then I will go over some problems.

Taking a bit at a certain index

int bitAtIndex(int v, int k) {
    return (v >> k) & 1;
}

Shift the bit you want to the least significant bit, then and it with 1 to get rid of all the other higher bits.

Clearing and setting a bit

void clearBit(int& v, int k) {
    v &= ~(1 << k);
}
void setBit(int& v, int k) {
    v |= 1 << k;
}

This idea is called masking: create a mask, apply it on the number by either or-ing or and-ing.

Getting the lowest 1 bit

int lowBit(int v) {
    return v & (-v);
}

This is sort of tricky. Let’s walk through it. Say our number is 00011000. In two’s complement, the negative number of v is obtained by 1+(~v). So the negative of 00011000 is 11100111 plus 1, which is 11101000. Taking the and result with 00011000 will yield 00001000, which is the lowest bit we want. The way to understand it is that the tail of our input number v, defined as the pattern 100000… at the end, will remain the same after taking the negative. Everything to the left of the tail will be flipped. Therefore taking the and result will yield the lowest bit that is a 1. This is particularly useful for Binary Indexed Tree, which I will talk about in a coming post.

Here’s the most cliched problem.

Single Number

Given an array, find the only number that appeared once. This problem is called Single Number for a reason: imagine it’s Christmas time and you’re out there on the streets alone, and everyone you see is with their significant others. Probably how they got the problem idea. O(n) time with O(1) space. To solve this problem, we need to find an operation that acting on the same number twice will yield the identity function, i.e. f(f(a, b), b) = a. This function had better be commutative and associative, so we can do it in any order, and cancel out all the pairs. Obviously +, -, *, / don’t meet the requirements. The answer is exclusive or. The ideas are: a^a = 0, a^0 = a. Therefore 3^2^1^2^3 = 2^2^3^3^1 = 0^1 = 1, exclusive or-ing all numbers gives you the answer.

int singleNumber(vector<int>& nums) {
    int ans = 0;
    for (int x : nums)
        ans ^= x;
    return ans;
}

One more remark: it’s possible to extend this algorithm to find the only number that appears once, given all other numbers appear n times for n ≥ 2. What we want to accomplish essentially is to create a commutative function that goes to identity after n operations. The function is this: maintain a vector of integers and size 32, counting the numbers of 1 at each bit mod n. It is easy to see after n times, the counts will be either 0 or n, both equal to 0 mod n. So we will end up with the answer. Our solution above is just a degenerate case when n = 2, so a vector of int mod 2 can be replaced by simply an integer, and modular addition can be replaced by exclusive or. Single Number II is the problem for n = 3. Don’t ask me where they got the problem idea from 🙂

int singleNumber(vector<int>& nums) {
    vector<int> c(32);
    for (int x : nums)
        for (int j = 0; j < 32; j++)
            c[j] = (c[j]+((x>>j)&1))%3;
    int ans = 0;
    for (int i = 0; i < 32; i++)
        ans |= c[i] << i;
    return ans;
}

For the people who have never seen bitwise, this is sort of complicated. You can see taking a bit on line 5 and setting a bit on line 8.

Generating the power set

Given a set s, the power set is the set of all subsets of s. For example if s = {1, 2, 3}, P(s) = {{}, {3}, {2}, {2, 3}, {1}, {1, 3}, {1, 2}, {1, 2, 3}}. Here is one possible implementation using bitwise and:

vector<vector<int> > powerSet(vector<int>& s) {
    vector<vector<int> > ans;
    for (int i = 0; i < (1 << s.size()); i++) {
        ans.push_back({});
        for (int j = 0; j < s.size(); j++)
            if (i & (1 << j))
                ans.back().push_back(s[j]);
    }
    return ans;
}

The key idea is the outer loop of i. Say for s = {1, 2, 3}, i will loop from 0 to 7, or 000 to 111 in binary. We will have all the bit patterns in that case: 000, 001, 010, 011, 100, 101, 110, 111. Now for each number, each bit indicates whether we include a certain element of the original set. By looping through these bits, we can create each subset and generate the power set.

Counting the number of 1 bits

Given a number v, count how many bits are 1. There are different ways to do it, two will be shown below.

int countBits(unsigned int v) {
    int ans = 0;
    for (int i = 0; i < 32; i++)
        if ((v >> i) & 1)
            ans++;
    return ans;
}
int countBits(unsigned int v) {
    int ans = 0;
    for (; v > 0; v -= v & (-v), ans++);
    return ans;
}

The second method uses the low bit function, removing the lowest bit every time. I suppose it is more efficient, but it probably doesn’t make a real difference.

Swapping two numbers without space

void swapNumbers(int& a, int& b) {
    a = a^b;  //  a ^= b;
    b = a^b;  //  b ^= a;
    a = a^b;  //  a ^= b;
}

Not very useful, but good to know.

Sum of Two Integers

Here’s a brainteaser: add two numbers, but the code cannot have any + or -. The idea of the code: calculate the carry bits and the addition without carry bits, then add the two results together. The base case is when one number becomes 0. There will not be an infinite loop because the number of trailing zeros of the carry bits must increase each time in a recursion. In the following code, if any of the numbers is 0, the sum is equal to bitwise or. Otherwise, the bits without carry will be a exclusive or b, and the carry bits will be where both a and b are 1, shifted to the left by 1.

int getSum(int a, int b) {
    return (a == 0 || b == 0)? a | b : getSum(a ^ b, (a & b) << 1);
}

Sudoku Solver by Lee Hsien Loong

This code is written by the prime minister of Singapore. It is written in pure C, so it is kind of hard to read. He used a lot of bitwise operators in this code. I read it two years ago, and I don’t want to spend the time understanding everything again, so my short explanation might be faulty. The algorithm is not tricky, as he simply picks a grid with the smallest number of choices, and tries everything recursively (line 180). To speed things up, he used integers as boolean arrays to indicate what numbers are still available for a certain row, column or 3×3 block. Therefore to get the possible placements at a certain grid simply requires taking the bitwise and result (line 171). To use one possible result, he took the lowest bit (line 173). Another trick to reduce the runtime is to pick the grid with the fewest possible choices (lines 162, 163, 188 I assume). He also pre-computed some functions into arrays to avoid repeated work. Most of these are optimizations that reduce the time constant, replacing one O(1) operation with another O(1) operation. Tricky and efficient, but also with reduced readability, in my opinion.

 

Anyway that’s a lot already; I will split the Binary Indexed Tree part in a separate post. Surely these are mostly brainteasers, but some interviewers do like them, and for some lower level (closer to hardware) jobs they are quite important.

TIW: Linked List Cycle

This is more like a special topics post, because it is a very specific algorithm with a very narrow application. The problem statement: given a linked list which has a cycle, determine where the cycle begins.

To explain further, a linked list with a cycle looks something like the number 6. In fact, the only 2 topologies (read: shapes) a linked list could have are a straight line or the number 6. We start walking from the top, and end up walking indefinitely in a loop. To determine whether there is a loop or not is fairly simple: create an unordered set, insert all the visited nodes (or just their pointers, or anything unique to the nodes) into the set, until we insert the same node twice or we run into the end of the linked list (in which case there will be no cycle).

Let’s say we have this declaration of list node.

struct ListNode {
    ListNode* next;
    int val;
};

And this would be the function to return the first node in a cycle.

ListNode* detectCycle(ListNode* head) {
    unordered_set<ListNode*> vis;
    while (head) {
        if (vis.count(head))
            return head;
        vis.insert(head);
        head = head->next;
    }
    return NULL;
}

This is trivial. What is not trivial is to accomplish the exact same task with O(n) time as before but with O(1) space.

The algorithm that does this uses 2 pointers, one fast and one slow. They both start from the head, and the fast pointer moves 2 nodes while the slow pointer moves 1 node per iteration. If there is no loop, the fast pointer will reach the end. If there is a loop, they will fall into the loop , and eventually they will end up at the same node at some point. At that time we will be sure that there is a loop. Why are we sure that they will always collide? Consider them both in the cycle already. In each iteration, the relative distance between the two pointers will increase by 1. When the distance hits a multiple of the length of the cycle, they will effectively have a distance of 0, and hence will be at the same node.

OK, that sounds clever. But we still do not know the beginning of the loop, do we?

Here’s the real genius: yes we can figure it out! After they collide, move the fast pointer back to the head. Now in each iteration, move them both at the same pace of 1 node, and eventually they will collide again. When that happens, we have found the beginning of the loop. I do not know of an intuitive explanation of this, but it is provable with a little algebra. Let’s say the part before the loop has m nodes, and the loop itself is n nodes long. Say, after k iterations, the fast pointer meets with the slow pointer. Then we know (k*2-m)-(k-m) = 0 (mod n), i.e. the relative distance between the two pointers is 0. Hence k is a multiple of n. After m more steps, the slow pointer in total has moved k+m steps. That is equivalent to moving m steps, and then moving k more steps. But moving k steps in the loop does nothing, because k is a multiple of n. Therefore after k+m steps, the slow pointer points at the beginning of the loop. Coincidentally the fast pointer, in the second phase of the algorithm, after moving m single steps, also arrive at the beginning of the loop. Therefore, we have proven (not very rigorously) the first time the two pointers meet in the second phase of the algorithm is at the beginning of the loop.

Perhaps I should show some code to make it clearer.

ListNode* detectCycle(ListNode* head) {
    ListNode *fast = head, *slow = head;
    do {
        for (int i = 0; i < 2; i++) {
            if (!fast)
                return NULL;
            fast = fast->next;
        }
        slow = slow->next;
    } while (fast != slow);
    fast = head;
    while (fast != slow) {
        fast = fast->next;
        slow = slow->next;
    }
    return fast;
}       

That’s the algorithm. There is one problem on Leetcode that uses this algorithm: Find the Duplicate Number. It is not exactly obvious, so please spend some time to convince yourself this problem represents a linked list. Basically the idea is that each number in the array is a node, the index being the address, and the number being the address of the next node. The graph looks like a 6 because there is exactly one node with in degree 2, ignoring the parts of the graph that cannot be reached from the head. Here’s the code for it:

int findDuplicate(vector<int>& nums) {
    int fast = 0, slow = 0;
    do {
        fast = nums[nums[fast]];
        slow = nums[slow];
    } while (fast != slow);
    fast = 0;
    while (fast != slow) {
        fast = nums[fast];
        slow = nums[slow];
    }
    return fast;
}

TIW: Reverse Linked List

I’m just going to say this first: I hate linked list problems. But I also hate waking up, yet I do it every day anyways.

First, what is a linked list: an array that supports O(1) insertion and O(n) random access, in contrast to vector’s O(n) insertion and O(1) random access. Here’s how a linked list look like in real life:

      oooOOOOOOOOOOO"
     o   ____          :::::::::::::::::: :::::::::::::::::: __|-----|__
     Y_,_|[]| --++++++ |[][][][][][][][]| |[][][][][][][][]| |  [] []  |
    {|_|_|__|;|______|;|________________|;|________________|;|_________|;
     /oo--OO   oo  oo   oo oo      oo oo   oo oo      oo oo   oo     oo
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

Here’s how a linked list look like in C++:

struct ListNode {
    int val;
    ListNode* next;
};

First, here’s the link to the Leetcode problem Reverse Linked List.

Like many other problems, linked list reversal could be achieved in two ways: iterative and recursive. The one thing we need to do is point all the “next” pointers backwards. Let’s look at iterative solution first.

ListNode* reverseList(ListNode* head) {
    ListNode* last = NULL;
    while (head) {
        ListNode* next = head->next;
        head->next = last;
        last = head;
        head = next;
    }
    return last;
}

This code is trivially easy to write. I have found 2 slightly different ways to interpret what this code does, and you can see which one you find easier to understand.

The first way to look at it: imagine the linked list 1->2->3->NULL. Let () denote the last, and [] denote head. After the first iteration: NULL<-(1) [2]->3->NULL; NULL<-1<-(2) [3]->NULL, NULL<-1<-2<-(3), [NULL]. Essentially we used more variables to store the last and next guy, so we can transition without losing track of anybody. When head hits null, the last guy will be the new head.

The second way to look at it is to think of “last” as a new linked list. What we attempt here is to pop the first node of the original linked list, and insert it to the front of the new linked list. Therefore, “head” is the first node in the old linked list, and “last” is the first node in the new one. 1->2->3->NULL, NULL; 2->3->NULL, 1->NULL; 3->NULL, 2->1->NULL; NULL, 3->2->1->NULL. In a certain sense, this is describing exactly the same thing, but it might be more clear now why we return the “last” variable: it is the head of the new linked list.

OK, that’s not bad, let’s look at a recursive way. Hold back for a minute and think how do we reduce the problem size, given that we can handle a smaller case. Perhaps we can take the first node out, and reverse the rest. Now we need to insert the original guy into the back of the new list. Oh no, we do not know where the end of the new linked list is! How do we solve this?

pair<ListNode*, ListNode> helper(ListNode* head) {
    if (!head || !head->next)
        return make_pair(head, head);
    auto ans = helper(head->next);
    ans.second->next = head;
    head->next = NULL;
    ans.second = head;
    return ans;
}
ListNode* reverseList(ListNode* head) {
    return helper(head).first;
}

This way is to instead of just writing the function with the given function signature, we write a helper function that also returns the pointer to the last node in the partially reversed list. In that way we can append to the list easily. But this is really not an elegant solution, as we can see in the next code snippet.

ListNode* reverseList(ListNode* head) {
    if (!head || !head->next)
        return head;
    ListNode* ans = reverseList(head->next);
    head->next->next = head;
    head->next = NULL;
    return ans;
}

The solution is on line 5. The end of the partially reversed list used to be the first in the list, so it is pointed to by the current head. So setting head->next->next to head, we have appended head to the end, essentially creating a cycle. Therefore we break the cycle by setting head->next to NULL.

It all looks alright, code is 2 lines shorter and all that. But here’s the catch: although it is O(n) time, the recursive solution is not O(1) space. This is because by calling the function n times within itself, we have created n times the local variables in this function. What is the problem, you may ask? There are two: first, stack space is much more limited than heap space (dynamic memory allocation), so we might encounter stack overflow if the linked list is huge. This will not be a problem for the iterative solution. The second: if we could use so much space, why don’t we just store everything on an external vector and trivially random access everything? What is the whole point of using linked list anyways?

Despite my rant, some interviewers actually don’t care if you write recursion. They might even think it’s cleaner. So it’s still good to know.

For some trickier problems, a recursive solution (or at least a non-O(1) space solution) might be necessary. But if you can do it in O(1) space, you should prefer to do so, because using linear space for linked list problem is kinda cheating.

Now if you cannot wait to challenge yourself, you can try this one: Reverse Linked List II. It’s definitely more code than 10 lines though.

TIW: Dijkstra

Shortest path problems using Dijkstra are actually very easy to write. There is a fixed format and you just need to fill in the blanks. The format:

  1. Optional pre-computation to create the graph;
  2. Make a set of pairs, insert initial state;
  3. While the set is not empty, pop the best one;
  4. If it is the destination, return the distance;
  5. If we have been to this node, continue to the next, otherwise mark this node as visited;
  6. Insert all the neighbors into the set.

In fact, Dijkstra is such a no-brainer that I sometimes write Dijkstra where BFS suffices, even though it gives an extra log(n) factor to the runtime.

Let’s make up a toy problem and see how it works. Say, a 2D maze search. Input: a matrix of characters, walk from any ‘S’ to any ‘T’, using only ‘.’ as path. Return the length of the shortest path. For example

..S
.XX
..T

Will return 6.

Here’s the code:

 

bool inboard(int x, int y, int m, int n) {
    return x >= 0 && x < m && y >= 0 && y < n;
}
int mazeSearch(vector<string>& maze) {
    int m = maze.size(), n = maze[0].size();
    set<pair<int, pair<int, int> > > st;
    vector<vector<bool> > vis(m, vector<bool>(n));
    for (int i = 0; i < m; i++)
        for (int j = 0; j < n; j++)
            if (maze[i][j] == 'S')
                st.insert(make_pair(0, make_pair(i, j)));  // insert all initial states with distance 0
    vector<vector<int> > dir{{1, 0}, {0, 1}, {-1, 0}, {0, -1}};
    while (!st.empty()) {
        auto sm = *st.begin();  // smallest item in the set, closest unvisited state from source
        st.erase(sm);  // popping the heap
        int r = sm.second.first, c = sm.second.second;
        if (vis[r])
            continue;
        vis[r] = true;
        if (maze[r] == 'T')
            return sm.first;  // found shortest path
        for (auto d : dir) {
            int nr = r+d[0], nc = c+d[1];
            if (inboard(nr, nc, m, n) && (maze[nr][nc] == '.' || maze[nr][nc] == 'T'))
                st.insert(make_pair(sm.first+1, make_pair(nr, nc)));  // next states with 1 distance farther from source
        }
    }
    return -1;  // no path found
}

The main tricks here are to use a set as a heap and a pair to denote a state. A few benefits of this trick:

  1. Set has log(n) insert, remove and maximum/minimum query, ideal for Dijkstra.
  2. Set has built-in duplicate removal, potentially saving a lot of extra work.
  3. Pair has default comparator as mentioned in STL#1, so we do not have to write our own.

The alternative is to use std::priority_queue<>. I cannot think of any reason why using priority queue would be superior. There are a few reasons I do not prefer them: the name is longer to type, does not have duplicate removal, requires extra work to get a min-heap (since it is by default a max-heap), and I don’t remember the syntax and would have to Google every time. Getting a min-heap or max-heap out of std::set<> is trivial: just change *st.begin() to *st.rbegin(). the first item from the end is the largest item.

Of course this problem only requires BFS.  You can modify the above code by changing the set to a vector and changing the while to a for loop, shown in the BFS blog post earlier. But this is just an example to demonstrate the structure of a Dijkstra implementation.

Unfortunately there is no problem that needs Dijkstra on Leetcode; although you could certainly use it to replace some BFS problems. Then maybe I’ll slightly go over A* using the same code structure just for fun. If you really understand the above, it is trivial to modify it for similar problems.

A* by augmenting Dijkstra

A* is just Dijkstra with a “postpone” function that estimates a lower bound for the distance of a node to destination. When the postpone function is 0, we get exactly Dijkstra. Think of it this way: you are a college student with a homework assigned, and have to submit before it is due. Given that you don’t know the due date, but it is at least tomorrow night. Are you going to do it now? Of course no. But given you only know the due date is before next Sunday, it might actually be tonight, so you have to do it now. The same idea applies to A*: if we know this path definitely takes at least 10 steps, we can safely postpone walking it after we have tried all paths that could possibly take 9 steps.

To implement A* with the above code: instead of putting in the distance from source at the first item of each pair, put the distance from source plus the estimated lower bound distance from destination.

Once again I don’t have a good problem, and it is also not that common, so I’ll skip the details. The actual function form highly depends on the problem. In some cases it might even be impossible or useless to calculate a lower bound. So A* is definitely not the one ring to rule it all, but it’s still good to know.