文章目录
- 报错信息
- 原因
- 代码示例
- 错误版
- 改正
报错信息
RuntimeError: expected scalar type Long but found Float
原因
nn.Linear需要作用于浮点数,这里可能输入了整数类型的张量作为参数。
代码示例
错误版
import torch
import torch.nn as nn
a = torch.tensor([1,2,3,4])
lin = nn.Linear(4,2)
b = lin(a)
print(b)
报错:
改正
import torch
import torch.nn as nn
a = torch.tensor([1,2,3,4])
lin = nn.Linear(4,2)
b = lin(a.float())
print(b)
把a转为float,结果为:
tensor([-1.1703, 0.0518], grad_fn=<AddBackward0>)