蓄水池算法

蓄水池算法(Reservoir Sampling)

这个算法真的很奇妙,它的核心是一个数学证明。外延,或者说应用场景是:

  1. $C_n^k$,也就是从大小为 n 的样本集中随机取 k 个不同的样本
  2. 流式数据,或者说无法直接根据索引拿到数据(更加不可能一遍加载到内存)

算法描述

算法的描述其实很简单:维基百科:水塘抽样

1
2
3
4
5
6
7
问题描述:从包含n个不同的项目的集合S中随机选取k个不同的样本。
算法:
从S中取首k个放入[水塘]中
对每个S[j]项(j>=k,数组从0开始):
true随机产生一个范围从0到j的整数r
true若r<k则把水塘中的第r项换成S[j]项
最后得到的水塘就是抽样结果

这个算法保证了每一项最后可能存在于水塘中的概率都是一样的。

单看算法,你肯定不知道为什么是等概率,其实数学证明并不难,请看下面的证明:

数学证明

我们把样本分为两类:

  1. 一类是首 k 个,它们一开始就在水塘中
  2. 一类是其他,它们一开始并不在水塘中

我们发现两个简单的逻辑:

  1. 对于水塘中的样本,只要随机数不选到该样本,该样本就不会被替换
  2. 水塘的某个项一旦被替换,就不可能再回到水塘,不会出现被替换掉,然后再回到水塘的局面,这样就保证了问题不会进一步变得复杂。所以:某个项被保留的概率 = 被选中到水塘的概率 * 后续不被替换的概率

分类讨论,首 k 个样本最终存在于水塘中的概率,和其余样本最终存在于水塘中的概率:

  1. 首 k 个样本,随便选一个做研究对象。被选中到水塘的概率为:1。(数组从 1 开始)从 j=k+1 开始考虑替换,第一次不被替换的概率是$\frac{k}{k+1}$,第二次不被替换的概率是$\frac{k+1}{k+2}$,第三次...,一直到最后一次不被替换的概率是$\frac{n-1}{n}$。

    所以该项被保留的概率 = $1\times\frac{k}{k+1}\times\frac{k+1}{k+2}\times\frac{k+2}{k+3}\times\cdots\times\frac{n-1}{n}=\frac{k}{n}$
  2. 一开始不在水塘中的那一部分,随便选一个做研究对象。被选中到水塘的概率为:$\frac{k}{j}$,后续不被替换的概率$\frac{j}{j+1}$,一直到$\frac{n-1}{n}$。

    所以该项被保留的概率 = $\frac{k}{j}\times\frac{j}{j+1}\times\cdots\frac{n-1}{n}=\frac{k}{n}$

到此我们就证明了所以样本最终存在于水塘中的概率都是$\frac{k}{n}$,这也完全符合了我们的数学期望。

代码

弄个流式数据我们这里没有条件,只能用伪代码模拟一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public Data[] reservoirSampling(int k, DataStream dataStream){
Data[] reservoir = new int[k];

// init pool
for(int i=0;i<reservoir.length;i++){
reservoir[i] = dataStream.getCurrentData();
dataStream.toNext();
}

Random random = new Random();
for(int i=k;!dataStream.isFinish();i++){
int d = random.nextInt(i+1);
if(d<k){
reservoir[d] = dataStream.getCurrentData();
}
dataStream.toNext();
}

return reservoir;
}