np.expand_dims()用于扩展数组的形状
参数:
values:数组
axis:表示在该位置添加数据
例子:
原始数据:
import numpy as np
a = np.array([[[1,2,3],[4,5,6]]])
print(a)
print(a.shape)
"""
[[[1 2 3]
[4 5 6]]]
(1, 2, 3)
"""
1.np.expand_dims(a, axis=0)表示在0位置添加数据,转换结果如下:
b = np.expand_dims(a, axis=0)
print(b)
print(b.shape)
"""
[[[[1 2 3]
[4 5 6]]]]
(1, 1, 2, 3)
"""
2.np.expand_dims(a, axis=1)表示在1位置添加数据,转换结果如下:
b = np.expand_dims(a, axis=1)
print(b)
print(b.shape)
"""
[[[[1 2 3]
[4 5 6]]]]
(1, 1, 2, 3)
"""
3.np.expand_dims(a, axis=2)表示在2位置添加数据,转换结果如下:
b = np.expand_dims(a, axis=2)
print(b)
print(b.shape)
"""
[[[[1 2 3]]
[[4 5 6]]]]
(1, 2, 1, 3)
"""
4.np.expand_dims(a, axis=3)表示在3位置添加数据,转换结果如下:
b = np.expand_dims(a, axis=3)
print(b)
print(b.shape)
"""
[[[[1]
[2]
[3]]
[[4]
[5]
[6]]]]
(1, 2, 3, 1)
"""
5.能在(1,2,3)中插入的位置总共为4个,再添加就会出现以下的警告,要不然也会在后面某一处提示AxisError。
参考:
https://blog.csdn.net/hong615771420/article/details/83448878