2020-05-14pytorch之stack、cat、transpose、unqueeze等

stack
使用stack是为了保留两个信息: 序列(先后)和 张量矩阵信息。比如在循环神经网络中,网络的输出数据通常是:包含了n个数据大小[batch_size, num_outputs]的list,这个和[n, batch_size, num_outputs]是完全不一样的!!!!不利于计算,需要使用stack进行拼接,保留–[1.时间步]和–[2.张量的矩阵乘积属性]。

官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

形式:
outputs = torch.stack(inputs, dim=0) → Tensor
重点

  1. 函数中的输入inputs只允许是list或tuple;且序列内部的张量元素,必须shape相等
    ----举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape

  2. dim是选择生成的维度,必须满足0<=dim<len(outputs);len(outputs)是输出后的tensor的维度大小
    例子:

x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x3 = torch.tensor([[13,23,33],[23,33,43]],dtype=torch.int)
x4 = torch.tensor([[14,24,34],[24,34,44]],dtype=torch.int)

inputs = [x1, x2, x3, x4]
In [19]: torch.stack(inputs, dim=0)
Out[19]: 
tensor([[[11, 21, 31],
         [21, 31, 41]],

        [[12, 22, 32],
         [22, 32, 42]],

        [[13, 23, 33],
         [23, 33, 43]],

        [[14, 24, 34],
         [24, 34, 44]]], dtype=torch.int32)

In [21]: torch.stack(inputs, dim=1)
Out[21]: 
tensor([[[11, 21, 31],
         [12, 22, 32],
         [13, 23, 33],
         [14, 24, 34]],

        [[21, 31, 41],
         [22, 32, 42],
         [23, 33, 43],
         [24, 34, 44]]], dtype=torch.int32)

In [20]: torch.stack(inputs, dim=2)
Out[20]: 
tensor([[[11, 12, 13, 14],
         [21, 22, 23, 24],
         [31, 32, 33, 34]],

        [[21, 22, 23, 24],
         [31, 32, 33, 34],
         [41, 42, 43, 44]]], dtype=torch.int32)

aa = torch.tensor([[[1,2,3],[4,5,6],[7, 8,9]]])
bb = torch.tensor([[[11, 21, 31],[41,51,61],[71,81,91]]])
cc = torch.tensor([[[101,201,301],[401,501,601],[701,801,901]]])

inputs1 = [aa, bb, cc]
In [29]: torch.stack(inputs1, dim=0)
Out[29]: 
tensor([[[[  1,   2,   3],
          [  4,   5,   6],
          [  7,   8,   9]]],

        [[[ 11,  21,  31],
          [ 41,  51,  61],
          [ 71,  81,  91]]],

        [[[101, 201, 301],
          [401, 501, 601],
          [701, 801, 901]]]])

In [30]: torch.stack(inputs1, dim=1)
Out[30]: 
tensor([[[[  1,   2,   3],
          [  4,   5,   6],
          [  7,   8,   9]],

         [[ 11,  21,  31],
          [ 41,  51,  61],
          [ 71,  81,  91]],

         [[101, 201, 301],
          [401, 501, 601],
          [701, 801, 901]]]])

In [31]: torch.stack(inputs1, dim=2)
Out[31]: 
tensor([[[[  1,   2,   3],
          [ 11,  21,  31],
          [101, 201, 301]],

         [[  4,   5,   6],
          [ 41,  51,  61],
          [401, 501, 601]],

         [[  7,   8,   9],
          [ 71,  81,  91],
          [701, 801, 901]]]])

In [32]: torch.stack(inputs1, dim=3)
Out[32]: 
tensor([[[[  1,  11, 101],
          [  2,  21, 201],
          [  3,  31, 301]],

         [[  4,  41, 401],
          [  5,  51, 501],
          [  6,  61, 601]],

         [[  7,  71, 701],
          [  8,  81, 801],
          [  9,  91, 901]]]])

In [33]: torch.stack(inputs1, dim=-1)
Out[33]: 
tensor([[[[  1,  11, 101],
          [  2,  21, 201],
          [  3,  31, 301]],

         [[  4,  41, 401],
          [  5,  51, 501],
          [  6,  61, 601]],

         [[  7,  71, 701],
          [  8,  81, 801],
          [  9,  91, 901]]]])

