Pytorch中nn.Module中self.register_buffer的解释

self.register_buffer作用解释

今天遇到了这样一种用法,self.register_buffer(‘name’,Tensor),该方法的作用在于定义一组参数。该组参数在模型训练时不会更新(即调用optimizer.step()后该组参数不会变化,只可人为地改变它们的值),但是该组参数又作为模型参数不可或缺的一部分。

实验

四种方式初始化模型中的参数

  1. 定义常见模型时的操作
  2. 使用register_buffer()定义一组参数
  3. 使用register_parameter()定义一组参数
  4. 使用python类的属性方式定义一组变量
import torch
import torch.nn as nn
from collections import OrderedDict

class Model(nn.Module):
	def __init__(self):
	super(Model,self).__init__()
	
	#(1)定义常见模型时的操作
	self.param_nn = nn.Sequential(OrderedDict([
		('conv',nn.Conv2d(1,1,3,bias=False)),
		('fc',nn.Linear(1,2,bias=False))
	]))
	
	#(2)使用register_buffer()定义一组参数
	self.register_buffer('reg_buf',torch.randn(1,2))

	#(3)使用register_parameter()定义一组参数
	self.register_parameter('reg_param',nn.Parameter(torch.randn(1,2)))

	#(4)使用python类的属性方式定义一组变量
	self.param_attr = torch.randn(1,2)

net = Model()
	

问题1:哪些参数会在模型训练时被更新?

因为定义优化器时会传入一个参数net.parameters,所以在模型训练时更新的参数可以通过list(net.named_parameters())查看

结果说明,只有方式(1)和方式(3)定义的参数可以被更新

问题2:模型中的参数到底有哪些?

模型中的所有参数都装在state_dict()中,所以可以通过net.state_dict()方式查看

结果说明,只有方式(4)的参数不在模型的参数列表,没有被模型训练时更新的参数reg_buf,依然在模型的参数列表里

self.register_buffer()的使用方法

  1. 传入参数:第一个参数传入一个字符串,表示这组参数的名字,第二个就是tensor形式的参数
  2. 在模型定义中调用:使用self.name方法,本例中就是self. reg_buf
  3. 在实例化模型后调用:使用net.buffers()方法。

其他知识

实际上,Pytorch定义的模型用OrderedDict()方式记录这三种类型,分别保存在self._modules, self._parameters 和self.buffer三个私有属性中

在模型实例化后可以用以下方法看三个私有属性中的变量
net.modules()
net.parameters()
net.buffers()

self._parameters 和net.parameters() 的返回值并不相同,self._parameters只记录了使用self.register_parameter()定义的参数,而net.parameters()返回所有可学习参数。

参考:
[1]Pytorchnn.Module中的self.register_buffer()解析

物联沃分享整理
物联沃-IOTWORD物联网 » Pytorch中nn.Module中self.register_buffer的解释

发表评论