np.argmax(A, axis=1) is a NumPy function that returns an array of indices of the maximum values along the second axis (axis=1) of a given NumPy array A.
import numpy as np
A = np.array([[1, 3, 2], [4, 2, 1], [3, 5, 6]])
print(np.argmax(A, axis=1))
# output: [1 0 2]
网友评论