Cat
对数据沿着某一维度进行拼接。cat后数据的总维数不变.

In [34]: x = torch.randn(2,3)
In [35]: y = torch.randn(1,3)
In [37]: print(x, '\n', y)
tensor([[ 1.8932,  0.8820, -0.3152],
        [ 0.4488,  1.7583, -0.0939]]) 
 tensor([[-1.0298,  0.8602, -0.5422]])

In [38]: torch.cat((x, y), dim=0)
Out[38]: 
tensor([[ 1.8932,  0.8820, -0.3152],
        [ 0.4488,  1.7583, -0.0939],
        [-1.0298,  0.8602, -0.5422]])

transpose
transpose ,交换维度

In [39]: x = torch.randn(2, 3)
In [40]: print(x)
tensor([[ 1.5418,  0.8280, -0.8068],
        [-0.3803, -1.1618,  1.4929]])

In [41]: x.transpose(0, 1)
Out[41]: 
tensor([[ 1.5418, -0.3803],
        [ 0.8280, -1.1618],
        [-0.8068,  1.4929]])

In [42]: x.transpose(1, 0)
Out[42]: 
tensor([[ 1.5418, -0.3803],
        [ 0.8280, -1.1618],
        [-0.8068,  1.4929]])

permute
permute,适合多维数据,permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

In [43]: x = torch.randn(2,3,4)
In [44]: xp = x.permute(1, 0, 2)

In [45]: print(x)
tensor([[[-0.4044, -0.4237,  0.2973, -1.5864],
         [ 0.7312, -0.9954, -1.2718,  0.0916],
         [ 0.3418,  1.1162,  0.8982,  0.6203]],

        [[ 0.9823, -1.3540,  1.0551,  1.5960],
         [ 1.5930, -0.3035, -0.3781,  1.3462],
         [ 1.1224,  0.6163, -1.3140, -1.1987]]])

In [46]: print(xp)
tensor([[[-0.4044, -0.4237,  0.2973, -1.5864],
         [ 0.9823, -1.3540,  1.0551,  1.5960]],

        [[ 0.7312, -0.9954, -1.2718,  0.0916],
         [ 1.5930, -0.3035, -0.3781,  1.3462]],

        [[ 0.3418,  1.1162,  0.8982,  0.6203],
         [ 1.1224,  0.6163, -1.3140, -1.1987]]])

squeeze 和 unsqueeze
squeeze(dim_n), 压缩,即去掉元素数量为1的dim_n维度。同理unsqueeze(dim_n),增加dim_n维度,元素数量为1。

# 定义张量
import torch

b = torch.Tensor(2,1)
b.shape
Out[28]: torch.Size([2, 1])

# 不加参数,去掉所有为元素个数为1的维度
b_ = b.squeeze()
b_.shape
Out[30]: torch.Size([2])

# 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
b_ = b.squeeze(0)
b_.shape 
Out[32]: torch.Size([2, 1])

# 这样就可以了
b_ = b.squeeze(1)
b_.shape
Out[34]: torch.Size([2])

# 增加一个维度
b_ = b.unsqueeze(2)
b_.shape
Out[36]: torch.Size([2, 1, 1])

**self.scatter(dim, index, src) **
从张量src中按照index张量中指定的索引位置写入self张量的值。对于一个三维张量,self更新为:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

为了保证scatter填充的有效性,需要注意:
(1)self张量在dim方向上的长度不小于source张量,且在其它轴方向的长度与source张量一般相同。这里的一般是指:scatter操作本身有broadcast机制。
(2)index张量的shape一般与source ,从而定义了每个source元素的填充位置。这里的一般是指broadcast机制下的例外情况。

import torch
a = torch.arange(10).reshape(2,5).float()
b = torch.zeros(3, 5))
index = torch.LongTensor([[1, 2, 1, 1, 2], [2, 0, 2, 1, 0]])
b_= b.scatter(dim=0, index=index,src=a)
print(b_)

# tensor([[0, 6, 0, 0, 9],
#        [0, 0, 2, 8, 0],
#        [5, 1, 7, 0, 4]])

