pytorch获取全部权重参数、每一层权重参数

pytorch获取全部权重参数、每一层权重参数

首先需要安装torchsummary
在相应的虚拟环境下pip install torchsummary

1、打印每层参数信息:
summary(net,input_size,batch_size,device),

net:网络模型
input_size:网络输入图片的shape
batch_size:默认参数为-1
device:在gpu上还是cpu上运行,默认是cuda在gpu上运行,若想在cpu上运行,需将参数改为cpu。

eg.vgg16网络:
from models import VGG16_torch
model = vgg16()
summary(model,(3,32,32),device=‘cpu’)

2、根据需要,输出相应层的权重
首先查看每层对应的名称

model = vgg16()
for name in model.state_dict():
  print(name)


再根据名称输出相应层的权重

 print(model.state_dict()['layers.0.conv2d.weight'])


3、打印模块名字和参数大小

for name, parameters in model.named_parameters():  
    print(name, ';', parameters.size())

输出结果:

4、加载模型全部参数

import torch
y = torch.load('vgg16_baseline.t7')
print(y)

来源:柠檬树下你和我₰

物联沃分享整理
物联沃-IOTWORD物联网 » pytorch获取全部权重参数、每一层权重参数

发表评论