1. 普通的slice
In [2]: x = torch.arange(12).reshape(4,3)
In [3]: x
Out[3]:
tensor([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.]])
In [4]: x.dtype
Out[4]: torch.float32
In [5]: y = x[2:, :]
In [6]: y
Out[6]:
tensor([[ 6., 7., 8.],
[ 9., 10., 11.]])
这个时候,变量x
和y
共享内存位置,如果将 y
的值改变, x
的值也会改变:
改变方式 1
In [15]: y[:,:] = 666
In [16]: y
Out[16]:
tensor([[ 666., 666., 666.],
[ 666., 666., 666.]])
In [17]: x
Out[17]:
tensor([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 666., 666., 666.],
[ 666., 666., 666.]])
改变方式 2
In [12]: y.fill_(0)
Out[12]:
tensor([[ 0., 0., 0.],
[ 0., 0., 0.]])
In [13]: y
Out[13]:
tensor([[ 0., 0., 0.],
[ 0., 0., 0.]])
In [14]: x
Out[14]:
tensor([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 0., 0., 0.],
[ 0., 0., 0.]])
2. Mask(dtype=torch.uint8) 作为slice的时候,不会有上述效果
In [2]: x = torch.arange(12).reshape(4, -1)
In [3]: x
Out[3]:
tensor([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.]])
In [4]: mask = x > 5
In [5]: mask
Out[5]:
tensor([[ 0, 0, 0],
[ 0, 0, 0],
[ 1, 1, 1],
[ 1, 1, 1]], dtype=torch.uint8)
In [6]: y = x[mask]
In [7]: y
Out[7]: tensor([ 6., 7., 8., 9., 10., 11.])
mask的数据类型为 torch.uint8
, 用其作为slice的时候,得到的结果就会 展开成一个一维的数组, 并且改变 y
的值, x
的值也不会发生变化。
In [8]: y[:] = 0
In [9]: y
Out[9]: tensor([ 0., 0., 0., 0., 0., 0.])
In [10]: x
Out[10]:
tensor([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.]])
In [11]: y.fill_(666)
Out[11]: tensor([ 666., 666., 666., 666., 666., 666.])
In [12]: y
Out[12]: tensor([ 666., 666., 666., 666., 666., 666.])
In [13]: x
Out[13]:
tensor([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.]])