Pytorch 中topk的用法¶
In [14]:
import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"a = torch.randint(0,10,[2,2,3])
a
a.topk(2,-1)
Out[14]:
tensor([[[4, 1, 5],[9, 2, 4]],[[1, 1, 1],[5, 1, 4]]])
Out[14]:
torch.return_types.topk(
values=tensor([[[5, 4],[9, 4]],[[1, 1],[5, 4]]]),
indices=tensor([[[2, 0],[0, 2]],[[0, 1],[0, 2]]]))
第一个返回结果为:在指定维度上选出topk个元素,并且依次从大到小排序,输出的形状为 (bach_size,seq_len,top_k) 第二个返回结果为: topk最大的元素索引,其形状为 (batch,seq_len,top_k)