Fast RNG in an Interval

https://arxiv.org/abs/1805.10941 – Fast Random Integer Generation in an Interval

Just read this interesting little paper recently. The original paper is already quite readable, but perhaps I can give a more readable write-up.

Problem Statement

Say you want to generate a random number in range [0, s), but you only have a random number generator that gives you a random number in range [0, 2^L) (inclusive, exclusive). One simple thing you can do is to first generate a number x, divide that by s and take the remainder. Another thing you can do is to scale the range down by something like this: x * (s / 2^L), with some floating point math, casting, whatever that works. Both ways will give you a resulting integer in the specified range.

But these are not “correct”, in a sense that they don’t generate random integers with a uniform distribution. Say, s = 3 and 2^L = 4, then you will always end up with one number being generated with probability 1/2, the other two numbers 1/4. Given 4 equally likely inputs, you just cannot convert that to 3 cases with equal probability. More generally, these simple approaches cannot work when s is not a power of 2.

First Attempt at Fixing Statistical Biases

To fix that, you will need to reject some numbers and try again. Like in the above example, when you get the number 3, you shuffle again, until you get any number from 0 to 2. Then, all outcomes are equally likely.

More generally, you need to throw away 2^L mod s numbers, so that the rest will be divisible by s. Let’s call that number r, for remainder. So you can throw away the first r numbers and use the first approach of taking remainder, as shown in this first attempt (pseudocode):

r = (2^L - s) mod s // 2^L is too large, so we subtract s
x = rand()
while x < r do
x = rand()
return x mod s

That’s a perfectly fine solution, and in fact it has been used in some popular standard libraries (e.g. GNU C++). However, division is a slow operation compared to others like multiplication, addition and branching, and in this function we are always doing two divisions (mod). If we can somehow cut down on our divisions, our function may run a lot faster.

Reducing number of divisions

It turns out we can do just that, with just a simple twist. Instead of getting rid of the first r numbers, we get rid of the last r numbers. And we can verify whether x is in the last r numbers like so:

x = rand ()
x_mod_s = x mod s
while x - x_mod_s > 2^L - s do
x = rand ()
x_mod_s = x mod s
return x_mod_s

The greater-than comparison on line 3 is a little tricky. It’s mathematically the same as comparing x - x_mod_s + s with 2^L, but we do this instead because you can’t express 2^L with L number of bits. So basically, the check is saying if the next multiple of s after x is larger than 2^L, then x is in the last r numbers and must be thrown away. We never actually calculate r, but with a little cleverness we manage to do the same check.

How many divisions are we doing here? Well, at least one on line 2, and possibly 0 or many more, depending on how many times the loop is run. Since we’re rejecting less than half of the possible outcomes (we’re at least keeping s and at most rejecting s - 1), we have at least 1/2 chance of breaking out of the loop each time, which means the expected number of loops is at most 1 (0 * 1/2 + 1 * 1/4 + 2 * 1/8 … = 1). So we know that the expected number of divisions is at worst 2, equal to that of the previous attempt. But most of the time, the expected number is a lot closer to 1 (e.g. when s is small), so this can theoretically be almost a 2x speed up.

So that’s pretty cool. But can we do even better?

Finally, Fast Random Integer

Remember other than taking remainders, there’s also the scaling approach x * (s / 2^L)? It turns out if you rewrite that as (x * s) / 2^L, it becomes quite efficient to compute, because computers can “divide” by a power of two by just chopping off bits from the right. Plus, a lot of hardware has support for getting the full multiplication results, so we don’t have to worry about x * s overflowing. In the approach using mod, we inevitably need one expensive division, but here we don’t anymore, due to quirks of having a denominator of power of 2. So this direction seems promising, but again we have to fix the statistical biases.

So let’s investigate how to do that with our toy example of s = 3, 2^L = 4. Let’s look at what happens to all possible values of x.

xs * x(s * x) / 2^L(s * x) mod 2^L
0000
1303
2612
3921

Essentially we have s intervals of size 2^L, and each interval maps to one single unique outcome. In this case, [0,4) maps to 0, [4, 8) maps to 1, and [8, 12) maps to 2. From the third column, we have two cases mapping to 0, and we’d like to get rid of one of them.

Note that the fundamental reason behind this uneven distribution is because 2^L is not divisible by s, so any contiguous range of 2^L numbers will contain a variable number of multiples of s. That menas we can fix that by rejecting r numbers in each range! More specifically, if we reject the first r numbers in each interval, then each interval will contain the same number of multiples of s. In the above example, the mapping becomes [1, 4) maps to 0, [5, 8) maps to 1, and [9, 12) maps to 2. Fair and square!

Let’s put that in pseudocode:

r = (2^L - s) mod s
x = rand ()
x_s = x * s
x_s_mod = lowest_n_bits x_s L // equivalent to x_s mod 2^L
while x_s_mod < r do
x = rand ()
x_s = x * s
x_s_mod = lowest_n_bits x_s L
return shift_right x_s L // equivalent to x_s / 2^L

Now that would work, and it would take exactly 1 expensive division on line 1 to compute r every single time. That beats both of the above algorithms! But wait, we can do even better! Since r < s, we can first check x_s_mod against s, and only compute r if that check fails. This is the algorithm proposed in the paper. It looks something like this:

x = rand ()
x_s = x * s
x_s_mod = lowest_n_bits x_s L
if x_s_mod < s then
r = (2^L - s) mod s
while x_s_mod < r do
x = rand ()
x_s = x * s
x_s_mod = lowest_n_bits x_s L
return shift_right x_s L

