index_select()函数有两种用法。
第一种是将被切片的函数作为参数传入index_select()中
torch.index_select(input, dim, index, out=None)
还有一种是调用张量内置的index_select()函数。
input.index_select(dim, index)
index_select()函数的作用是针对张量input,在它的dim维度上切取index指定的范围切片。
参数:
input:被操作的张量
dim:维度
index:一维Tensor,表示索引下标的范围
例如
import torch
a = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7]])
b = torch.index_select(a, 0, torch.tensor([1]))
print(b)
c = torch.index_select(a, 1, torch.tensor([1,3]))
print(c)
输出为
这里维度dim从0开始算,则b表示在第0维(即行)上,切下下标为1的行;c表示在第1维(即列)上,切下下标为1和3的列。