关键方法一览
方法 | 作用 | 区别 |
---|---|---|
cat | 合并 | 保持原有维度的数量 |
stack | 合并 | 原有维度数量加1 |
split | 分割 | 按照长度去分割 |
chunk | 分割 | 等分 |
要点细述
cat
cat
是concatenate(连接)的缩写,而不是指(猫)。作用是把2个tensor按照特定的维度连接起来。
要求:除被拼接的维度外,其他维度必须相同
Code Demo
import torch
a=torch.randn(3,4) #随机生成一个shape(3,4)的tensort
b=torch.randn(2,4) #随机生成一个shape(2,4)的tensor
torch.cat([a,b],dim=0)
#返回一个shape(5,4)的tensor
#把a和b拼接成一个shape(5,4)的tensor,
#可理解为沿着行增加的方向(即纵向)拼接
stack
stack
会增加一个新的维度,来表示拼接后的2个tensor,直观些理解的话,咱们不妨把一个2维的tensor理解成一张长方形的纸张,cat
相当于是把两张纸缝合在一起,形成一张更大的纸,而stack
相当于是把两张纸上下堆叠在一起。
要求:两个tensor拼接前的形状完全一致
Code Demo
a=torch.randn(3,4)
b=torch.randn(3,4)
c=torch.stack([a,b],dim=0)
#返回一个shape(2,3,4)的tensor,新增的维度2分别指向a和b
d=torch.stack([a,b],dim=1)
#返回一个shape(3,2,4)的tensor,新增的维度2分别指向相应的a的第i行和b的第i行
助记:
这里的关键词参数dim的理解和cat方法中有些区别。
cat方法中可以理解为原tensor的维度,dim=0,就是沿着原来的0轴进行拼接,dim=1,就是沿着原来的1轴进行拼接。
stack方法中的dim则是指向新增维度的位置,dim=0,就是在新形成的tensor的维度的第0个位置新插入维度
split
split
是根据长度去拆分tensor
Code Demo
a=torch.randn(3,4)
a.split([1,2],dim=0)
#把维度0按照长度[1,2]拆分,形成2个tensor,
#shape(1,4)和shape(2,4)
a.split([2,2],dim=1)
#把维度1按照长度[2,2]拆分,形成2个tensor,
#shape(3,2)和shape(3,2)
chunk
chunk
可以理解为均等分的split,但是当维度长度不能被等分份数整除时,虽然不会报错,但可能结果与预期的不一样,建议只在可以被整除的情况下运用
Code Demo
a=torch.randn(4,6)
a.chunk(2,dim=0)
#返回一个shape(2,6)的tensor
a.chunk(2,dim=1)
#返回一个shape(4,3)的tensor