In [15]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import torch.nn.functional as F
import torch
a = torch.tensor([[1,2],[1,2],[1,2],[1,2]])
F.one_hot(a) # 依次对张量i中的每个元素进行one hot 编码, 返回的形状为 (*, num_classes), 类似与 位置编码, 新增一个维度
Out[15]:
tensor([[[0, 1, 0],[0, 0, 1]],[[0, 1, 0],[0, 0, 1]],[[0, 1, 0],[0, 0, 1]],[[0, 1, 0],[0, 0, 1]]])
In [16]:
# 同样可以指定分类数
a = torch.tensor([[1,2],[1,2],[1,2],[1,2]])
F.one_hot(a, num_classes=4)
Out[16]:
tensor([[[0, 1, 0, 0],[0, 0, 1, 0]],[[0, 1, 0, 0],[0, 0, 1, 0]],[[0, 1, 0, 0],[0, 0, 1, 0]],[[0, 1, 0, 0],[0, 0, 1, 0]]])