forward()函数报错:接收到3个参数,但只需2个参数

问题描述:

在forward中明明正确数量的参数,却报错:forward() takes 2 positional arguments but 3 were given;

问题分析:

使用nn.Sequential()定义的网络,只接受单输入

例如:

self.backbone=nn.Sequential(nn.lstm(input_size=20, hidden_size=40, num_layers=2),

                                    nn.linear(in_features=40, out_features=2))

def forward(self, input):

        h0 = torch.randn(hidden_layers, batch_size, hidden)

        c0 = torch.randn(hidden_layers, batch_size, hidden)
        output, _ = self.backbone(input)  (对)

         output, _ = self.backbone(input, (h0, c0)   (错误,因为nn.Sequential()定义的网络,只接受单输入

物联沃分享整理
物联沃-IOTWORD物联网 » forward()函数报错:接收到3个参数,但只需2个参数

发表评论