在numpy里取矩阵数据非常方便,比如:
a = np.random.random((5, 4))
indices = np.array([0,2,4])
print(a)
#array([[0.47122875, 0.37836802, 0.18210801, 0.341471 ],
# [0.56551837, 0.27328607, 0.50911876, 0.01179739],
# [0.75350208, 0.9967817 , 0.94043434, 0.15640884],
# [0.09511502, 0.96345098, 0.6500849 , 0.04084285],
# [0.93815553, 0.04821088, 0.10792035, 0.27093746]])
print(a[indices])
#array([[0.47122875, 0.37836802, 0.18210801, 0.341471 ],
# [0.75350208, 0.9967817 , 0.94043434, 0.15640884],
# [0.93815553, 0.04821088, 0.10792035, 0.27093746]])
这样就把矩阵a中的1,3,5行取出来了。
如果是只取某一维中单个索引的数据可以直接写成tensor[:, 2]
, 但如果要提取的索引不连续的话,在tensorflow里面的用法就要用到tf.gather.
import tensorflow as tf
sess = tf.Session()
b = tf.gather(tf.constant(a), indices)
sess.run(b)
#Output
array([[0.47122875, 0.37836802, 0.18210801, 0.341471 ],
[0.75350208, 0.9967817 , 0.94043434, 0.15640884],
[0.93815553, 0.04821088, 0.10792035, 0.27093746]])
tf.gather_nd允许在多维上进行索引:
matrix中直接通过坐标取数(索引维度与tensor维度相同):
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']
取第二行和第一行:
indices = [[1], [0]]
params = [['a', 'b'], ['c', 'd']]
output = [['c', 'd'], ['a', 'b']]
3维tensor的结果:
indices = [[1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[['a1', 'b1'], ['c1', 'd1']]]
indices = [[0, 1], [1, 0]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [['c0', 'd0'], ['a1', 'b1']]
另外还有tf.batch_gather的用法如下:
tf.batch_gather(params, indices, name=None)
Gather slices from params
according to indices
with leading batch dims.
This operation assumes that the leading dimensions of indices
are dense,
and the gathers on the axis corresponding to the last dimension of indices
.
#tf.batch_gather按如下运算:
result[i1, ..., in] = params[i1, ..., in-1, indices[i1, ..., in]]
Therefore params
should be a Tensor of shape [A1, ..., AN, B1, ..., BM],
indices
should be a Tensor of shape [A1, ..., AN-1, C] and result
will be
a Tensor of size [A1, ..., AN-1, C, B1, ..., BM]
.
如果索引是一维的tensor,结果和tf.gather
是一样的.