基于之前有C++基础,对于python中的一些函数的用法总会有些疑问。
例如,为什么python可以直接调用对象,而不是调用对象里的函数呢?
以下为包含__call__函数的类的调用
除此之外, 在PyTorch 中,所有继承自 nn.Module 的类都继承了一个特殊的 call() 方法。
# 使用ToTensor创建一个对象
tensor_trans=transforms.ToTensor()
# 这里使用了magic function,ToTensor类中包含了__call__()函数,所以可以直接将新建的对象实例当成函数使用,并传入参数
tensor_img=tensor_trans(img)
可以看到上面将ToTensor()创建的实例tensor_trans直接当函数在使用,直接向它里面传入了img参数,这样实际上会调用__call__()函数,得到的返回值也是__call__()函数的返回值
以下为另一种使用到__call__的情况(call自动调用call里面所包含的一些函数,例如forward)
实际上
class MM(nn.module):def __init__(self):super(MM,self).__init__()def forward(self,input):output=input+1return outputmm=MM()
output=mm(torch.tensor(1.0))
print(output)
上面的代码展示了在我们继承自nn.module类的自定义类中,定义了forward函数,那么为什么可以直接通过mm(torch.tensor(1.0))来直接调用类呢?
因为在PyTorch 中,所有继承自 nn.Module 的类都继承了一个特殊的 call() 方法,而call会自动调用它内部的forward函数。