1.简介
Prioritized DQN 是为了解决当在memory中均匀采样时候学习效率低下的问题。原因主要有两个:
1.我们想让new transition立马用于更新,因为这样的new experience对于explore很重要。
2.我们想让large td-error的transition立马用于更新(比如有99次失败的经历和1次成功的经历,我们希望立马学习这个成功的经历)
显然uniform sampling无法做到这两点。
于是便有了伟大的Prioritized Experience Replay.下面我将分享自己学习这篇论文的时候一些经验。请读完论文和简单介绍后,如有困惑,再阅读以下部分。
2.关键点
Prioritized DQN能够成功的主要原因有两个:sum tree这种数据结构带来的采样的O(log n)的高效率,和Weighted Importance sampling的正确估计。后者,我现在还没有完全搞明白原理。
我简单由谈下自己对于sum tree数据结构的理解。 sum tree存储的元素是样本的优先级,其思想是根据累积概率密度(因此叫sum)来抽取样本。从最左方开始,优先级累积逐渐增大,如果我们的段>左子孩子,(递归地)就在右子孩子中寻找(这时候要做减法,以便又是新的累积优先级)。
如果把累积优先级(离散地)画出来,我们就会发现,高优先级对应的直线段斜率最大,被抽取到的概率最大。(可以以下图为例,自己在每个段中取数字进行验证)。
3.代码解读
原代码注释较少,我这里列出几个点,方便大家阅读代码。
- 代码实现的是DQN, 而不是Double DQN。
- 在插入new transition更新sum tree的时候, 是根据新样本与原来位置的样本的优先级差来更新。(详见SumTree.add)
- 在memory中插入new transition的时候,给予new transition最大的优先级,因为我们想让new experience立马用于学习。(详见Memory.store)
- 在memroy中抽取n个samples后,我们会根据nn计算出来的TD-error来更新那些抽取到的样本的优先级,这样的话new transition就不会一直被学习。(详见Memory.batch_update)。
大家最好照着源码自己敲一编(时间大概2~3小时),我这里给出自己在搬砖过程中写的一点注释(也可以自己下载,照着看)。
import numpy as np
import tensorflow as tf
np.random.seed(1)
tf.set_random_seed(1)
class SumTree(object):
data_pointer = 0
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1)
# [------------ Parent nodes -------------][------ leaves to recode priority ----------]
# size: capacity - 1 size: capacity
self.data = np.zeros(capacity, dtype=object)
# [------------ data frame ---------------]
# size: capacity
# memory store_transition的时候使用
def add(self, p, data): # p is the new priority, data is transition
# memory batch_update的时候使用
def update(self, tree_idx, p):
# memory 分段采样的时候使用
def get_leaf(self, v):
parent_idx = 0
while True:
cl_idx = 2 * parent_idx + 1
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree): # 此时parent就是叶子结点
leaf_idx = parent_idx
break
else:
if v <= self.tree[cl_idx]: # <= 左子孩子,就向左前进
parent_idx = cl_idx
else:
v -= self.tree[cl_idx] # > 右子孩子,需要重新当作一颗累积树,因此要减去左子孩子的值
parent_idx = cr_idx
data_idx = leaf_idx - (self.capacity + 1)
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
@property
def total_p(self):
return self.tree[0] # the root
class Memory(object):
epsilon = 0.01 # small amount to avoid zero priority
alpha = 0.6 # [0, 1] convet the importance of TD error to priority
beta = 0.4 # importance sampling, from intial value increasing to 1
beta_increment_per_sampling = 0.001
abs_err_upper = 1 # clipped abs error
def __init__(self, capacity):
self.tree = SumTree(capacity)
# 向sum tree的transitions 中加入 new transition
def store(self, transition):
# 从sum tree中采取n个样本
def sample(self, n):
# 更新采样过的样本的priority(基于abs_error)
def batch_update(self, tree_idx, abs_error):