参考 https://blog.csdn.net/loseinvain/article/details/79638183
https://blog.csdn.net/chengshuhao1991/article/details/78545723
输入两个二维数组如下:
a = tf.constant([[1,2,3],[3,4,5]]) # shape (2,3)
b = tf.constant([[7,8,9],[10,11,12]]) # shape (2,3)
tf.concat相当于numpy中的np.concatenate函数,用于将两个张量在某一个维度(axis)合并起来,例如:
ab1 = tf.concat([a,b], axis=0) # shape(4,3)
[[ 1 2 3]
[ 3 4 5]
[ 7 8 9]
[10 11 12]]
ab2 = tf.concat([a,b], axis=1) # shape(2,6)
[[ 1 2 3 7 8 9]
[ 3 4 5 10 11 12]]
tf.stack其作用类似于tf.concat,都是拼接两个张量,而不同之处在于,tf.concat拼接的是除了拼接维度axis外其他维度的shape完全相同的张量,并且产生的张量的阶数不会发生变化,而tf.stack则会在新的张量阶上拼接,产生的张量的阶数将会增加。
tf.stack()就是以指定的轴axis,将一个维度为R的张量数组转变成一个维度为R+1的张量。即将张量数组以指定的轴,提高一个维度。
假设要转变的张量数组values(如[x, y])的长度为N,其中的每个张量(如x, y)的形状为(A, B, C)。
如果轴axis=0,则转变后的张量的形状为(N, A, B, C)。
如果轴axis=1,则转变后的张量的形状为(A, N, B, C)。
如果轴axis=2,则转变后的张量的形状为(A, B, N, C)。其它情况依次类推。
a = tf.constant([[1,2,3],[3,4,5]])
b = tf.constant([[7,8,9],[10,11,12]])
ab3 = tf.stack([a,b], axis=0)
[[[ 1 2 3]
[ 3 4 5]]
[[ 7 8 9]
[10 11 12]]]
ab4 = tf.stack([a,b], axis=1)
[[[ 1 2 3]
[ 7 8 9]]
[[ 3 4 5]
[10 11 12]]]
ab5 = tf.stack([a,b], axis=2)
[[[ 1 7]
[ 2 8]
[ 3 9]]
[[ 3 10]
[ 4 11]
[ 5 12]]]
如
‘x’ is [[1,1,1,1],[2,2,2,2],[3,3,3,3]],形状是(3,4),维度为2
‘y’ is [[4,4,4,4],[5,5,5,5],[6,6,6,6]],形状是(3,4),维度为2
stack([x,y]) => [[[1,1,1,1],[2,2,2,2],[3,3,3,3]], [[4,4,4,4],[5,5,5,5],[6,6,6,6]]] # axis的值默认为0。输出的形状为(2, 3, 4)
stack([x,y],axis=1) => [[[1,1,1,1],[4,4,4,4]],[[2,2,2,2],[5,5,5,5]],[[3,3,3,3],[6,6,6,6]]] # axis的值为1。输出的形状为(3, 2, 4)
stack([x,y],axis=2) => [[[1,4],[1,4],[1,4],[1,4]],[[2,5],[2,5],[2,5],[2,5]],[[3,6],[3,6],[3,6],[3,6]]]# axis的值为2。输出的形状为(3, 4, 2)
axis可这样理解:stack要将一组N个相同形状的张量(如[x, y])提高一个维度。axis就是在和原来形状相同的张量里,将axis指定的维度里每一个元素用拼接后的数组代替。如axis=2,表示在指定的第2个维度(数值),将原来的每一个数值(如1), 用x和y对应位置的数值拼接而成的数组(如[1, 4])代替,即从(A, B)转变为(A, B, N)。
对两个二维数组的拼接,axis=0则表示在三维层面上进行拼接,操作单位为二维矩阵;axis=1则表示在二维层面上进行拼接,操作单位为行向量,一一对应的进行拼接;axis=2则表示在一维层面上进行拼接,操作单位为数值,进行point-wise的拼接。
而tf.unstack与tf.stack的操作相反,是将一个高阶数的张量在某个axis上分解为低阶数的张量,例如:
a1 = tf.unstack(ab3, axis=0)
[array([[1, 2, 3],
[3, 4, 5]], dtype=int32),
array([[7, 8, 9],
[10, 11, 12]], dtype=int32)]
a2 = tf.unstack(ab3, axis=1)
[array([[1, 2, 3],
[7, 8, 9]], dtype=int32),
array([[3, 4, 5],
[10, 11, 12]], dtype=int32)]
对tf.concat()的也可以做unstack操作
a3 = tf.unstack(ab1, axis=0)
[array([1, 2, 3], dtype=int32),
array([3, 4, 5], dtype=int32),
array([7, 8, 9], dtype=int32),
array([10, 11, 12], dtype=int32)]
a4 = tf.unstack(ab1, axis=2)
[array([[1, 7], [2, 8]], dtype=int32),
array([[3, 9], [3, 10]], dtype=int32),
array([[4, 11], [5, 12]], dtype=int32)]