torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
t = torch.rand(3,5)
"""
tensor([[0.1501, 0.1785, 0.9239, 0.7364, 0.0742],
[0.4710, 0.9974, 0.9749, 0.4824, 0.2628],
[0.5910, 0.3075, 0.0327, 0.6995, 0.5297]])
"""
t.topk(k=2,dim=0,largest=True, sorted=True).values
"""
输出可以用.values调用topk的值,用.indices调用其索引。
注意形状会维持原数据形状。
例如本例中当对(3,5)取dim=0的top2时,shape为(2,5),原3被2替换。
tensor([[0.5910, 0.9974, 0.9749, 0.7364, 0.5297],
[0.4710, 0.3075, 0.9239, 0.6995, 0.2628]])
"""
t.topk(k=2,dim=1,largest=True, sorted=True).values
"""
shape=(3,2)
tensor([[0.9239, 0.7364],
[0.9974, 0.9749],
[0.6995, 0.5910]])
"""
t.topk(k=2,dim=0,largest=True, sorted=True)
"""
torch.return_types.topk(
values=tensor([[0.5910, 0.9974, 0.9749, 0.7364, 0.5297],
[0.4710, 0.3075, 0.9239, 0.6995, 0.2628]]),
indices=tensor([[2, 1, 1, 0, 2],
[1, 2, 0, 2, 1]]))
"""
网友评论