I have a really big datbase that I want to sample from. The traditional way to sample a set of values goes like this:

To sample *n* values from a set of values of size *T*:

+ if *uniform() < n / T*, take the current value, and subtract 1 from *n*

+ subtract 1 from *T*

+ repeat until either *n* or *T* is 0

That will produce a “fair” or “unbiased” sample.

The problem is, if my original set is really huge (mine is 60 million and growing), I don’t have time to roll the die 60 million times. (!) Instead, I want to roll the die and pick an item — exactly *n* times. Then keep skipping and picking until I have all *n* values.

So I thought about it, and got some help from http://math.stackexchange.com

It boils down to this:

+ If I picked *n* items randomly *all at once*, where would the first one land? That is, *min({r_1 … r_n})*. A helpful fellow at math.stackexchange boiled it down to this equation:

x = 1 – (1 – r) ** (1 / n)

The long story (ommitted here) is that the distribution would be 1 – *(1 – x)* to the *n*th power. Then solve for *x*. Pretty easy.

Then….

+ If I generate a uniform random number and plug it in for *r*, this is distributed the same as *min({r_1 … r_n})* — the same way that the lowest item would fall. Voila! I’ve just simulated picking the first item as if I had randomly selected all *n*.

+ So I skip over that many items in the list, pick that one, and then….

+ Repeat until *n* is 0

That way, if I have a big database (like Mongo), I can skip, find_one, skip, find_one, etc. Until I have all the items I need.

The only problem I’m having is that my implementation favors the first and last element in the list. But I can live with that.

In Python 2.7, my implementation looks like:

def skip(n): """ Produce a random number with the same distribution as min({r_0, ... r_n}) to see where the next smallest one is """ r = numpy.random.uniform() return 1.0 - (1.0 - r) ** (1.0 / n) def sample(T, n): """ Take n items from a list of size T """ t = T i = 0 while t > 0 and n > 0: s = skip(n) * (t - n + 1) i += s yield int(i) % T i += 1 t -= s + 1 n -= 1 if __name__ == '__main__': t = [0] * 100 for c in xrange(10000): for i in sample(len(t), 10): try: t[i] += 1 except: print c, i pprint.pprint(t)