0%

蓄水池算法

蓄水池采样算法(Reservoir Sampling)

蓄水池采样算法是非常常用的一种流式数据处理算法

问题

大致描述:

给出一个数据流,这个数据流的长度很大或未知,并且对该数据流中的数据只能访问一次。请写出一个随机选择算法,使得数据流中所有数据被选中的概率相等。

一些实际问题

  1. 从 100,000 分调查报告中抽取1000份进行统计。
  2. 从一本很厚的电话簿中抽取1000人进行姓氏统计。
  3. 从google搜索"Ken Thompson",从中抽取100个结果查看哪些是今年的。

这些都是很基本的采样问题。

既然说的采样问题,最重要的就是做到公平,也就是保证每个元素被采样到的概率是相同的。

对于第一个问题,我们已经知道数据的规模,通过算法生成[0, 100,000-1]间的随机数1000个,并且保证不重复即可。再取出对应的元素即可。

但是对于第二和第三个问题,我们不知道数据的整体规模是多大。可能有人会想到,可以先对数据进行一次遍历,计算出数据的规模N,然后按照第一题的方法采样即可。这当然可以,但是并不好。因为这可能需要遍历两次,需要花两次的时间。也可以尝试估算数据的规模,但是这样得到的采样数据可能并不平均。

问题严格定义

给你一个长度为N的链表。N很大,但你不知道N有多大。你的任务是从这N个元素中随机抽取k个元素。你只能遍历这个链表一次。你的算法必须保证取出的元素恰好有k个,且它们是完全随机的(出现概率相等的)。

解法

蓄水池算法:

蓄水池算法是针对从一个序列中随机抽取不重复的K个数,保证每个数被抽取到的概率都为K/N这个问题构建的。

做法:

首先构造一个可以容纳k个元素的蓄水池(数组),将序列前k个元素直接放入蓄水池数组中。

然后从第i = k+1个数据开始,以k/i(k<i<=n)的概率决定它是否进入到蓄水池中。蓄水池中的k个元素被替换出去的概率是相同的。

当遍历完所有元素之后,数组中剩下的元素即为所需采取的样本。

证明

对于第i个数(i <= k)。在k步之前,被选中的概率为1。当走到第k+1步时,被k+1个元素替换的概率=第k+1个元素被选中的概率i被替换的概率,即为 k/(k+1) 1/k = 1/(k+1)。则不被第k+1个元素替换的概率为1 - 1/(k+1) = k/(k+1)。依次类推,不被K+2个元素替换的概率为1-k/(k+2) * 1/k = (k+1)/(k+2)。则运行到第n步时,第i个数仍保留的概率=被选中的概率不被替换的概率,即: \[ 1 \times \frac {k}{k+1}\times \frac {k+1}{k+2}\times \frac {k+2}{k+3}\times ...\times \frac {n-1}{n} = \frac {k}{n} \] 对于第j个数(j>k)。我们知道,在第j步被选中的概率为k/j。不被j+1个元素替换的概率为1 - k/(j+1) 1/k = j/(j+1)。则运行到第n步时,第i个数仍保留的概率=被选中的概率*不被替换的概率,即: \[ \frac {k}{j}\times \frac {j}{j+1}\times \frac {j+1}{j+2}\times \frac {j+2}{j+3}\times ... \times \frac {n-1}{n} = \frac {k}{n} \] 所以对于中每个元素,被保留的概率都为k/n。

实现

python3版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# 蓄水池算法实现.py
import random


class ReservoirSampling(object):

def sample(self, node, k):
data = []
# 计数器
counter = 0
while node:
counter += 1
# 前k个元素直接放入
if counter <= k:
data.append(node)
else:
# 判断第j个元素是否留下
if random.randint(1, counter) <= k:
# 判断替换掉哪个元素
removed_idx = random.randint(0, k - 1)
# 替换该元素,放入新元素
data[removed_idx] = node
# 如果不留下,就继续
# 访问下一个node
node = next(node)
return data


def main():
class ListNode(object):
def __init__(self, val):
self.val = val
self.next = None

def __next__(self):
return self.next

head = ListNode(0)
cur = head
for i in range(1, 11):
cur.next = ListNode(i)
cur = cur.next
rs = ReservoirSampling()
res = rs.sample(head, 10)
for node in res:
print(node.val)


if __name__ == '__main__':
main()

实战题目

382. 链表随机节点