Now the number of expensive divisions is either 0 or 1, with some probability depending on s and 2^L. This looks clearly faster than the other algorithms, and experiments in the paper confirmed that. But as often is the case, performance comes at the cost of less readable code. Also in this case, we’re relying on hardware support for full multiplication results, so the code is less portable and in reality looks pretty low level and messy. Go and Swift have adopted this, deciding the tradeoff worthy, according to the author’s blog (https://lemire.me/blog/2019/09/28/doubling-the-speed-of-stduniform_int_distribution-in-the-gnu-c-library/), C++ may also use this soon.

How Many Divisions Exactly?

There’s still one last part we haven’t figured out – we know the expected number of divisions is between 0 and 1, but what exactly is it? In other words, how many multiples of s, in the range [0, s * 2^L), has a remainder less than s when dividing by 2^L? To people with more number theory background, this is probably obvious. But starting from scratch, it can take quite a lot of work to prove, so I’ll just sketch the intuitions.

It’s a well known fact that if p and q are co-prime (no common factors other than 1), then the numbers { 0, p mod q, 2p mod q, 3p mod q ... (q-1) p mod q } will be exactly 0 to q-1. This is because if there is any repeated number, then we have a * p mod q = b * p mod q (assuming a > b), which indicates (a - b) * p mod q = 0. But we know that 0 < a - b < q, and p has no common factor with q, so if we multiply those two together, it cannot be a multiple of q. So it’s impossible to have duplicates, and multiples of p will evenly distribute among [0, q) when taken mod q.

Now if s and 2^L are co-prime, there will be exactly s number of multiples of s that has a remainder ranging from 0 to s - 1. That means the expected number of divisions in this case is s / 2^L.

If they aren’t co-prime, that means s is divisible by some power of 2. Say s = s' * 2^k, where s' is odd. Then, s * 2^(L-k) = s' * 2^L will be 0 mod 2^L. So your multiples of s mod q will go back to 0 after 2^(L-k) times. And you have 2^k iterations of that. So if you go through the final count, it goes 2^k, followed by 2^k - 1 number of 0s, rinse and repeat. How many are below s? You have s' number of nonzero counts, each one equal to 2^k – it’s again, unsurprisingly, s. So the expected number of divisions is still indeed s / 2^L.

Final Thoughts

Earlier I said each time you need to throw away 2^L mod s numbers to make an even distribution, but that’s not completely necessary. For example, if s = 5 and 2^L = 8, you don’t have to fully reject 3 cases. In fact, you can save up those little randomness for the next iteration. In the next iteration, say you get into 1 of the 3 cases again. Then, combined with the 3 cases you saved up last time, you are now in 9 equally likely events. If you are in the first 5, then you can safely return that value without introducing biases. However, this is only useful when generating the random bit strings are really expensive, which is totally not the case in non-cryptographic use cases.

One last note – we have established that the expected number of divisions is s / 2^L. As s gets close to 2^L, it seems like our code can become slower. But I think that’s not necessarily the case, because the time division takes is probably variable as well, if the hardware component uses any sort of short-circuiting at all. When s is close to 2^L, 2^L mod s is essentially one or two subtractions plus some branching, which can theoretically be done really fast. So, given my educated guess/pure speculation, s / 2^L growing isn’t a real concern.

Catenary Inversion: Curves of Sagrada Familia

Sagrada Familia is stunning and beautiful. If you ever go visit, don’t miss out on the bottom level: there are exhibitions about the constructions and history of this masterpiece by Gaudi. When I visited a while ago, I was surprised to find a model that explained how the curves of the arches of the church were designed, and it was really cool.

To start out, imagine you’re building the roof of a house. Usually they are like this: /\. Modern houses also look like this: Π. The point is, a flat roof is hard to support, so the older houses are all angled. If you hold a dumbbell horizontally to your side, you’ll feel tired a lot faster than holding it angled upwards. This is because materials in general are a lot better at handling compression than bending forces. By holding your arm at an angle, you are supporting part of the weight by compressing your arm along its own direction, reducing the amount of force perpendicular to that direction. Back to the roof: a flat one is fine if made by concrete and steel, but if we use a long piece of wood, maybe not.

Anyway, using the same materials, an arch shaped building will last much longer than a flat topped one, simply because the bricks are subject to bending forces to a less extent. The problem then becomes: how can we find the shape that minimizes bending force at every point on the arch (to zero, actually)?

If you remember high school physics, we can dive into it. Say we draw an arch like this: ∩, and we pick any brick on the arch (say on the left half). Let’s pretend this is the ideal curve, such that there is no bending force anywhere. This little brick you picked is going to have a tiny bit of mass, and the slope of the arch changes a little bit before and after it. Then we have three forces acting on this dot: gravity which points down, force from the left brick supporting it, and force to the right brick. The latter two forces have slightly different slopes, and the three add up to 0 (otherwise the arch will collapse. Note that we don’t have forces perpendicular to the arch between bricks, which is the whole point.) Oh no, we have a differential equation! It’s been 2 days since I took that exam, I have forgotten everything! What do I do?

The Catenary Curve

Consider this seemingly unrelated physical problem: given a string with multiple beads on it on regular intervals. Now you hold the two ends loosely such that it forms a U shape. What properties does this shape satisfy? Similarly, consider a single bead (again, on the left side, but as they say: WLOG) there are three forces acting on this bead: gravity pulling it down, force from the left bead pulling it up, and force to the right bead. You probably saw this coming: these three forces are same as those we just talked about on bricks, but pointing to exactly the opposite directions! (I am too lazy to draw a picture, you can try to get the idea). What’s more, these three forces also add up to 0, since the beads are not moving. This means that we can take the shape of the string, put it upside down, and make a perfect arch out of it. To see why this is true, imagine you draw the perfect string curve on paper, drawing out the forces. Now you rotate the paper 180 degrees and negate all three forces on the beads. (1) The three forces still sum to 0; (2) gravity still points down with the right amount; (3) the remaining two forces still balance out with the forces on the bricks nearby, since those are also negated. Hence, this curve satisfies all of our requirements. We have found the answer! If I recall correctly, architects actually used this method to draw the curves for blueprints. This curve of beads on a string is called the catenary curve, and the solution is in the form of the hyperbolic cosine, which is essentially a sum of two exponential functions, having equal but opposite signs of exponents.

As a personal anecdote, a physics professor of mine once called a problem “physically solved” after he wrote out the equations which uniquely determine the answer, because the rest can be solved by mathematics, either analytically or numerically. This can lead to very uninteresting tests, for example simply writing out the Maxwell equations for every single EM problem. In our case, the arch problem is not only “physically solved” because we can derive the curve using the same differential equation as catenary curves, but also that we can “physically solve” it using beads, a string and actual physics.

Theoretically Or Practically Fast: Two Types of Hash Tables

When people say a program is fast, they usually mean it takes a short time to execute; but when computer scientists say an algorithm is efficient, they mean the growth rate of the run-time with respect to the input size is slow. It is often the case that these two notions approximately mean the same thing, that a more efficient algorithm yields a faster program, even though it is not always true. I recently learned about two types of hash tables that show this difference: Cuckoo hashing and Robin Hood hashing. Even though in theory Cuckoo is more efficient, Robin Hood is in practice a lot faster.

Vanilla Hash Tables

Before going into details, let’s revisit how a hash table works. Simple thing, you have a hash function that takes a key and yields a random (but deterministic) number as the “hash”, and you use this hash to figure out where to put this item in the table. If you want to take something out, just compute the hash again and you can find exactly where the item is stored.

That’s not the whole story though. What if you want to insert item A into spot 1, but item B is already there? You can’t just throw away item B, what are you, a savage? There are multiple ways to deal with this problem called collisions. The easiest way is chaining, which is to make each slot into a list of items, and inserting is just appending to the corresponding list. Sure, that would work, but now your code is a lot slower. This is because each operation takes a look up to get the list, and another look up to get the contents of the list. This is really bad because the table and the lists will not align nicely in cache, and cache misses are BAD BAD things.

Linear Probing

To improve cache performance, people started to throw out bad ideas. One bad idea that kind of worked was, if this slot is occupied, just take the next one! If the next one is taken, just take the next next one! This is called linear probing, and is pretty commonly used. Now we solved the caching problem because we’re only accessing nearby memory slots, our programs now run a lot faster. Why is this bad? Well, it turns out as the table gets filled up, the probe count can get really high. Probe count is how many spots need to be checked before an an item is found, or how far away the item is from where it should be. Imagine out of 5 spots, the first 3 are occupied. Then, the next insertion has a 60% chance of collision. If it does collide, the probe count is on average 2 (probe count = 3 if the hashed index is the first slot, 2 if second, and 1 if third). And after a collision, the size of the occupied segment increases by one, making the next insertion even worse.

Robin Hood to the Rescue

This brings us to Robin Hood hashing. It is actually very similar to linear probing, but with one simple trick. Say in the example above, we have item A sitting at slot 1, B at 2, C at 3 like this: [A B C _ _] and we have another collision trying to insert D at slot 1. Instead of placing D like [A B C D _], we arrange them as [A D B C _], then the probe count would be [0 1 1 1 _] instead of [0 0 0 3 _]. Now there are two things to understand: (1) What exactly happened? (2) Why is this better?

To answer (1), imagine we’re inserting D to slot 1. We see that A is there, then just like linear probing, we move on. Now B is there with probe count 0, while D already has probe count 1. That’s not fair! How dare you have a lower probe count than I do! So D kicks out B, and now B has to find a place to live. It looks at C and yells and kick C out, now C has to find a place, which happens to be slot 4. In other words, during probing, if the current item has a lower probe count, it is swapped out and probing continues.

But (2) why is this better? If you think about it, the total probe count (which is 3) is unchanged, so the average probe count is the same as just linear probing. It should have exactly the same performance! This approach is superior because the highest probe count is much lower. In this case, the highest probe count is reduced from 3 to 1. If we can keep the max probe count small, then all probing can be done with almost no cache misses. This makes Robin Hood hash tables really fast. Having a low maximum probe count also solves a really painful problem in linear probing, which is when you remove an item in the hash table and leave a gap. When we look for items, we can’t just declare missing when we first probe an empty spot, because maybe our item is stored further ahead, and this spot was just removed later on. With the maximum probe count really low in Robin Hood hashing, we can then say “I’m just going to look at these n spots, if you can’t find it it ain’t there.” Which is nice and simple. (It also enables shift back deletion instead of tombstone, which is another boost to the performance, but I don’t want to get too deep.)

A Theoretically Better Solution

Even though the highest probe count is small, it still grows as we have more items in the hash table (probe count goes up when load factor goes up). What if I tell you I don’t want it to grow? Meet Cuckoo hashing, which has a maximum probe count of 1.

The algorithm is slightly more complicated. Instead of one table, we have two tables, and one hash function each, chosen randomly among the space of all hash functions (not practical, but close enough). When we look up an item, we just need to compute the hash for the first and the second table, and look at those two spots. Done!

I mean, sure, but that’s easy – how do you insert? What if both spots are full? That’s where the name cuckoo comes from. You pick one of the spots, kick that guy out and put your item there. Now with the new guy, you find its other available spot, kick that guy out, so on and so forth. If you end up in a loop (or spend too long and decide to give up), you pick another two hash functions, and rebuild the hash table. As far as at least half of the slots are empty, it is highly likely that the whole operation takes constant time in expectation. I don’t have any simple intuitions about why this is true. If you do, please let me know.

Theoretically, Cuckoo hashing beats Robin Hood because the worst case look up time is constant, unlike in Robin Hood. Why is it still slower than Robin Hood in practice? As it turns out, the constant is both a blessing and a curse. When looking up missing items, there has to be two look ups; even when looking up existing items, on average there still has to be 1.5 look ups. To make matters worse, since the two spots are spatially uncorrelated, they usually cause two cache misses. Once again, cache performance makes a big difference in the overall speed.

Having said all that, performance in real life is a very complicated matter. Cache performance is just one of the unpredictable components. There are also various compiler optimizations, chip architectures, threading/synchronization issues or language designs that can affect performance, even given then same algorithm. Writing a fast program is all about profiling, fine tuning, and finding the balance*.

*My coworker once said, “at the end of almost any meeting, you can say ‘it’s all about the balance’, and everyone would agree with you.

Variants of the Optimal Coding Problem

A Variant of Optimal Coding Problem
Prerequisites: binary prefix code trees, Huffman coding

Here’s a problem: say we are given n positive numbers, and you are allowed to each time pick two numbers (a and b) that are in the list, add them together and put it back to the list, and pay me an amount of dollars equal to the sum (a+b). Obviously your optimal strategy would be not to play this stupid game. Here’s the twist: what if I point a gun at your head and force you to play until we end up with only one number, then what would be your optimal strategy, and what would be the minimum amount of money you have to pay me?

At the first glance this seems quite easy, of course you would just pick the smallest two numbers each time and add them, because that would be the best option you have each time. It would minimize your cost every round, and I will end up with the minimum amount possible. However if you paid attention in your greedy algorithm lecture in college, you will notice the fallacy of this statement – picking the best option each time does not necessarily lead to the best overall outcome.

I claim that it is not obvious that this algorithm is right, not only because it is greedy, but also because this problem is equivalent to the classic optimal coding problem, which is not that easy. For those who haven’t heard, optimal coding problem is the problem where you’re given a bunch of symbols and the average frequencies for each of them, and you try to design a set of binary prefix codes for these symbols such that using this coding scheme, the average sequence of symbols will be represented by the shortest expected number of bits.

What? Yeah, this number summing and merging problem is equivalent to the optimal coding problem. I’ll spare you and myself a formal proof, but here’s an example to convince you. Say we have 3 numbers a, b and c, and we first sum a+b, then sum (a+b)+c. Your total cost will be 2a+2b+c. Now if we draw out the summation as a tree:

 a+b+c
/\
/\ c
a b

(pardon my ascii art) you will see that a and b sit at depth=2, and c sits at depth=1. This gives another way to come up with the total cost: multiply each number with its tree depth, and sum these products together. This is always the same as the total cost, because for each number, as you go up a level in the tree, you add that number to the total cost once.

Now if you paid attention in your algorithm lecture on binary prefix code trees, you will see this looks exactly like a binary prefix code tree. To complete the analogy, say we are given three symbols A, B and C, and they have frequencies a, b and c. Had we given them the same binary code tree as above, the average code length weighted by frequencies will be 2a+2b+c. If you don’t see why, you should pause here and think until you get it (or not, you decide).

Now that we established these two problems are the same (same objective function, same solution space), it follows that the optimal algorithms for both are the same. The solution to the optimal coding problem is Huffman Coding, as spoiled from the very beginning of this post. Now, this algorithm is exactly the greedy algorithm above that we weren’t sure was right or not. So – I just spent this much time convincing you that algorithm might not be right, and then convincing you it is indeed correct. Yay progress!

Variants of that Variant

(Technically, that wasn’t really a variant because that was exactly the same problem.) After talking to a few people in real life about this, I found it hard to convince people that this innocent looking algorithm could possibly be wrong, and it certainly didn’t help that the algorithm ended up being right. One line I tried was: “even Shannon tried this problem and came up with a suboptimal Shannon-Fano coding scheme,” but people didn’t seem to appreciate that. In fact, it is quite stunning that when we turn the optimal coding problem into the summing problem, the solution seems so much more obvious.

So I came up with a second attempt: what if we tweak the problem a bit so that greedy wouldn’t work? In the original problem, when we merge two numbers, it costs us the sum of the two, and we reinsert the sum to the list of numbers. What if the cost function or the merge function is not an addition? For example, we can imagine changing the cost function to take minimum, or the merge function to return the square of the sum, etc. Then, would the greedy algorithm still work? Now people are much less sure – which kind of proves the point that they shouldn’t have been so confident in the first place.

But would it? It turns out, perhaps unsurprisingly, in most cases it won’t work anymore. Just to raise a few counter examples, (1 1 1 1) yields an optimal merging of 1 + (1 + (1 + 1)) if the cost function is taking minimum. (1 1 1 2 3) yields an optimal merging of ((1 + 1) + 2) + (1 + 3) if we take maximum instead. And if we take cost as the square of the sum, for the case (1 1 2 2), the optimal way is (1 + 2) + (1 + 2). In all these cases, the greedy way will give a sub-optimal total cost. It happens that if we take both cost and merging functions to taking min/max, the greedy approach works, but they are a lot more trivial than the other cases.

At this point, it seems like we really lucked out that the optimal coding problem corresponds to a case where greedy works. Otherwise, computers would either spend so many more cycles computing the optimal codes for compressing files, or we would have ended up with larger file sizes for various kinds of compressed files.

To end this post with a trivia: did you know that the Huffman tree for a gzip file is also encoded using a Huffman tree? Since every file has a different frequency count of symbols, each file has a different Huffman tree that needs to be saved in the compressed file as well. This tree is then encoded, again, using a standard Huffman tree that is the same for all compressed files, so the decoder knows how to decode, which is pretty meta.

Longest Sequence Without Duplicate Substring of Constant Length

Around 9 years ago I was working on a multiple choice exam and this problem came up. Recently I thought about it again and it’s actually very simple.

The problem is: say you are writing a sequence of letters A, B, C, D, with one constraint. Whatever substring of length 3 that already appeared in your sequence cannot appear again. What is the longest sequence length you can possibly come up with?

As an example, the sequence “ABCDACBB” is valid, but the sequence “ABCDABCB” is not, since “ABC” appears twice. By the same token, “AAAAB” is also invalid, since “AAA” appears twice.

There is an easy generalization of this problem to k choices of letters instead of 4, and m as the length of substring instead of 3. Quick maths: there are k^m substring combinations, so there can be k^m unique starting points, and the answer is upper bounded by k^m + m – 1.

9 years ago I didn’t know anything so I thought it was hard. But now it just looks like a graph traversal problem, which is either trivial or intractable (depending on whether it’s Eulerian or Hamiltonian). With this in mind, the problem is trivial. You can either read my solution below or spend a few minutes to figure it out.

Construct a directed graph where each node represents a unique substring made up of the k choices of letters of length m – 1. There are k outgoing edges and k incoming edges for each node. The k outgoing edges represent the k different ways you can append a letter to the current substring of length m – 1. So for example for the node “AB”, the outgoing edge “C” would point to the node “BC”, since after adding a letter “C”, you get to the new prefix “BC”. Then each edge corresponds to one unique substring of length m, and the problem becomes to find an Eulerian path in this graph. Since this graph is always connected and each node has the same number of in degrees as out degrees, the path always exists and must be a cycle, regardless of m and k. There will be many possible cycles as well. Now this problem is theoretically solved, and also programmatically solved.

As I later found out from my roommate’s math homework, this graph construction is called de Bruijn graph. Of course everything has to have a name.

I wish I could go back in time and explain this to the little me.

Having fun with queues

Recently I’ve been reading Okasaki’s Purely Functional Data Structures. In the first few chapters, the book discusses several really interesting ideas, that I will attempt to summarize and introduce below.

TLDR: persistence; lazy evaluation; amortization; scheduling.

Before getting into algorithms, there are some prerequisites to understanding the context of functional programming.

Immutable objects, persistence

Normally in imperative programming, we have variables, instances of classes and pointers that we can read or write. For example we can define an integer and change its value.

1
2
3
int x = 4;
x = x + 1;
cout << x << endl;

However in the functional version of the same code, even though it looks as if it does the same thing, variables are not modified.

1
2
3
let x = 4 in
let x = x + 1 in
printf "%d\n" x

What the above OCaml code does, is it first binds the name x to the number 4, then binds the name x to the number 5. The old value is not overwritten but shadowed, which means the old value is theoretically still there, just that we cannot access it anymore, since we lost the name that refers to it. In purely functional code, everything is immutable.

But if we use a new name for the new value, we can keep the old value, obviously:

1
2
3
let x = 4 in
let y = x + 1 in
printf "%d %d\n" x y

Now let’s look at some nontrivial example, modifying a list.

1
2
3
4
5
6
let x = [ 1; 2; 3 ] in
let y = 0 :: x in
print_list x;
print_list y
(* 1 2 3
   0 1 2 3 *)

Note that the double colon operator adds a new value on the head of a singly linked list. What this means is that after adding a new element to the linked list, not only do we have a linked list of [ 0; 1; 2; 3 ] called y, but we also still have the old list of [ 1; 2; 3 ] called x. We did not modify the original linked list in any way; all the values and pointers in it are still unchanged.

Now this, in itself, is already interesting: first, you can have y = 0 :: x and z = -1 :: x, essentially creating 3 linked list in total. But since they all share the same tail of length 3, so only 5 (-1, 0, 1, 2, 3) memory locations are being allocated, instead of 3+4+4 = 11. It is also worth the time to go through common list operations (reverse, map, reduce, fold, iter, take, drop, head, tail…) and verify that you can implement them in the same time complexity without making modifications to the original list. You cannot apply the same techniques to random access arrays, and that is why lists are the basic building blocks of functional programs the same way arrays are the basic building blocks of imperative programs.

Lazy evaluation, memoization

One of the advantages of having immutable objects is that imagine for a function f that takes in only a list as input, the output of the function will always be the same as far as the list has the same physical address. Then we do not have to repeat the same work, given the technique of memoization: after executing a pure function with an immutable input, we can write down the input output pair in a hash table and simply look it up the next time. What is even better is we can do this at the language/compiler level, so memoization is hidden from the programmer. A nice consequence is that we can simply write a recursion to solve problems that automatically turn into DP solutions.

Lazy evaluation capability is also built into many functional languages. The point of lazy evaluation is to delay computation as late as possible until we actually need the values. With this technique, you can define a sequence of infinite length, such as a sequence of all the natural numbers. There are two basic operations here: creating a lazy computation, and forcing it to get the output. This will be important later in our design of queues.

Queue: 1st gen

Now let’s get to the point: implementing purely functional queues. The first challenge: actually have a queue, regardless of time complexity. Given building blocks of lists, how do we implement a queue? Note that it is not desirable to build from arrays, because they are intrinsically mutable, using them will not help to create an immutable data structure. Fortunately we have a classic algorithm for this: implement a queue with two lists, F and R. We pop from the front of F and push to the front of R, and do (F, R) = (List.rev R, []), reverse the list R and put it in the front, whenever F becomes empty. Here’s the obligatory leetcode link.

This implementation claims to have amortized costs of O(1) for push, pop and top operators. This is obvious for push and top, and pop takes a little bit of analysis. Sometimes when popping, we will reverse a linked list of n elements, which takes time O(n). However in order to trigger this O(n) operation, we need to already have done n O(1) operations. That means approximately every n operations, we can at most have done n*O(1) + O(n) work, which is still O(n). Therefore on average we still only had O(1) work per operation. This is a very brief description, assuming prior knowledge in readers.

Amortization breaking down

However this amortization bound does not work for our functional version of queue. This is perhaps obvious to see: say we have one element in F, and we have n elements in R. These two lists together form a functional queue instance, called x. Calling y = Queue.pop x would be O(n), then calling y = Queue.pop (Queue.push x 1) would be another O(n). In fact, we can create many such O(n) operations. Obviously, the cost now for Queue.pop could be as bad as O(n), even in the sense of amortized cost. This is because now that we can access the entire history of the data structure, we can force it to perform the worse operation over and over again, hence breaking time complexity bound. Because of this way of repeating the expensive operation, it seems unlikely that persistent data structures in general can utilize amortization in run time analysis.

Queue: 2nd gen

The key idea to fixing the amortized cost is, spoiler alert, lazy evaluation. And this is supposedly the aha moment of this post as well. Instead of reversing the R when F becomes empty, we “reverse the list R and append it to F” when R’s length becomes greater than F’s length. The operation of F = F @ (List.rev R), where @ means concatenation, is written in quotes, because it is computed lazily. The intuition is that, say we have F and R both of length n, and pushing one element on R (or popping one element from F) will create a lazy computation of reversing R in O(n) time. However, this reversal will not actually be carried out until O(n) more pops are executed, such that the first element of (List.rev R) becomes exposed. Then, the O(n) cost of reversal can be amortized to the O(n) steps forcing the reversal, hence all operations become O(1).

Let’s compare this amortization with the 1st gen amortization. How come this works and the 1st gen didn’t work? The key distinction is we need to amortize expensive operations into future operations, not those in the history. The old method of forcing an expensive operation to happen multiple times to break amortization does not work anymore, because we must call pop O(n) times before we force a lazy reversal of O(n) to happen, and once it is reversed, we memoize that result, so there is no way to force that O(n) to duplicate itself. Meanwhile, memoization of list reversal doesn’t help with the 1st gen design, because we can add any element to R to make it a different list, which would constitute a different input to the List.rev function.

Sketch of proof

Having a working data structure is great, but what’s better is a framework for proving the amortization cost rigorously for not only queues, but also other persistent data structures. Traditionally we have 2 methods of proving amortized cost, the banker’s method of gaining and spending credits and the physicist’s method of assigning potential to data structures. Both of them are similar in that they calculate how much credit we have accumulated in the history up till a point, and then we can spend them on expensive operations that follow.

To prove amortization into the future for persistent data structures, we accumulate debits and try to pay for them in the future one step at a time. Here, debits would be the suspended costs of lazy computation. Adding n debit is essentially saying, “hey I need to do something O(n) here, please execute O(n) other operations before asking for my result.” Hence, you cannot ask for a particular result (like the reversed list R) if your debit account is positive; you must pay off all your debt first.

Then, proving these amortization bounds is simply a matter of choosing the right debits and payments for each operator in the banker’s method, and choosing the right debit potential function for each instance of the data structure. There are rigorous results given in the book, which will not be reproduced here. The curious reader can… just look them up in the book.

Scheduling

Lastly, another important point: we don’t really need amortization at all, instead we can just make everything O(1) worst case. The idea is that instead of lazily delaying the work until it is actually needed, we just actually complete it one step at a time. If reversing the list takes n steps, and we are amortizing that cost on n operations, why don’t we just actually do one step per operation? Since the operations come after the lazy computation is created, we can for sure do that. It was not possible before when the operations being amortized on came before the expensive computation. Of course, the code is going to be a little trickier and has more performance overhead, but theoretically the queue with scheduling achieves a worst case time complexity the same as the ephemeral (non-persistent) counterpart, everything is O(1).

This is not always possible, as for some other data structures it is not easy to find out what pieces of expensive computations we need to pay off per operation. Fortunately for queues it is not hard to imagine how to implement one.

What I left out

Of course, this is still a TLDR version of all the technical details of the implementation, which are all explained in length in the book. The proofs are still nontrivial to complete even given the ideas, and it takes some effort to make sure the scheduling is efficient. I also left out all actual implementations of the data structure in ML code in the book, since explaining the syntax would be heavy and would not contribute to understanding the key ideas.

Anyway, this is about the amount I want to discuss in this post; it’s also interesting to think about implementing maps and sets (like c++ STL ones) as functional data structures with the exact same asymptotic time complexity on common operators like add, remove, find for both and set, get for maps. Happy functional programming 🙂

Nerd sniping: prison break edition

First, relevant xkcd.

Nerd sniping is like a chain email; after being sniped, if you don’t nerd snipe 10 other people in 3 days, your IQ will drop by 30 points. – A wise man

Anyways here are three problems that sniped me throughout the years that I still remember. Solutions are in white font following the problems.

  1. There are 64 prisoners, 64 different types of hats, and one guard. The guard is going to play a game with these prisoners. First, the guard randomly picks a hat for each prisoner to put on, and any prisoner cannot see his/her own hat, but can see the hats of everyone else. Some of these hats might be of the same type. They all get one chance to guess their own hat type (individually without knowing other prisoners’ guesses), and if any prisoner gets it right, they all win the game. All prisoners can negotiate before starting the game, but cannot talk to each other afterwards. How do they win the game?
  2. There are 2 prisoners A and B, and one guard. The guard pulls out a chess board (8×8) and put one coin on each grid, either heads or tails. The guard points at a certain coin, and only A can see the choice (also the board with all coins). Then A gets to flip exactly one coin, then he leaves the room, and B comes in. By only looking at the 64 coins after A’s coin flip, can B figure out which coin the guard pointed at? Same conditions as before: negotiation only before the game starts.
  3. There are 2 prisoners A and B, and also the same guard. This time the guard pulls out 64 cards with numbers 1-64 written on each card, then scrambles the order and put them on a table, face up. A gets to see all cards and also can swap one pair of cards. Then all cards are concealed, keeping the order after A’s swap. B comes in, and the guard tells him a number between 1 and 64 inclusively, and B can try to find the corresponding card in 32 trials. In each trial, B can only reveal one card to see if it matches with the guard’s specified card. How can B find the card? Same conditions as before, blah blah blah.

Solutions:

Just kidding, I’m not going to write down the solutions. It wouldn’t exactly be nerd sniping if that’s the case :^)

