TensorFlow tf.argmax()函数

xiaoxiao2021-02-28  33

TensorFlow tf.argmax()函数

tf.argmax(input, axis=None, name=None, dimension=None) 对矩阵按行或列计算最大值 四个参数: 1.input:输入值 2.axis:可选值0表示按列,1表示按行求最大值 3.name 4.默认使用axis即可

重点说说axis参数的作用 举例说明

test = np.array([[1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]]) np.argmax(test, 0)   #输出:array([3, 3, 1] np.argmax(test, 1)   #输出:array([2, 2, 0, 0]

解释:

# axis参数为0时: test[0] = array([1, 2, 3]) test[1] = array([2, 3, 4]) test[2] = array([5, 4, 3]) test[3] = array([8, 7, 2]) # output : [3, 3, 1]

此时输出的是每一列最大值所在的数组下标。输出的数组元素数量是原矩阵的列数

# axis参数为1时: test[0] = array([1, 2, 3]) #2 test[1] = array([2, 3, 4]) #2 test[2] = array([5, 4, 3]) #0 test[3] = array([8, 7, 2]) #0 # output : [2, 2, 0, 0]

此时输出的每一个数组中最大值所在的列号。输出的数组元素个数是原数组的数量,即原矩阵行数。

通过比较,我们可以看到,axis两个参数的区别是:0是每个数组对应位置之间的比较,而1则是数组内部元素之间的比较。

转载请注明原文地址: https://www.6miu.com/read-2624534.html

最新回复(0)