a = torch.arange(10).reshape(2,5).float()
#tensor([[0., 1., 2., 3., 4.],
#        [5., 6., 7., 8., 9.]])
ind = torch.LongTensor([[1, 2, 1, 1, 2]])
c = b.scatter(0, ind, a)
#tensor([[0., 0., 0., 0., 0.],
#        [0., 0., 2., 3., 0.],
#        [0., 1., 0., 0., 4.]])

scatter函数的一个典型应用就是在分类问题中,将目标标签转换为one-hot编码形式,如:

labels = torch.LongTensor([1,3])
targets = torch.zeros(2, 5)
targets.scatter(dim=1, index=labels.unsqueeze(-1), src=torch.tensor(1))
# 注意dim=1,即逐样本的进行列填充
# 返回值为 tensor([[0, 1, 0, 0, 0],
#                 [0, 0, 0, 1, 0]])

gather
函数torch.gather(input, dim, index, out=None) → Tensor
沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.
对一个 3 维张量,输出可以定义为:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

Parameters:

  • input (Tensor) – 源张量
  • dim (int) – 索引的轴
  • index (LongTensor) – 聚合元素的下标(index需要是torch.longTensor类型)
  • out (Tensor, optional) – 目标张量

使用说明举例:

  1. dim = 1
import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  28.,  22.,  27.,   0.]],

        [[ 26.,  10.,  20.,  29.,  18.],
         [  5.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,2],
                          [0,0,0,0,0],
                          [1,1,1,1,1]],
                        [[1,2,2,2,2],
                         [0,0,0,0,0],
                         [2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
'''
True
tensor([[[ 18.,  26.,  22.,   1.,   0.],
         [ 18.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.]],

        [[  5.,  29.,  10.,   0.,  22.],
         [ 26.,  10.,  20.,  29.,  18.],
         [ 10.,  29.,  10.,   0.,  22.]]])
可以看到沿着dim=1,也就是列的时候。输出tensor第一页内容,
第一行分别是 按照index指定的,
input tensor的第一页 
第一列的下标为0的元素 第二列的下标为1元素 第三列的下标为2的元素,第四列下标为0元素,
第五列下标为2元素
index-->0,1,2,0,2    output--> 18.,  26.,  22.,   1.,   0.
'''
  1. dim =2
c = torch.gather(a, 2,index)
print(c)
'''
tensor([[[ 18.,   5.,   7.,  18.,   7.],
         [  3.,   3.,   3.,   3.,   3.],
         [ 28.,  28.,  28.,  28.,  28.]],

        [[ 10.,  20.,  20.,  20.,  20.],
         [  5.,   5.,   5.,   5.,   5.],
         [ 10.,  10.,  10.,  10.,  10.]]])
dim = 2的时候就安装 行 聚合了。参照上面的举一反三。
'''
  1. dim = 0
index2 = torch.LongTensor([[[0,1,1,0,1],
                          [0,1,1,1,1],
                          [1,1,1,1,1]],
                        [[1,0,0,0,0],
                         [0,0,0,0,0],
                         [1,1,0,0,0]]])
d = torch.gather(a, 0,index2)
print(d)
'''
tensor([[[ 18.,  10.,  20.,   1.,  18.],
         [  3.,  24.,  26.,  21.,   3.],
         [ 10.,  29.,  10.,   0.,  22.]],

        [[ 26.,   5.,   7.,   1.,   1.],
         [  3.,  26.,   9.,   7.,   9.],
         [ 10.,  29.,  22.,  27.,   0.]]])
这个有点特殊,dim = 0的时候(三维情况下),是从不同的页收集元素的。
这里举的例子只有两页。所有index在0,1两个之间选择。
输出的矩阵元素也是按照index的指定。分别在第一页和第二页之间跳着选的。
index [0,1,1,0,1]的意思就是。
在第一页选这个位置的元素,在第二页选这个位置的元素,在第二页选,第一页选,第二页选。
'''

转载或参考链接:
https://blog.csdn.net/excellent_sun/article/details/95175823
https://www.cnblogs.com/yifdu25/p/9399047.html
https://www.cnblogs.com/dogecheng/p/11938009.html
https://www.jianshu.com/p/5d1f8cd5fe31

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,311评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,339评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,671评论 0 342
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,252评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,253评论 5 371
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,031评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,340评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,973评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,466评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,937评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,039评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,701评论 4 323
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,254评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,259评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,485评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,497评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,786评论 2 345