Euler Path: 8 Lines Solution

The problem of Euler path marked a very fundamental moment in algorithm studies. When Euler posed the 7-bridge problem, there was no mathematical tool to solve it, hence he created the tool – graph theory. It sounds like how Newton invented calculus to solve his gravity problems and how Bernoulli invented calculus of variations to solve his brachistochrone problem. As I quote Sutherland,

“Well, I didn’t know it was hard.”

Problem statement

Say you have arrived in a new country with a bunch of islands and one way bridges between some pairs of islands. As a tourist you would like to visit all the bridges once and only once. Abstractly, an Euler path is a path that traverses all the edges once and only once. Is it always possible? If possible, how do we find such a path?

Existence of path

Unlike many other problems, we can determine the existence of the path without actually finding the path. There are 2 conditions for it. The first has to do with in and out degrees, and the second is about the connectivity of the graph.

For a directed graph, the in degree is the number of incoming edges to a vertex, and the out  degree is the number of outgoing edges to a vertex. Easily, if we have an Euler path, then there will be a start and end vertex. The start vertex will have an out degree – in degree of 1, the end vertex will have in degree – out degree of 1. Every other vertex will have an equal number of out degree and in degree, because if you have a different number, you obviously cannot go through all of those edges in one path. The only exception is when the start and end vertices are the same, in which case all the vertices have the same number of in and out degrees.

