tf.concat函数:函数功能比较简单,主要用于连接两个数组
参数:
values:需要连接的数组,注意数组的维度应该一致
axis:从哪个维度来连接数组
例子:
1.一维数组
import tensorflow as tf
if __name__ == "__main__":
a = [1,2,3]
b = [4,5,6]
c = tf.concat([a,b],0)
sess = tf.InteractiveSession()
print(sess.run(c)) #[1 2 3 4 5 6]
注意:axis参数不能超过数组的维度。如果超过数组的维度,如下:
c = tf.concat([a,b],1)
则会报,ValueError: Shape must be at least rank 2 but is rank 1 for 'concat',意思是数组至少是二维,axis才能为1。
2.二维数组
a = [[1,1],[2,2],[3,3]]
b = [[4,4],[5,5],[6,6]]
c = tf.concat([a,b],0)
print(sess.run(c))
"""
[[1 1]
[2 2]
[3 3]
[4 4]
[5 5]
[6 6]]
"""
c = tf.concat([a,b],1) #等价于tf.concat([a,b],-1)
print(sess.run(c))
"""
[[1 1 4 4]
[2 2 5 5]
[3 3 6 6]]
"""
3.三维数组
a = [[[1,1],[2,2]],[[3,3],[4,4]]]
b = [[[5,5]],[[6,6]]]
c = tf.concat([a,b],1)
print(sess.run(c))
"""
[[[1 1]
[2 2]
[5 5]]
[[3 3]
[4 4]
[6 6]]]
"""
a = [[1, 2], [3, 4]] b = [[5, 6]] c = np.concatenate((a, b), axis=None) """ [[1,2,3,4,5,6]] """
5.如何来判断数组是否在该个维度上的shape是相同的呢?
其实很简单,我们根据tf.concat的axis参数来去数组的[],0表示去掉最外面的一层,1去掉两层,以此类推,下面举例说明一下。
如:最后一个例子中的c = tf.concat([a,b],1),我们先将a去掉最外面两层[],变成了[1,1],[2,2]和[3,3],[4,4]],然后再将b去掉最外面两层[],变成了[5,5]和[6,6],此时再进行concat,可以发现此时的shape是相等的。
参考:
https://blog.csdn.net/sinat_29957455/article/details/86100641