Pytorch避坑之:RuntimeError: Input type(torch.cuda.FloatTensor) and weight type(torch.FloatTensor) shoul

问题分析

  • 就像是字面意思那样,这个错误是因为模型中的 weights 没有被转移到 cuda 上,而模型的数据转移到了 cuda 上而造成的
  • 但是造成这个问题的原因却没有那么简单。
  • 绝大多数时候,造成这个的原因是因为你定义好模型之后,没有对模型进行 to(device) 而造成的,但是,也有可能,是因为你的模型在定义的时候,没有定义好,导致模型的一部分在加载的时候没有办法转移到 cuda上。
  • 细节举例

  • 比如我现在定义了一个模型 A,B,它们的结构如下:
  • # @Time : 2022/1/19 17:57 
    # @Author : PeinuanQin
    # @File : test.py
    import torch.nn as nn
    import torch
    import torch.utils.data as Data
    from tqdm import tqdm
    from torchvision import transforms,datasets
    import numpy as np
    import torchvision
    from torch.optim import lr_scheduler
    
    
    class A(nn.Module):
        def __init__(self):
            super(A,self).__init__()
            self.conv = nn.Conv2d(in_channels=3
                                  ,out_channels=8
                                  ,kernel_size=3)
            self.relu = nn.ReLU(inplace=True)
    
        def forward(self,x):
            out = self.conv(x)
            out = self.relu(out)
            B_model = B()
            out = B_model(out)
            return out
    
    class B(nn.Module):
        def __init__(self):
            super(B,self).__init__()
            self.conv = nn.Conv2d(in_channels=8
                                  ,out_channels=16
                                  ,kernel_size=3)
            self.relu = nn.ReLU(inplace=True)
    
        def forward(self, x):
            out = self.conv(x)
            out = self.relu(out)
            return out
    
    
    
    

  • 这个时候就会报错,而报错的原因,就是因为 torch 的流程是这样的:
  • 首先将所有的模型加载,先从 A 开始,进入 A 的 init 中把所有的内容加载,然后,通过 main 函数中的 to(device) 操作,就把加载的所有内容和网络定义都放到 cuda 上了,但是注意!!!
  • 第二步开始训练,训练的过程中,都是通过 forward 函数来调用的,但是这个时候程序发现,当进入 A 的 forward 中运行的时候,出现了几个 B 的网络层,但是注意:这些 B 中定义的网络层,在网络加载的过程中可是没有出现在 A 的 __init__里面,也就理所当然地没有加载到 cuda上,因此在 A 的 forward 中出现的时候,B 的这几个网络层的 weight 依然在 cpu 上,这就导致了错误。
  • 改错思路

  • 将所有的内容都放到 cpu 上运行,即:
  • 但显然这是个治标不治本的方法,我们就没有办法使用 gpu 训练了,因此我们选择把所有的网络层(只要有参数需要训练的网络层)都放到 init 里面去定义,只在 forward 中写运行时的逻辑,即:
  • class A(nn.Module):
        def __init__(self):
            super(A,self).__init__()
            self.conv = nn.Conv2d(in_channels=3
                                  ,out_channels=8
                                  ,kernel_size=3)
            self.relu = nn.ReLU(inplace=True)
            self.b_module = B()
    
        def forward(self,x):
            out = self.conv(x)
            out = self.relu(out)
            out = self.b_module(out)
            return out
    
    class B(nn.Module):
        def __init__(self):
            super(B,self).__init__()
            self.conv = nn.Conv2d(in_channels=8
                                  ,out_channels=16
                                  ,kernel_size=3)
            self.relu = nn.ReLU(inplace=True)
    
        def forward(self, x):
            out = self.conv(x)
            out = self.relu(out)
            return out
    
    

    来源:暖仔会飞

    物联沃分享整理
    物联沃-IOTWORD物联网 » Pytorch避坑之:RuntimeError: Input type(torch.cuda.FloatTensor) and weight type(torch.FloatTensor) shoul

    发表评论