The second condition is that all vertices have to be weakly connected, meaning that by treating all edges as undirected edges, there exists a path between every pair of vertices. The only exception is for the vertices that have no edges – those can just be removed from the graph, and the Euler path trivially does not include them.

Given the above two properties, you can prove there is an Euler path by the following steps. First it is easy to see that if you start walking from the start vertex (out – in = 1) and removing edges as you walk through them, you can only end up at the end vertex (in – out = 1). Then, the remaining graph is full of cycles that can be visited through some vertex in the path we already have, and you just have to merge the cycles and the path to get the Euler path.

It is easier to see the other direction of the proof: if there is an Euler path, both conditions have to be met.

Implementing existence of path

As mentioned above, we code up the two conditions and check whether we can find a valid starting position, otherwise return -1. Vertices are numbered 0 to n-1. The graph is stored as an adjacency list, meaning that adj[i] has all the neighbors (edges go from i to those). So the out degree of i is adj[i].size().

int start(vector<vector<int> >& adj) {
    // condition 1: in and out degrees
    vector<int> deg(adj.size());
    for (int i = 0; i < adj.size(); i++) {
        deg[i] += adj[i].size();
        for (int x : adj[i])
            deg[x]--;
    }
    int ans = -1;
    for (int i = 0; i < deg.size(); i++)
        if ((ans != -1 && deg[i] != 0 && deg[i] != -1)
            || deg[i] > 1)
            return -1;
        else if (deg[i] == 1)
            ans = i;
    if (ans == -1)  // start and end vertices are the same 
        for (int i = 0; i < adj.size() && ans == -1; i++)
            if (!adj[i].empty())
                ans = i;
    if (ans == -1)  // there is no edge at all
        return -1;
    // condition 2: connectivity
    vector<bool> vis(adj.size());
    vector<int> bfs{ans};
    vis[ans] = true;
    for (int i = 0; i < bfs.size(); i++)
        for (int x : adj[bfs[i]])
            if (!vis[x]) {
                bfs.push_back(x);
                vis[x] = true;
            }
    for (int i = 0; i < adj.size(); i++)
        if (!vis[i] && !adj[i].empty())
            return -1;
    return ans;
}

