Pytorch中nn.Module中self.register_buffer的解释
self.register_buffer作用解释
今天遇到了这样一种用法,self.register_buffer(‘name’,Tensor),该方法的作用在于定义一组参数。该组参数在模型训练时不会更新(即调用optimizer.step()后该组参数不会变化,只可人为地改变它们的值),但是该组参数又作为模型参数不可或缺的一部分。
实验
四种方式初始化模型中的参数
- 定义常见模型时的操作
- 使用register_buffer()定义一组参数
- 使用register_parameter()定义一组参数
- 使用python类的属性方式定义一组变量
import torch
import torch.nn as nn
from collections import OrderedDict
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
#(1)定义常见模型时的操作
self.param_nn = nn.Sequential(OrderedDict([
('conv',nn.Conv2d(1,1,3,bias=False)),
('fc',nn.Linear(1,2,bias=False))
]))
#(2)使用register_buffer()定义一组参数
self.register_buffer('reg_buf',torch.randn(1,2))
#(3)使用register_parameter()定义一组参数
self.register_parameter('reg_param',nn.Parameter(torch.randn(1,2)))
#(4)使用python类的属性方式定义一组变量
self.param_attr = torch.randn(1,2)
net = Model()
问题1:哪些参数会在模型训练时被更新?
因为定义优化器时会传入一个参数net.parameters,所以在模型训练时更新的参数可以通过list(net.named_parameters())查看
结果说明,只有方式(1)和方式(3)定义的参数可以被更新
问题2:模型中的参数到底有哪些?
模型中的所有参数都装在state_dict()中,所以可以通过net.state_dict()方式查看
结果说明,只有方式(4)的参数不在模型的参数列表,没有被模型训练时更新的参数reg_buf,依然在模型的参数列表里
self.register_buffer()的使用方法
- 传入参数:第一个参数传入一个字符串,表示这组参数的名字,第二个就是tensor形式的参数
- 在模型定义中调用:使用self.name方法,本例中就是self. reg_buf
- 在实例化模型后调用:使用net.buffers()方法。
其他知识
实际上,Pytorch定义的模型用OrderedDict()方式记录这三种类型,分别保存在self._modules, self._parameters 和self.buffer三个私有属性中
在模型实例化后可以用以下方法看三个私有属性中的变量
net.modules()
net.parameters()
net.buffers()
self._parameters 和net.parameters() 的返回值并不相同,self._parameters只记录了使用self.register_parameter()定义的参数,而net.parameters()返回所有可学习参数。