sum 函数有两个定义:
torch.sum(input, *, dtype=None) → Tensor
torch.sum(input, dim, keepdim=False, *, dtype=None) → Tensor
In [2]:
import torch
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"a = torch.tensor([[1, 2, 3],[4,5,6]])
a.sum() # 返回标量类型总和数据
Out[2]:
tensor(21)
In [3]:
a = torch.tensor([[[1, 2, 3],[4,5,6]],[[1, 2, 3],[4,5,6]]])
a.sum(dim=-1) # 对求和的部分,维度会下降
a.sum(dim=-1, keepdim=True) # 将会保持维度不变
Out[3]:
tensor([[ 6, 15],[ 6, 15]])
Out[3]:
tensor([[[ 6],[15]],[[ 6],[15]]])