That is slightly more clumsy that I would like it to be, but it should be clear. Basically just counting in and out degrees and running a BFS on the starting vertex.

Actually finding the path

As mentioned above in the sketch of proof, finding a path consists of the 3 steps:

  1. Walk from the start vertex, removing edges as we use them, until there is nowhere to go. Then we have a path from start to end.
  2. For the remaining edges, start at a vertex on the path and randomly walk until we go back to the same vertex. Then we have a cycle. Merge the cycle on the path. For example, if we had a path s->a->b->c->…->t and a cycle c->alpha->beta->c, we merge them to become s->a->b->c->alpha->beta->c->…->t.
  3. Repeat step 2 until there is no remaining edge.

Well, that does not sound very easy to write nor very efficient to run, if we literally implemented the above. How many lines of code would that be?

The 8 lines solution

The answer is 8. Here’s the code:

void euler(vector<vector<int> >& adj, vector<int>& ans, int pos) {
    while (!adj[pos].empty()) {
        int next = adj[pos].back();
        adj[pos].pop_back();
        euler(adj, ans, next);
    }
    ans.push_back(pos);
}

This is a very beautiful solution using recursion, in my opinion. I was quite surprised when I first saw this. Where is everything? Where is getting the main path? Where is getting the cycles? Where are we merging them?

