pytorch报错(4)forward() missing 1 required positional argument: ‘x‘或者‘NoneType‘ object is not callable

解决:TypeErro: ‘NoneType’ object is not callable/forward()
TypeErro: forward() missing 1 required positional argument: 'x’

文章目录

  • 前言
  • 一、首先展示正确封装的代码
  • 二、两种错误
  • 1.TypeError: forward() missing 1 required positional argument: 'x'
  • 2.TypeErro: 'NoneType' object is not callable/forward()
  • 总结

  • 前言

    我们在构建自己的神经网络类时,经常要在现有的pytorch模型下修改,然后将修改好的类封装到一个新的py文件中,在封装过程中可能遇到如下两种错误:
    TypeErro: ‘NoneType’ object is not callable/forward()
    TypeErro: forward() missing 1 required positional argument: 'x’

    提示:以下是本篇文章正文内容,下面案例可供参考

    一、首先展示正确封装的代码

    以VGG16为例,我们经常会基于pytorch官网封装的torchvision.models.vgg16()进行修改,这里我们在vgg16 的最后添加一层全连接层,100输入,10输出,并封装为一个py文件,方便日后调用。

    #基于vgg16进行修改、封装
    #把改进的ImageNet网络存储起来
    import torchvision
    import torch
    from torch import nn
    
    ImageNet = torchvision.models.vgg16(pretrained=True, progress=True)
    ImageNet.classifier.add_module("linear", nn.Linear(1000, 10))
    class my_net(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = ImageNet
    
        def forward(self, x):
            output = self.model(x)
            return output
    
    if __name__ == '__main__':
        my_model2 = my_net()
        input = torch.ones((1, 3, 32, 32))
        output = my_model2(input)
        print(output.shape)
    

    这样写是没有报错的,可以整段复制,运行结果如下图(main函数只是测试一下网络是否可以正常输出)。

    接下来看一下错误的例子及其报出的错误。

    二、两种错误

    1.TypeError: forward() missing 1 required positional argument: ‘x’

    如果上文第12行self.model = ImageNet后面加了括号,即:self.model = ImageNet(),那么就会报错TypeError: forward() missing 1 required positional argument: ‘x’

    2.TypeErro: ‘NoneType’ object is not callable/forward()

    如果你把第8行 ImageNet.classifier.add_module(“linear”, nn.Linear(1000, 10)) 直接赋给第12行的self.model ,那么就会报错 TypeErro: ‘NoneType’ object is not callable/forward()

    总结

    pytorch中可能一个看起来很正常的括号就会引发错误。建议先在类class外修改好神经网络,然后直接调用修改好的模型名称。
    如果有帮助的话,请顺手点赞,谢谢。

    来源:香菜冰激凌

    物联沃分享整理
    物联沃-IOTWORD物联网 » pytorch报错(4)forward() missing 1 required positional argument: ‘x‘或者‘NoneType‘ object is not callable

    发表评论