tf.map_fn()函数定义如下:
tf.map_fn(
fn,
elems,
dtype=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)
把函数当参数传进去,可以直接用lambda。将参数elems从第一维展开,进行map处理。一个简单的例子:
def fun1(a):
if a.shape == (1,2):
print("ok")
else:
raise Exception("shape error")
return a * 2
var1 = np.random.randint(10, size=(2,1,2))
var2 = np.random.randint(10, size=(1,2))
print(var1)
# fun1(var1) 执行错误 shape error
#fun1(var2) 可以执行
# 执行
rtn = tf.map_fn(fun1, var1)
# 结果打印
with tf.Session() as sess:
result = sess.run(rtn)
print(result)
最后的执行结果:
[[[3 3]]
[[5 0]]]
ok
[[[ 6 6]]
[[10 0]]]