The gist of this recursion is a post-order DFS, meaning that we visit the end of the graph first, and then backtrack. This comes from a very crucial observation: we can always be sure what could be at the end of the path, but not the front. It is important to know that the answer array is in reverse order of visit, i.e. we need to reverse it to get the Euler path. Let’s go through the stages of how the program works.

  1. Path from start to end vertex. The first time we push back is when we run out of outgoing edges, which can only be the case of the end vertex, with in-out = 1. In all other vertices, the numbers are the same so if you can go in, you can definitely go out of that vertex. Hence the first push back occurs with the end vertex, and at that point the program execution stack has the entire path.
  2. As we return from a recursive call of the function, we are essentially going back from the end to the start. If a vertex on the main path does not have any outgoing edge, we know we will visit it next, so we push it to the ans vector and return from the function.
  3. But if a vertex on the path does have an outgoing edge, that means there is at least one cycle including this vertex. Then in the next iteration of the while loop, we will visit one of the outgoing edges and start another round of recursion. Again, this recursion must end on the same vertex because it is the only one with in-out = 1. This recursion gets you a cycle and by the time it returns, the cycle would have been pushed to the ans array already, finishing the merge operation.

By studying this code, there is one interesting point to note. That is, the while loop will only be executed 0 to 2 times in any recursive call. It will only be 0 at the end vertex, and on the main path with no branches, it will be 1. On the main path with branches, it will only be 2 but not more regardless of the out degree, because surely the recursive call for cycle needs to end on that vertex but it will not end until all outgoing edges are used up. Therefore to the program, multiple cycles on one vertex is just one big cycle. On the other vertices on cycles, it works the same way whether they have branches or not.

In case you want to see how to run it, here is the main function I wrote to test it:

int main() {
    int n, q;
    cin >> n >> q;
    vector<vector<int> > adj(n);
    for (int i = 0; i < q; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
    }
    int s = start(adj);
    if (s == -1) {
        cout << "no path" << endl;
        return 0;
    }
    vector<int> ans;
    euler(adj, ans, s);
    reverse(ans.begin(), ans.end());
    for (int x : ans)
    cout << x << endl;
    return 0;
}

That’s it – a simple problem with a simple solution. Leetcode has one Euler path problem, and the algorithm in this blog post comes from the discussion of that problem. Everything is pretty much the same for undirected graphs – you just have to use different data structures to store the edges. The proof is mostly the same with the first condition now about odd/even number of edges at each vertex, as there is no distinction between in and out degree.

TIW: Subset Sum

Subset sum, aka the coin change problem, is a very specific problem that for some reason gets mentioned a lot. The problem is very simple: given some coins of different values, what values of items can we buy with a subset of the given coins? Sometimes the problem statement asks, can we buy an item of a certain price with our coins without change? If not, what is the minimum amount of change needed? What if we have a number of the same coins, or even an infinite number? Many of these variants can be solved with some tweaks to the same code. So let’s start with the most basic formulation: given a vector of positive integers, return a vector of all integers that are subset sums (i.e. there exists a subset in the input vector with a sum equal to that number).

There are many good ways to accomplish this task.

First attempt: iterative, boolean array

Let’s jump into the code directly:

vector<int> solution1(vector<int>& coins) {
    int total = 0;
    for (int x : coins)
        total += x;
    vector<bool> pos(total+1);
    pos[0] = true;
    for (int x : coins)
        for (int i = total; i >= 0; i--)
            if (pos[i])
                pos[i+x] = true;
    vector<int> ans;
    for (int i = 0; i <= total; i++)
        if (pos[i])
            ans.push_back(i);
    return ans; 
}

OK, so first we calculated the total values of the coins, then we know all possible subset sums are in the range [0, total]. Then we constructed a boolean vector with indices in the same range, to indicate what values we can form. The algorithm proceeds by adding a new coin every iteration. Say we used to be able to pay {0, 1, 2, 3, 4, 5, 6}, and we have a new coin of value 10. Then by adding the new coin in, we can pay {10, 11, 12, 13, 14, 15, 16}. But we can also not use the new coin and pay [0, 6], so we need to keep that in the original array. Then we combine the two.

The most important part here is the loop for (int i = total; i >= 0; i–). It counts down, because counting up will lead to a logic error. Think of it this way: you could already pay 0, and you add a new coin of value 2 in, so you set pos[2] = true. And then you count up and see pos[2] is true – you will then set pos[4] to be true, although you already used the new coin in that combination! Counting down avoids this problem because every new element you set to be true will not be visited again in this iteration.

Second attempt: iterative, set of integers

vector<int> solution2(vector<int>& coins) {
    set<int> pos;
    pos.insert(0);
    for (int x : coins)
        for (auto it = pos.rbegin(); it != pos.rend(); it++)
            pos.insert(x+*it);
    return vector<int>(pos.begin(), pos.end());
}

The basic idea is the same, but instead of indicating val as possible by setting pos[val] = true, we insert it to the set. The code is much more concise.

Third attempt: recursive, set of integers

void soln3helper(vector<int>& coins, int idx, int val, set<int>& pos) {
    if (idx == coins.size()) {
        pos.insert(val);
        return;
    }
    soln3helper(coins, idx+1, val, pos);
    soln3helper(coins, idx+1, val+coins[idx], pos);
}

vector<int> solution3(vector<int>& coins) {
    set<int> pos;
    soln3helper(coins, 0, 0, pos);
    return vector<int>(pos.begin(), pos.end());
}

This is basically DFS, if you treat whether using a coin or not as an edge and a combination of coins as a vertex in the graph, you will get a tree. At the leaf nodes of the tree, insert the sum of the combination into the set.

Why the three approaches?

Each approach has its pros and cons. The second one has the least amount of code, and is easy to understand. The first one might be marginally more efficient for small test cases, because it only uses vectors but not any sets. Note however that the boolean vector could be huge if the coins are of really big values, in which case it would be very inefficient. The third one is intriguing because there is not any branch cutting. It incidentally takes care of negative coin values, which the other solutions do not. That means our runtime will always be the worst case runtime, 2^n where n is the number of coins. But is it really that bad? Let’s compare the runtimes:

First attempt: O(n*sum)

This is obvious from the nested loops. How big is sum though? The lower bound is n, when all the numbers are 1. In that case we get O(n^2) overall runtime. There is however not an upper bound – any coin can be arbitrarily big, and in that case we are kind of screwed.

Second attempt: O(n*2^n)

In the ith loop, we have at most 2^(i+1) items in the set. Looping through them and all insertions cost O(2^(i+1)*log(2^(i+1))) = O(i*2^i). Summing from i = 0 to i = n-1, we can see that each later term is larger than the previous term by 2 times at least, so we can drop all the front terms and only take the last, and we get the answer.

Third attempt: O(2^n)

Very simple: on each level i, we have 2^i edges coming out, i.e. O(2^i) work. Summing them gives O(2^n) by the same argument.

