使用pytorch保存效果最好那个模型+加载模型

1 保存在验证集上表现最好的那一轮模型

1 验证集的作用就是监督训练是否过拟合;

一般默认验证集的损失值经历由下降到上升的阶段;

保存在验证集上损失最小的那个迭代模型,其泛化能力应该最好;

# 在训练部分计算验证集损失值,保存最小损失值对应的那个模型

model = BotRGCN()# 自定义模型实例化,()中可以传定义的参数
def train(epoch,min_loss):
    model.train()
    output = model() # 自动调用定义的forward函数,在()中传相应参数
    loss_train = loss(output[et.train_idx],de.labels[et.train_idx])
    acc_train = accuracy(output[et.train_idx],de.labels[et.train_idx])
    acc_val = accuracy(output[et.val_idx],de.labels[et.val_idx])
    # 计算损失值,做比较
    loss_val = loss(output[et.val_idx],de.labels[et.val_idx])
    optimizer.zero_grad()
    loss_train.backgrad()
    optimizer.step()
    if loss_val < min_loss
        min_loss = loss_val
        print("save model")
        # 保存模型语句
        torch.save(model.state_dict(),"model.pth")
    return loss_train, acc_train, acc_val, min_loss

if __name__ == "__main__":
    epochs = 100
    min_loss = 100
    for epoch in range(epochs):
        loss_train, acc_train, acc_val, min_loss = train(epoch,min_loss)

保存模型中state_dict 是状态字典;

        PyTorch 中,一个模型( torch.nn.Module )的可学习参数(也就是权重和偏置值)是包含在模型参数(model.parameters())中的,一个状态字典就是一个简单的 Python 的字典,其键值对是每个网络层和其对应的参数张量。

        模型的状态字典只包含带有可学习参数的网络层(比如卷积层、全连接层等)和注册的缓存(batchnorm的 running_mean)。优化器对象(torch.optim)同样也是有一个状态字典,包含的优化器状态的信息以及使用的超参数.

        由于状态字典也是 Python 的字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都很容易实现。

        当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save() 来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。

2 加载模型,在测试集上测试模型效果

model = BotRGCN()

model.load_state_dict(torch.load('model.pth'))

model.eval()

test()

        在进行预测之前,必须调用 model.eval() 方法来将 dropout 和 batch normalization 层设置为验证模型。否则,只会生成前后不一致的预测结果。 

        load_state_dict() 方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用 torch.load() ,而不是直接 model.load_state_dict(PATH)

3 另一种保存与加载方法

加载保存整个模型

保存:

torch.save(model, 'model.pkl')

加载:

# Model class must be defined somewhere
model = torch.load('model.pkl')
model.eval()

        保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现;

        这种实现保存模型的做法将是采用 Python 的 pickle 模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle 并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。

物联沃分享整理
物联沃-IOTWORD物联网 » 使用pytorch保存效果最好那个模型+加载模型

发表评论