Pytorch保存和加载模型(load和load_state_dict)

Pytorch目前成为学术界最流行的DL框架,没有之一。很大程度上,简洁直观地操作有关。模型的保存和加载,于pytorch而言,也是很简单的。本文做了一个比较实验,方便大家理解。
首先,要清楚几个函数:torch.save,torch.load,state_dict(),load_state_dict()。
先举最简单的例子:

import torch

model = torch.load('my_model.pth')
torch.save(model, 'new_model.pth')

上面的代码非常直观,一载一存。但是有一个问题,这样保存的pth文件直接包含了整个模型的结构当你需要灵活加载模型参数时,比如只加载部分参数,那么这种情况保存的pth文件读取进来还得额外解析出“参数文件”

如果想更灵活对待咱们训练好的模型参数,咱们可以使用下面这个方法。pytorch把所有的模型参数用一个内部定义的dict进行保存,自称为“state_dict”。这个所谓的state_dict就是不带模型结构的模型参数了~
咱们的加载和保存就发生了一点微妙的变化:

import torch
model = MyModel() # init your model class, build the graph shape

state_dict = torch.load('model_state_dict.pth')
model.load_state_dict(state_dict)

torch.save(model.state_dict(), 'model_state_dict1.pth')

比较上面两段代码,咱们可以有一下结论:

pth文件既可能保存了模型的图结构,也有可能没保存;
加载没保存图结构的pth时,需要先初始化模型结构,即把架子搭好;
在保存模型的时候,如果不想保存图结构,可以单独保存model.state_dict()

  • 实验

  • import torch
    import torchvision.models as models
    
    model = models.vgg16(pretrained=True)
    torch.save(model.state_dict(), 'only_weights.pth')
    
    model_state_dict = torch.load('only_weights.pth')
    model1 = models.vgg16() # describe the graph shape
    model1.load_state_dict(model_state_dict)
    model1.eval()
    
    torch.save(model1, 'whole_model.pth')
    
    model2 = torch.load('whole_model.pth')
    model2.eval()
    
    # model3 = torch.load('only_weights.pth')
    # model3.eval()    # Error
    
    
    

    model3切换到eval()模式就会报错,原因是model3只包含weights而缺乏图结构~

  • torch.load_state_dict()函数的用法

  • 在Pytorch中构建好一个模型后,一般需要进行预训练权重中加载。torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中,操作方式如下所示:

    sd_net = torchvision.models.resnte50(pretrained=False)
    sd_net.load_state_dict(torch.load('*.pth'), strict=True)
    

    在本博文中重点关注的是 属性 strict; 当strict=True,要求预训练权重层数的键值与新构建的模型中的权重层数名称完全吻合;如果新构建的模型在层数上进行了部分微调,则上述代码就会报错:说key对应不上。

    此时,如果我们采用strict=False 就能够完美的解决这个问题。也即,与训练权重中与新构建网络中匹配层的键值就进行使用,没有的就默认初始化

    参考博文:
    https://blog.csdn.net/ChaoMartin/article/details/118686268

    https://blog.csdn.net/leviopku/article/details/123925804

    来源:失之毫厘,差之千里

    物联沃分享整理
    物联沃-IOTWORD物联网 » Pytorch保存和加载模型(load和load_state_dict)

    发表评论