所有人都非常擅长均匀抽样,因为几乎所有的编程语言都内置了均匀分布中生成一个0到1的实数的方法,本文中我们将此方法记作...
那如果是从一个带有权重的集合里抽取呢?往往就没有那么简单了,特别是如果集合很大,而且里面的元素会变的情况。
比如说:在一个集合里个元素,每个元素有不同的权重,每个元素被抽中的概率为元素的权重占总权重的比例,现从这个集合中有放回的随机抽取k个元素,你会怎么做呢?
加权采样能发挥的本质,其实就是如何巧妙利用来实现相应的加权效果。
拍脑袋版
先介绍几个拍脑袋就能想出来的方法,这些方法不需要任何的经验,所有人都能想到。
放大集合法
比如说我们有一个序列[a,b,c,d]
中有 4
个元素,权重分别是 [1,2,3,4]
, 我们要从这个这个序列来进行加权采样,我们完全可以将权重转化为序列中的元素,然后将权重都置为1,把序列变成[a,b,b,c,c,c,d,d,d,d]
, 然后就可以愉快的用int(rand(0,1) * sum([1,2,3,4]))
获取元素下标来进行采样了。
这个方法非常简单有效,但是显然它在构造新序列和求和时都需要很大的开销,时间复杂度为。不过如果序列元素是稳定的,在构建了新序列和求好和之后采样效率会非常高,时间复杂度为
累加比例法
其实和上面那种方法差不多...
就是得到权重的累加和[1,3,6,10]
,然后愉快的来一发int(rand(0,1) *10 )
,再通过二分查找得到元素下标即可。 很显然,因为要求和,所以复杂度也是,如果序列稳定,求好了和的话就是.
调包法
在python标准库中提供了random.choices
这个抽样方法,是可以传入权重的,于是你十分愉快的拿起键盘一顿敲:
random.choices(population=[a,b,c,d], weights=[1,2,3,4], k=1)
十分简单的就实现了,仔细分析一下这个的实现,其实和上面的累加比例法是一样的。 当你传入累加权重 cum_weights
, 时间复杂度是, 不传的话因为要算,复杂度也为.
同样不适用于动态序列的情况。
高斯分布近似法
实际上就是通过一个均值为0的高斯分布去近似这个排好序的序列.
然后通过高斯分布来采样... 这样首先效率低, 而且不是精确基于分布的采样,另外如果加了元素发生变化还得调整 ,很是麻烦...所以肯定不用这个....
数据科学家版
np.random.choice
可能因为numpy
用的太娴熟了,很多数据工作在涉及到抽样,上来就是np.random.choice
,但是这玩意的时间复杂度更上面random
的方法一样是, 在供选择的集合很大时同样会有效率低下的问题...
以效率著称的numpy
居然效率不行???孤为之奈何....
那么有什么方法可以在动态序列上获得 的效果呢?
sum tree 算法
SumTree是一种特殊的二叉树,其中父节点的值等于其子节点的值之和,如下图所示,根节点的值是所有叶子的值的和:10 = 3+7,3 = 1+2,以此类推......
叶子节点代表元素,其上的值是元素的权重。也就是说,我们只要构造这样一棵树,就可以愉快的进行采样了。随便来一发rand(0,1)*10
, 比如说是2,我们从根节点的左孩子还是查找,
发现3比2大,于是往左边查,然后发现比3的左还是1大,就往3的左孩子找,就找到了2这个节点,对应的就是b
这个元素。 在构建好这棵树后,显然,我们每次更新元素的内容,只需要更新这个节点的父节点,也就是只要修改 个节点,查找过程和动态修改过程都是 , 于是我们终于找到了适合动态变化的集合的加权采样方法...
附上Paper原作者的SumTree实现:
import numpy
class SumTree:
write = 0
def __init__(self, capacity):
self.capacity = capacity
self.tree = numpy.zeros( 2*capacity - 1 )
self.data = numpy.zeros( capacity, dtype=object )
def _propagate(self, idx, change):
parent = (idx - 1) // 2
self.tree[parent] += change
if parent != 0:
self._propagate(parent, change)
def _retrieve(self, idx, s):
left = 2 * idx + 1
right = left + 1
if left >= len(self.tree):
return idx
if s <= self.tree[left]:
return self._retrieve(left, s)
else:
return self._retrieve(right, s-self.tree[left])
def total(self):
return self.tree[0]
def add(self, p, data):
idx = self.write + self.capacity - 1
self.data[self.write] = data
self.update(idx, p)
self.write += 1
if self.write >= self.capacity:
self.write = 0
def update(self, idx, p):
change = p - self.tree[idx]
self.tree[idx] = p
self._propagate(idx, change)
def get(self, s):
idx = self._retrieve(0, s)
dataIdx = idx - self.capacity + 1
return (idx, self.tree[idx], self.data[dataIdx])