-
np.argmax()
解释:接收两个参数,第一个为np数组,第二个为axis,在数组的第axis轴上求最大值,返回数组中最大值的索引值,当一组中同时出现几个最大值时,返回第一个最大值的索引值。看例子:
import numpy as np
a = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
],
[
[21, 6, -5, 2],
[9, 36, 2, 8],
[3, 7, 79, 1]
]
])
b = np.argmax(a, axis = 0)
c = np.argmax(a, axis = 1)
d = np.argmax(a, axis = 2)
print(b)
print(c)
print(d)
输出为:
>>b
[[2 1 0 0]
[0 2 0 0]
[1 0 2 0]]
>>c
[[1 2 0 1]
[1 0 2 1]
[0 1 2 1]]
>>d
[[1 0 1]
[1 0 2]
[0 1 2]]
分析:对于一个3*3*4的矩阵,当axis = 0
时,在第一个维度上作比较,即三个矩阵作比较,返回的是一个3*4的矩阵,同理,axis = 1
时在第二个维度上作比较,返回的是一个3*4的矩阵,axis = 2
时在第三个维度上作比较,返回的是一个3*3的矩阵。可以发现,输出相较于输入总会减少一维,具体减少哪一维由axis决定,可以用这个来验证。
-
torch.max()
解释:传入两个参数,一个torch.tensor,一个dim,用法与np.max相似,不过这个返回两个tensor,第一个是沿着dim维的最大值,另一个是对应的索引。同时出现几个最大值时,返回最后一个最大值的索引值。