解决RuntimeError报错:Sizes of tensors must match except in dimension

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2 but got size 3 for tensor number 1 in the list.

常见的模型报错,比方说pix2pix模型

In[18], line 84, in Generator.forward(self, x)

        82 bottleneck = self.bottleneck(d7)

        83 up1 = self.up1(bottleneck)

—> 84 up2 = self.up2(torch.cat([up1, d7], 1))

        85 up3 = self.up3(torch.cat([up2, d6], 1))

        86 up4 = self.up4(torch.cat([up3, d5], 1))

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2 but got size 3 for tensor number 1 in the list.

解决方案:

模型里面加一个函数

from torch import nn
import torch.nn.functional as F
class Generator(nn.Module):
    def __init__(self,*args):
        self.padder_size = 256

        '''
        模型该长啥样长啥样
        '''
    
    def forward(self,inp):
        B,C,H,W = inp.shape
        inp = self.check_image_size(inp)
        '''
        该怎么forward怎么forward
        '''
        
        return x[:,:,:H,:W]
    def check_image_size(self, x):
        _, _, h, w = x.size()
        mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
        mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
        x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
        return x

padder_size根据最接近你数据集的来,这个函数是从GitHub – megvii-research/NAFNet: The state-of-the-art image restoration model without nonlinear activation functions.这个模型的代码里找的,本来是做pix2pix但是输入为300*300的时候就报错,256*256就不报错,后面发现是中间反卷积的时候输出形状和下采样的形状不一样,cat就不好使了,上了这个函数就好使了,但是会慢不少。

物联沃分享整理
物联沃-IOTWORD物联网 » 解决RuntimeError报错:Sizes of tensors must match except in dimension

发表评论