pytorch训练中断后,如何在之前的断点处继续训练

我们在训练模型的时候经常出现各种问题导致训练中断,比方说断电,或者关机之类的导致电脑系统关闭,从而将模型训练中断,那么如何在模型中断后,能够保留之前的训练结果不被丢失,同时又可以继续之前的断点处继续训练?

首先在代码离需要保存模型,比方说我们模型设置训练5000轮,那么我们可以选择每100轮保存一次模型,这样的话,在训练的过程中就能保存下100,200,300.。。。等轮数时候的模型,那么当模型训练到400轮的时候突然训练中断,那么我们就可以通过加载400轮的参数来进行继续训练,其实这个过程就类似在预训练模型的基础上进行训练。下面简单粗暴上代码:

1、保存模型

torch.save(checkpoint, checkpoint_path)

其中checkpoint其实保存的就是模型的一些参数,比方说下面这种字典形式的保存所需的模型参数:

checkpoint = {
    'model': model_state_dict,
    'generator': generator_state_dict,
    'opt': model_opt,
    'optim': optim,
}

checkpoint_path则是表示保存的模型

checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)

save_checkpoint_steps是保存的间隔轮数,step是保存的轮数,比方说save_checkpoint_steps=100,那么step的取值就是100,200,300,400等,下面的代码解释step的取值由来。

if step % self.save_checkpoint_steps != 0:
    return
chkpt, chkpt_name = self._save(step)

其中_save函数就是实现了前面checkpoint的内容的保存。

模型的保存设置就此结束。

2、模型的加载

假如此时模型训练中断了,我们得在代码里设置一个参数,这个参数用来查找确定当前路径下是否有已存在得模型。

# 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        generator.load_state_dict(checkpoint['generator'])
        start_epoch = checkpoint['model_opt']
        optim=checkpoint['optim']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存模型,将从头开始训练!')

或者设置一个变量train_from,若赋值已有模型得路径,则继续训练;若为None,那么从头训练。这块代码既可以用于训练中断,又可以用于使用预训练模型。

if opt.train_from:#是否存在预训练模型
    logger.info('Loading checkpoint from %s' % opt.train_from)
    checkpoint = torch.load(opt.train_from)#加载预训练模型的检查点
    model_opt = checkpoint['opt']
else:
    checkpoint = None
    model_opt = opt

加油,come on!

来源:程序小K

物联沃分享整理
物联沃-IOTWORD物联网 » pytorch训练中断后,如何在之前的断点处继续训练

发表评论