From the runtime analyses, you can see either the first or the third algorithms could be optimal for different inputs. The second, however, is never asymptotically optimal. Even in the case of all 1s, it runs in O(n^2*log(n)). Hence it is most useful when you don’t really need that much performance but prefer pretty code instead (say, during an interview).

Fourth attempt: iterative, integer array

vector<int> solution4(vector<int>& coins) {
    vector<int> pos{0};
    for (int x : coins) {
        vector<int> newpos, next;
        for (int p : pos)
            newpos.push_back(p+x);
        int l = 0, r = 0;
        while (l < pos.size() || r < newpos.size()) {
            if (l == pos.size() ||
                (r != newpos.size() && pos[l] >= newpos[r])) {
                if (next.empty() || next.back() != newpos[r])
                    next.push_back(newpos[r]);
                r++;
            } else {
                if (next.empty() || next.back() != pos[l])
                    next.push_back(pos[l]);
                l++;
            }
        }
        pos.swap(next);
    }
    return pos;       
}

This is clearly a more elaborate approach to the problem. What we’re trying to achieve here is an O(2^n) algorithm that has a best case scenario of O(n^2), in the case of all 1s as input array. This is built mostly upon the second attempt, and we’re trying to eliminate the use of a set, since it inevitably gives us an extra time complexity factor. The way to accomplish this is to store the new values in a new vector, and merge the two vectors like in merge sort but removing all duplicates. In the ith iteration, we must only have O(2^i) amount of work, and summing through all i’s gives O(2^n).

Multiple coins

That’s about it, but one last discussion about multiple coins or infinite coins. Say we have the coins {3, 4, 7, 19} and we want to pay for an item of price 25, can we do it? Yes, 3+3+19 = 25. How do we do this? There are a few approaches, but here is one of them. I will build it on attempt 2:

bool multipleCoins(vector<int>& coins, int val) {
    set<int> pos;
    pos.insert(0);
    for (int x : coins)
        for (auto it = pos.rbegin(); it != pos.rend(); it++)
            for (int v = x+*it; v <= val; v += x)
                pos.insert(v);
    return pos.count(val);
}

Instead of pushing in only one value, we keep pushing in more values with more coins of the same type until we go over the desired value. Of course we can immediately return true if we find val, which could potentially save a lot of work. But this is good enough as an example to show how to modify the above versions to account for multiple coins.

Brief Intro to Segment Tree

Something I just learned – segment tree, is a data structure more advanced and generalized than binary indexed tree. Even though I just learned it and might not be qualified to discuss it yet, I’m pretty excited so who cares. So here’s an intro to segment tree from a noob point of view. Most of the content comes from this Codeforces blog by Al.Cash, but that blog assumes more prior knowledge and I will attempt to explain it from scratch.

Motivation

Member binary indexed tree? I member. For range sum query, BIT supports logarithmic time complexity for updating an element and querying the sum of any range. The way querying from l to r was done was a 2-phase process: first find the prefix sum of l-1 and r, then subtract the two. Pretty straight forward.

But what if we want to take minimum or maximum of the range, instead of just taking the sum? With BIT, we can only query the range from the beginning to a certain point. It is trivial really, refer to the code below:

// binary indexed tree for taking maximum of range [0, k)
void update(int v, int k, vector<int>& bit) {
    for (k++; k < bit.size(); k += k&(-k))
        bit[k] = max(bit[k], v);
}
int query(int k, vector<int>& bit) {
    int ans = 0;
    for (k++; k > 0; k -= k&(-k))
        ans = max(ans, bit[k]);
    return ans;
}

There are a few limitations with this approach. First, you can only update an element to a bigger integer, because there is no reverse operation for maximum like subtraction for addition. Second, you can only take maximum from the beginning to a certain element but not an arbitrary range. With summation that’s not a problem, as we can subtract sums of different ranges to get partial sums. But with taking minimum, you can’t just subtract minima can you?

Therefore, we need a data structure that is like BIT, but supports true range query, instead of prefix range query.

Basic specs

  1. Segment tree is like a binary indexed tree on steroids, all problems solved by BIT can be solved by segment tree.
  2. It takes 2 times the space of the original array (double of BIT).
  3. You can update any element of the original array in O(log(n)) time (same as BIT).
  4. You can look up any element of the original array in constant time (BIT takes log(n)).
  5. The original array can be of any data type (same as BIT).
  6. You can perform a range query of a certain associative function in O(log(n)) time (BIT can only make prefix range queries).

The last point is important. As far as a function is associative, you can use it with ST. It does not have to be commutative or reversible. A few examples:

  1. Addition. Just like BIT, given {1, 2, 3, 4, 5}, you can sum {2, 3, 4} in log(n) time, for example.
  2. Maximum. Given {1, 2, 3, 4, 5}, you can take maximum of {2, 3, 4} in log(n) time, for example.
  3. Matrix multiplication. Given {A1, A2, A3, A4, A5} of five matrices, you can get the product {A2, A3, A4} in log(n) time.

Associativity: say you have 3 variables a, b and c, and a function f that operates on two parameters. Then f(f(a, b), c) = f(a, f(b, c)) iff f is associative. Note that commutativity does not follow, f(a, b) does not necessarily equal f(b, a).

Associativity is important because for a range {a, b, c, d, e}, you can precompute {a, b} and {c, d, e}, then combine the two results.

Basic idea

Let’s borrow the sample from BIT: {1, 2, 3, 4, 5, 6, 7, 8}. Here’s a graph of what segment tree stores:

screen-shot-2017-01-01-at-2-36-30-pm

For comparison, here’s a graph of what binary indexed tree stores:

screen-shot-2017-01-01-at-2-39-26-pm

A few observations:

  1. ST is simply BIT with the whole table filled in, without any blanks.
  2. st[8..15] is our original array. In the second half of the ST array, we always store the original array.
  3. st[i] = st[i*2] + st[i*2+1]. Node i/2 is node i’s parent.
  4. To update any element changes the same number of nodes in the tree.
  5. We can sum up ranges directly. For example {3, 4, 5, 6, 7} = st[5]+st[6]+st[14].

It is easy to see that ST is an extension of BIT and supports direct range queries while BIT doesn’t. The only thing left now is how to actually implement it.

Code

In general, the function does not have to be addition of integers, so I will abstract it to any function that takes 2 structs and returns a struct.

struct data {
    int num;
    data() {
        num = 0;
    }
    data(int n) {
        num = n;
    }
};
data merge(data a, data b) {
    return data(a.num+b.num);
}

For simplicity I still made it a wrapper of addition of numbers. You can make it any function with any data, either minimum of long longs, multiplication of matrices etc.

Then, given an array, we need to build the segment tree.

vector<data> build(vector<data>& v) {
    int n = v.size();
    vector<data> st(n*2);
    for (int i = 0; i < v.size(); i++) st[i+n] = v[i]; for (int i = n-1; i > 0; i--)
        st[i] = merge(st[i*2], st[i*2+1]);
    return st;
}

Technically we do not need a build function if we have an update function, just like with BIT. We can just update the entries one by one. But that would take n*log(n) time, while this function is O(n), so this is not entirely useless.

The build function takes in the original array and makes an array double its size. Then, we copy the array to the second half of the tree. Constructing the tree is a bottom-up procedure, each time calculating the new sum from the two lower nodes (from observation 3). That’s it.

The update function is even simpler.

void update(data v, int k, vector<data>& st) {
    int n = st.size()/2;
    st[k+n] = v;
    for (int i = (k+n)/2; i > 0; i /= 2)
        st[i] = merge(st[i*2], st[i*2+1]);
}

It is a 4-line function. Technically you can make it 1 line, but why?? Here, we first update the entry in the second half of the tree, then we go to its parent iteratively by dividing by 2, until you hit 1, the root.

Now the actual hard part: queries. Given the range [l, r) from l to r-1, query the “sum” (could be product or any arbitrary function) of the range.

