Fast RNG in an Interval
https://arxiv.org/abs/1805.10941 - Fast Random Integer Generation in an Interval
Just read this interesting little paper recently. The original paper is already quite readable, but perhaps I can give a more readable write-up.
Problem Statement
Say you want to generate a random number in range [0, s)
, but you only have a random number generator that gives you a random number in range [0, 2^L)
(inclusive, exclusive). One simple thing you can do is to first generate a number x
, divide that by s
and take the remainder. Another thing you can do is to scale the range down by something like this: x * (s / 2^L)
, with some floating point math, casting, whatever that works. Both ways will give you a resulting integer in the specified range.
But these are not “correct”, in a sense that they don’t generate random integers with a uniform distribution. Say, s = 3
and 2^L = 4
, then you will always end up with one number being generated with probability 1/2, the other two numbers 1/4. Given 4 equally likely inputs, you just cannot convert that to 3 cases with equal probability. More generally, these simple approaches cannot work when s is not a power of 2.
First Attempt at Fixing Statistical Biases
To fix that, you will need to reject some numbers and try again. Like in the above example, when you get the number 3, you shuffle again, until you get any number from 0 to 2. Then, all outcomes are equally likely.
More generally, you need to throw away 2^L mod s
numbers, so that the rest will be divisible by s
. Let’s call that number r
, for remainder. So you can throw away the first r
numbers and use the first approach of taking remainder, as shown in this first attempt (pseudocode):
r = (2^L - s) mod s // 2^L is too large, so we subtract s
x = rand()
while x < r do
x = rand()
return x mod s
That’s a perfectly fine solution, and in fact it has been used in some popular standard libraries (e.g. GNU C++). However, division is a slow operation compared to others like multiplication, addition and branching, and in this function we are always doing two divisions (mod). If we can somehow cut down on our divisions, our function may run a lot faster.
Reducing number of divisions
It turns out we can do just that, with just a simple twist. Instead of getting rid of the first r
numbers, we get rid of the last r
numbers. And we can verify whether x
is in the last r
numbers like so:
x = rand ()
x_mod_s = x mod s
while x - x_mod_s > 2^L - s do
x = rand ()
x_mod_s = x mod s
return x_mod_s
The greater-than comparison on line 3 is a little tricky. It’s mathematically the same as comparing x - x_mod_s + s
with 2^L
, but we do this instead because you can’t express 2^L
with L
number of bits. So basically, the check is saying if the next multiple of s
after x
is larger than 2^L
, then x
is in the last r
numbers and must be thrown away. We never actually calculate r
, but with a little cleverness we manage to do the same check.
How many divisions are we doing here? Well, at least one on line 2, and possibly 0 or many more, depending on how many times the loop is run. Since we’re rejecting less than half of the possible outcomes (we’re at least keeping s
and at most rejecting s - 1
), we have at least 1/2 chance of breaking out of the loop each time, which means the expected number of loops is at most 1 (0 \* 1/2 + 1 \* 1/4 + 2 \* 1/8 ... = 1
). So we know that the expected number of divisions is at worst 2, equal to that of the previous attempt. But most of the time, the expected number is a lot closer to 1 (e.g. when s
is small), so this can theoretically be almost a 2x speed up.
So that’s pretty cool. But can we do even better?
Finally, Fast Random Integer
Remember other than taking remainders, there’s also the scaling approach x * (s / 2^L)
? It turns out if you rewrite that as (x * s) / 2^L
, it becomes quite efficient to compute, because computers can “divide” by a power of two by just chopping off bits from the right. Plus, a lot of hardware has support for getting the full multiplication results, so we don’t have to worry about x * s
overflowing. In the approach using mod, we inevitably need one expensive division, but here we don’t anymore, due to quirks of having a denominator of power of 2. So this direction seems promising, but again we have to fix the statistical biases.
So let’s investigate how to do that with our toy example of s
= 3, 2^L
= 4. Let’s look at what happens to all possible values of x
.
x |
s * x |
(s * x) / 2^L |
(s * x) mod 2^L |
---|---|---|---|
0 | 0 | 0 | 0 |
1 | 3 | 0 | 3 |
2 | 6 | 1 | 2 |
3 | 9 | 2 | 1 |
Essentially we have s
intervals of size 2^L
, and each interval maps to one single unique outcome. In this case, [0,4)
maps to 0, [4, 8)
maps to 1, and [8, 12)
maps to 2. From the third column, we have two cases mapping to 0, and we’d like to get rid of one of them.
Note that the fundamental reason behind this uneven distribution is because 2^L
is not divisible by s
, so any contiguous range of 2^L
numbers will contain a variable number of multiples of s
. That menas we can fix that by rejecting r
numbers in each range! More specifically, if we reject the first r
numbers in each interval, then each interval will contain the same number of multiples of s
. In the above example, the mapping becomes [1, 4)
maps to 0, [5, 8)
maps to 1, and [9, 12)
maps to 2. Fair and square!
Let’s put that in pseudocode:
r = (2^L - s) mod s
x = rand ()
x_s = x * s
x_s_mod = lowest_n_bits x_s L // equivalent to x_s mod 2^L
while x_s_mod < r do
x = rand ()
x_s = x * s
x_s_mod = lowest_n_bits x_s L
return shift_right x_s L // equivalent to x_s / 2^L
Now that would work, and it would take exactly 1 expensive division on line 1 to compute r
every single time. That beats both of the above algorithms! But wait, we can do even better! Since r < s
, we can first check x_s_mod
against s
, and only compute r
if that check fails. This is the algorithm proposed in the paper. It looks something like this:
x = rand ()
x_s = x * s
x_s_mod = lowest_n_bits x_s L
if x_s_mod < s then
r = (2^L - s) mod s
while x_s_mod < r do
x = rand ()
x_s = x * s
x_s_mod = lowest_n_bits x_s L
return shift_right x_s L
Now the number of expensive divisions is either 0 or 1, with some probability depending on s
and 2^L
. This looks clearly faster than the other algorithms, and experiments in the paper confirmed that. But as often is the case, performance comes at the cost of less readable code. Also in this case, we’re relying on hardware support for full multiplication results, so the code is less portable and in reality looks pretty low level and messy. Go and Swift have adopted this, deciding the tradeoff worthy, according to the author’s blog (https://lemire.me/blog/2019/09/28/doubling-the-speed-of-stduniform_int_distribution-in-the-gnu-c-library/), C++ may also use this soon.
How Many Divisions Exactly?
There’s still one last part we haven’t figured out - we know the expected number of divisions is between 0 and 1, but what exactly is it? In other words, how many multiples of s
, in the range [0, s * 2^L)
, has a remainder less than s
when dividing by 2^L
? To people with more number theory background, this is probably obvious. But starting from scratch, it can take quite a lot of work to prove, so I’ll just sketch the intuitions.
It’s a well known fact that if p
and q
are co-prime (no common factors other than 1), then the numbers { 0, p mod q, 2p mod q, 3p mod q ... (q-1) p mod q }
will be exactly 0
to q-1
. This is because if there is any repeated number, then we have a * p mod q = b * p mod q
(assuming a > b
), which indicates (a - b) * p mod q = 0
. But we know that 0 < a - b < q
, and p
has no common factor with q
, so if we multiply those two together, it cannot be a multiple of q
. So it’s impossible to have duplicates, and multiples of p
will evenly distribute among [0, q)
when taken mod q
.
Now if s
and 2^L
are co-prime, there will be exactly s
number of multiples of s
that has a remainder ranging from 0 to s - 1
. That means the expected number of divisions in this case is s / 2^L
.
If they aren’t co-prime, that means s is divisible by some power of 2. Say s = s' * 2^k
, where s'
is odd. Then, s * 2^(L-k) = s' * 2^L
will be 0 mod 2^L
. So your multiples of s mod q
will go back to 0 after 2^(L-k)
times. And you have 2^k
iterations of that. So if you go through the final count, it goes 2^k
, followed by 2^k - 1
number of 0s, rinse and repeat. How many are below s
? You have s'
number of nonzero counts, each one equal to 2^k
- it’s again, unsurprisingly, s
. So the expected number of divisions is still indeed s / 2^L
.
Final Thoughts
Earlier I said each time you need to throw away 2^L mod s
numbers to make an even distribution, but that’s not completely necessary. For example, if s = 5
and 2^L = 8
, you don’t have to fully reject 3 cases. In fact, you can save up those little randomness for the next iteration. In the next iteration, say you get into 1 of the 3 cases again. Then, combined with the 3 cases you saved up last time, you are now in 9 equally likely events. If you are in the first 5, then you can safely return that value without introducing biases. However, this is only useful when generating the random bit strings are really expensive, which is totally not the case in non-cryptographic use cases.
One last note - we have established that the expected number of divisions is s / 2^L
. As s
gets close to 2^L
, it seems like our code can become slower. But I think that’s not necessarily the case, because the time division takes is probably variable as well, if the hardware component uses any sort of short-circuiting at all. When s
is close to 2^L
, 2^L mod s
is essentially one or two subtractions plus some branching, which can theoretically be done really fast. So, given my educated guess/pure speculation, s / 2^L
growing isn’t a real concern.