探索loss.backward() 和optimizer.step()的关系并灵活运用

loss.backward() 和optimizer.step()的关系及灵活运用

在deep learning的模型训练中,我们经常看到如下的代码片段:

loss.backward()
optimizer.step()

那么,这两个函数到底是怎么联系在一起的呢?

loss.backward()的作用

我们都知道,loss.backward()函数的作用是根据loss来计算网络参数的梯度,其对应的输入默认为网络的叶子节点,即数据集内的数据,叶子节点如下图所示:

同样的,该梯度信息也可以用函数torch.autograd.grad()计算得到

x = torch.tensor(2., requires_grad=True)
y = torch.tensor(3., requires_grad=True)

z = x * x * y
z.backward()
print(x.grad)
>>> tensor(12.)
x = torch.tensor(2., requires_grad=True)
y = torch.tensor(3., requires_grad=True)

z = x * x * y
x_grad = torch.autograd.grad(outputs=z, inputs=x)
print(x_grad[0])
>>> tensor(12.)

以上内容引自https://zhuanlan.zhihu.com/p/279758736

optimizer.step()的作用

优化器的作用就是针对计算得到的参数梯度对网络参数进行更新,所以要想使得优化器起作用,主要需要两个东西:

  • 优化器需要知道当前的网络模型的参数空间
  • 优化器需要知道反向传播的梯度信息(即backward计算得到的信息)
  • 观察一下SGD方法中step()方法的源码

    def step(self, closure=None):
            """Performs a single optimization step.
            Arguments:
                closure (callable, optional): A closure that reevaluates the model
                    and returns the loss.
            """
            loss = None
            if closure is not None:
                loss = closure()
                
            for group in self.param_groups:
                weight_decay = group['weight_decay']
                momentum = group['momentum']
                dampening = group['dampening']
                nesterov = group['nesterov']
                
                for p in group['params']:
                    if p.grad is None:
                        continue
                    d_p = p.grad.data
                    if weight_decay != 0:
                        d_p.add_(weight_decay, p.data)
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = d_p.clone()
                        else:
                            buf = param_state['momentum_buffer']
                            buf.mul_(momentum).add_(1 - dampening, d_p)
                        if nesterov:
                            d_p = d_p.add(momentum, buf)
                        else:
                            d_p = buf
    
    		p.data.add_(-group['lr'], d_p)
    
            return loss
    

    我们可以看到里面有如下的代码

    for p in group['params']:
        if p.grad is None:
            continue
            d_p = p.grad.data
    

    说明,step()函数确实是利用了计算得到的梯度信息,且该信息是与网络的参数绑定在一起的,所以optimizer函数在读入是先导入了网络参数模型’params’,然后通过一个.grad()函数就可以轻松的获取他的梯度信息。

    如何验证该关系的正确性

    我们想通过改变梯度信息来验证该关系的正确性,即是否可以通过一次梯度下降后,再通过一次梯度上升来得到初始化的参数

    import torch
    import torch.nn as nn
    
    #  Check if we have a CUDA-capable device; if so, use it
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Will train on {}'.format(device))
    
    #  为了让参数恢复成初始化状态,使用最简单的SGD优化器
    optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
    
    #  定义模型
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            
            self.linear = nn.Linear(3,1)
            
        def forward(self, x):
            y = self.linear(x)
            return y
        
    #  载入模型与输入,并打印此时的模型参数
    x = (torch.rand(3)).to(device)
    net = CNN().to(device)
    print('the first output!')
    for name, parameters in net.named_parameters():
        print(name, ':', parameters)
        
    print('-------------------------------------------------------------------------------')    
    #  做梯度下降
    optimizer.zero_grad()
    y = net(x)
    loss = (1-y)**2
    
    loss.backward()
    optimizer.step()
    #  打印梯度信息
    for name, parameters in net.named_parameters():
        print(name, ':', parameters.grad)
    #  经过第一次更新以后,打印网络参数
    for name, parameters in net.named_parameters():
        print(name, ':', parameters)
        
    print('-------------------------------------------------------------------------------')
    #  我们直接将网络参数的梯度信息改为相反数来进行梯度上升
    for name, parameters in net.named_parameters():
        parameters.grad *= -1
    #  打印
    for name, parameters in net.named_parameters():
        print('the second output!')
        print(name, ':', parameters.grad)
    

    经过对比,我们发现最后的结果与我们的设想一样,网络参数恢复成初始化状态,因此可以证明optimizer.step()与loss.backward()之间的关系。

    来源:Mr Sorry

    物联沃分享整理
    物联沃-IOTWORD物联网 » 探索loss.backward() 和optimizer.step()的关系并灵活运用

    发表评论