data query(int l, int r, vector<data>& st) {
    int n = st.size()/2;
    data ansl, ansr;
    for (l += n, r += n; l < r; l /= 2, r /= 2) {
        if (l%2 == 1) {
            ansl = merge(ansl, st[l]);
            l++;
        }
        if (r%2 == 1) {
            r--;
            ansr = merge(st[r], ansr);
        }
    }
    return merge(ansl, ansr);
}

I will only try my best to explain, but I will not go through everything in fine detail because it is too tedious. Let’s say we want to sum {2, 3, 4, 5, 6}, i.e. l = 1, r = 6 (because v[1] = 2, v[6-1] = 6). First we add n to both l and r, so we have l = 9, r = 14. Refer to the above graph, our goal is to sum st[9], st[5] and st[6]. In fact, in the first loop we will pick up st[9], because l is an odd number. After we add it to the left sum, we add 1 to l to denote that we have already added numbers in this subtree, so we have a smaller range of numbers to add. In the next loop, l and r are divided by 2 to go up a level in the tree, and become 5 and 7. Now both numbers are odd, and we will merge st[5] and st[6] to the left sum and right sum respectively. In the next iteration, l and r are both 3, meaning that the range to sum is empty, and the loop breaks.

To be honest, I don’t 100% understand why the conditions are l and r are odd. The idea could be that if l is even, that means both l and l+1 are in the range, and we would rather add the number at l/2 since it includes both, therefore we do not do anything when l is even. The only exception when not both l and l+1 are in the range is when r = l+1, but that means r is odd and we will add st[r-1], which is st[l]. Otherwise if l is odd, we might as well eliminate this subtree by adding st[l] and moving l to l+1. Everything should be mirrored and r should be checked to be even, but r is not included in the range [l, r), so the actual end point is r-1, so we check whether it is odd. Just like with low bits in BIT, you don’t actually need to understand it to use it; I bet you can’t implement a red black tree either but you still use set<> like an algorithm master anyway.

Important point: There are two ans variables, ansl and ansr, and they take sums from both parts respectively. This is to maintain the order of computation, in case the merge function is not commutative. In this case however it does not make any difference.

Array of arbitrary size

The above is explained with an array size of 8, which was sort of cheating, because you will most likely want an array of arbitrary size. Of course, you can add padding zeros at the end to make it a power of two.

n = v.size();
while (n != lowbit(n))  // n&(-n)
    n += lowbit(n);
v.resize(n);

That would do. This is because when n is a power of two, its low bit is itself. However this is not even needed; the original code, although designed for powers of two, will work for any vector size n. There is a short explanation on the original blog, but the full proof should be too complicated and not useful to know. We only need to know that it automagically works for any n, and happily copy paste code.

Sample: matrix multiplication

Build, update and query are copy pasted, so they are omitted in the sample.

struct data {
    int m, n;
    vector<vector<int> > A;
    data() {
        m = 0;
        n = 0;
    }
    data(const vector<vector<int> >& B) {
        A = B;
        m = A.size();
        n = a[0].size();
    }
    void print() {
        for (auto r : A) {
            for (auto c : r)
                cout << c << " ";
            cout << endl;
        }
    }
};
data merge(data a, data b) {
    if (a.m*a.n == 0)
        return b;
    if (b.m*b.n == 0)
        return b;
    int m = a.m, n = b.n, l = a.n;
    vector<vector<int> > C(m, vector<int>(n));
    for (int i = 0; i < m; i++)
        for (int j = 0; j < n; j++)
            for (int k = 0; k < l; k++)
                C[i][j] += a.A[i][k]*b.A[k][j];
    return data(C);
}
vector<data> build(vector<data>& v) {...}
void update(data v, int k, vector<data>& st) {...}
data query(int l, int r, vector<data>& st) {...}
int main() {
    vector<data> v;
    v.push_back(data({{2, 0}, {0, 2}}));
    v.push_back(data({{1, 1, 4}, {4, 2, 2}}));
    v.push_back(data({{-2, 0}, {1, 4}, {1, 2}}));
    v.push_back(data({{0, 1}, {1, 0}}));
    vector<data> st = build(v);
    int l, r;
    while (cin >> l >> r)
        query(l, r, st).print();
    return 0;
}

You can mostly just copy paste the three functions and modify the data and merge definitions to fit your applications. I was motivated to study segment trees because of this problem on Codeforces. I was not able to do this problem during the contest, but neither could tourist, so whatever.

New Year and Old Subsequence

This is a slightly more advanced application of segment tree. The problem: given a string of digits, return the minimum number of digits that need to be removed such that there is a subsequence of “2017” but not “2016”. If 2017 is not a subsequence, print -1. More precisely, after a string of length up to 200,000 characters is given, there are up to 200,000 queries of the range l to r, and we need to answer what is the minimum number of digits to remove such that the sequence [l, r] has a subsequence of “2017” but not “2016”. The algorithm is described here. The gist is that for an interval of a string, all that we need to know is given we already have a certain prefix of 2017, how many digits do we need to erase in this interval so that we will have a longer prefix of 2017. For example, if the current digit is 6, then given “201” or “2017”, we will have to erase one digit (the 6) to ensure we have a prefix of “201” or “2017” without any “2016”. Otherwise the 6 doesn’t matter. Here’s my code.

struct data {
    unsigned int dp[5][5];
    data() {
        for (int j = 0; j < 5; j++)
            for (int i = 0; i <= j; i++)
                dp[j][i] = INT_MAX;
    }
    void clear() {
        for (int i = 0; i < 5; i++)
        dp[i][i] = 0;
    }
};

data merge(const data& a, const data& b) {
    data temp;
    for (int j = 0; j < 5; j++)
        for (int i = 0; i <= j; i++)
            for (int k = i; k <= j; k++)
                temp.dp[j][i] = min(temp.dp[j][i], a.dp[k][i]+b.dp[j][k]);
    return temp;
}

int main() {
    int n, q;
    string s;
    cin >> n >> q >> s;
    vector<data> st(2*n);
    for (int i = 0; i < n; i++) {
        st[i+n].clear();
        if (s[i] == '2') {
            st[i+n].dp[0][0] = 1;
            st[i+n].dp[1][0] = 0;
        } else if (s[i] == '0') {
            st[i+n].dp[1][1] = 1;
            st[i+n].dp[2][1] = 0;
        } else if (s[i] == '1') {
            st[i+n].dp[2][2] = 1;
            st[i+n].dp[3][2] = 0;
        } else if (s[i] == '7') {
            st[i+n].dp[3][3] = 1;
            st[i+n].dp[4][3] = 0;
        } else if (s[i] == '6') {
            st[i+n].dp[3][3] = st[i+n].dp[4][4] = 1;
        }
    }
    // build segment tree
    for (int i = n-1; i; i--)
        st[i] = merge(st[i<<1], st[i<<1|1]);
    while (q--) {
        int l, r;
        cin >> l >> r;
        // query from l-1 to r
        data ansl, ansr;
        ansl.clear();
        ansr.clear();
        for (l += n-1, r += n; l < r; l >>= 1, r >>= 1) {
            if (l&1) {
                ansl = merge(ansl, st[l]);
                l++;
            }
            if (r&1) {
                r--;
                ansr = merge(st[r], ansr);
            }
        }
        int ans = merge(ansl, ansr).dp[4][0];
        cout << (ans == INT_MAX ? -1 : ans) << endl;
    }
    return 0;
}

If you paid attention to the code, you will see I embedded the build and query functions in the main function. Also you will notice I used integer array in C style instead of vectors in C++ style. There are also changes in details such as replacing *2 by <<1 (left shift) and +1 by |1 (bitwise or). I hate to do this, but the judge on Codeforces is very demanding and my normal coding style got TLE.