Jumping through Large Database for a Fair Sample

I have a really big datbase that I want to sample from. The traditional way to sample a set of values goes like this:

To sample n values from a set of values of size T:

+ if uniform() < n / T, take the current value, and subtract 1 from n
+ subtract 1 from T
+ repeat until either n or T is 0

That will produce a “fair” or “unbiased” sample.

The problem is, if my original set is really huge (mine is 60 million and growing), I don’t have time to roll the die 60 million times. (!) Instead, I want to roll the die and pick an item — exactly n times. Then keep skipping and picking until I have all n values.

So I thought about it, and got some help from http://math.stackexchange.com

It boils down to this:

+ If I picked n items randomly *all at once*, where would the first one land? That is, min({r_1 … r_n}). A helpful fellow at math.stackexchange boiled it down to this equation:

x = 1 – (1 – r) ** (1 / n)

The long story (ommitted here) is that the distribution would be 1 – (1 – x) to the nth power. Then solve for x. Pretty easy.

Then….

+ If I generate a uniform random number and plug it in for r, this is distributed the same as min({r_1 … r_n}) — the same way that the lowest item would fall. Voila! I’ve just simulated picking the first item as if I had randomly selected all n.

+ So I skip over that many items in the list, pick that one, and then….

+ Repeat until n is 0

That way, if I have a big database (like Mongo), I can skip, find_one, skip, find_one, etc. Until I have all the items I need.

The only problem I’m having is that my implementation favors the first and last element in the list. But I can live with that.

In Python 2.7, my implementation looks like:

def skip(n):
    """
    Produce a random number with the same distribution as
    min({r_0, ... r_n}) to see where the next smallest one is
    """
    r = numpy.random.uniform()
    return 1.0 - (1.0 - r) ** (1.0 / n)

def sample(T, n):
    """
    Take n items from a list of size T
    """
    t = T
    i = 0
    while t > 0 and n > 0:
        s = skip(n) * (t - n + 1)
        i += s
        yield int(i) % T
        i += 1
        t -= s + 1
        n -= 1

if __name__ == '__main__':

    t = [0] * 100
    for c in xrange(10000):
        for i in sample(len(t), 10):
            try:
                t[i] += 1
            except:
                print c, i

    pprint.pprint(t)
Advertisements

Leave a comment

Filed under computer algorithms, computer scaling, mongodb, python, utility

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s