Pytorch学习笔记(三)——nn.Sequential的理解

nn.Sequential的理解

  • 一、源码剖析
  • 二、实战意义
  • 在定义CNN模型的时候看到有如下定义,其中讲解一下nn.Sequential

    class CNN(nn.Module):
       def __int__(self):
          super(CNN,self).__init__()
          self.conv1=nn.Sequential(nn.Conv2d(in_channels=1,
                               out_channels=16,
                               kernel_size=3,
                               stride=2,
                               padding=1),
                               torch.nn.BatchNorm2d(16),
                               nn.ReLU()
                               )
    
    

    一、源码剖析

    nn.Sequential是一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。
    nn.Sequential的源码可以看到如下:

       def __init__(self, *args):
            super(Sequential, self).__init__()
            if len(args) == 1 and isinstance(args[0], OrderedDict):
                for key, module in args[0].items():
                    self.add_module(key, module)
            else:
                for idx, module in enumerate(args):
                    self.add_module(str(idx), module)
                    
        def forward(self, input):
            for module in self:
                input = module(input)
            return input
    

    if len(args) == 1 and isinstance(args[0], OrderedDict) 其判断是否使用OrderedDict,如果自己定义了名称的话使用自定义的名称,否则将使用idx自动定义。
    forword()函数可知,sequential会使用for按顺序调用相应的模块。
    官方提供的例子如下:

    # Example of using Sequential
    model1 = nn.Sequential(
              nn.Conv2d(1,20,5),
              nn.ReLU(),
              nn.Conv2d(20,64,5),
              nn.ReLU()
            )
    print(model1)
    # Sequential(
    #   (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
    #   (1): ReLU()
    #   (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
    #   (3): ReLU()
    # )
    
    # Example of using Sequential with OrderedDict
    import collections
    model2 = nn.Sequential(collections.OrderedDict([
              ('conv1', nn.Conv2d(1,20,5)),
              ('relu1', nn.ReLU()),
              ('conv2', nn.Conv2d(20,64,5)),
              ('relu2', nn.ReLU())
            ]))
    print(model2)
    # Sequential(
    #   (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
    #   (relu1): ReLU()
    #   (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
    #   (relu2): ReLU()
    # )
    
    

    二、实战意义

    为什么要使用nn.Sequential?假设我们要定义一个网络,其中一层是这样的:
    input–>Linear(input)–> nn.ReLU(input)–>Linear(input)–> nn.ReLU(input)–>Linear(input)

    class Net(nn.Module):
        def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
            super().__init__()
    
          	self.Linear1 = nn.Linear(in_dim, n_hidden_1),
    		self.Linear2=nn.Linear(n_hidden_1, n_hidden_2)
    		self.Linear3=nn.Linear(n_hidden_2, out_dim)
                
    
      	def forward(self, x):
          	out= self.Linear1(x)
          	out=torch.relu(out)
          	out=self.Linear2(out)
          	out=torch.relu(out)
          	out=self.Linear3(out)
          	return out
    

    此时可使用nn.Sequential化简操作

    class Net(nn.Module):
        def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
            super().__init__()
    
          	self.layer = nn.Sequential(
                nn.Linear(in_dim, n_hidden_1), 
                nn.ReLU(True),
                nn.Linear(n_hidden_1, n_hidden_2),
                nn.ReLU(True),
                # 最后一层不需要添加激活函数
                nn.Linear(n_hidden_2, out_dim)
                 )
    
      	def forward(self, x):
          	x = self.layer(x)
          	return x
    
    

    Reference:
    PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景
    pytorch系列7 —–nn.Sequential讲解

    来源:酒与花生米

    物联沃分享整理
    物联沃-IOTWORD物联网 » Pytorch学习笔记(三)——nn.Sequential的理解

    发表评论