torchsummary和torchstat使用方法和结果分析

1 torchstat:查看模型的大小和浮动运算量

安装工具 pip install torchstat

使用例子

import torch
import torch.nn as nn
from torchstat import stat

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

        self.fc = nn.Sequential(
            nn.Linear(32 * 127 * 127, 1024),
            nn.ReLU(),
            nn.Linear(1024, 4)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = x.view(-1, 32 * 127 * 127)
        x = self.fc(x)

        return x
model=Net()
print(stat(model, (3, 256, 256)))

输出结果如下所示:

参数说明:

【params】
网络的参数量

【memory】
节点推理时候所需的内存

【Flops】
网络完成的浮点运算

【MAdd】
网络完成的乘加操作的数量。一次乘加=一次乘法+一次加法,所以可以粗略的认为
Flops ≈2*MAdd

【MemRead】
网络运行时,从内存中读取的大小

【MemWrite】
网络运行时,写入到内存中的大小

【MemR+W】
MemR+W = MemRead + MemWrite

2 torchsummary:查看模型结构和输入输出尺寸

torchsummary.summary(model, input_size, batch_size=-1, device="cuda")

功能:查看模型的信息,便于调试

  • model:pytorch 模型,必须继承自 nn.Module
  • input_size:模型输入 size,形状为 C,H ,W
  • batch_size:batch_size,默认为 -1,在展示模型每层输出的形状时显示的 batch_size
  • device:“cuda"或者"cpu”
  • 使用时需要注意,默认device=‘cuda’,如果是在‘cpu’,那么就需要更改。不匹配就会出现下面的错误:

    import torch
    from model.lenet import LeNet
    from torchsummary import summary
    
    # 模型
    
    lenet = LeNet(classes=2)
    
    print(summary(lenet, (3, 32, 32), device="cpu"))
    
    
    #输入结果如下所示:下述信息分别有模型每层的输出形状,每层的参数数量,总的参数数量,以及模型大小等信息。
    #由结果可得,模型的大小是0.23MB
    
    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Conv2d-1            [-1, 6, 28, 28]             456
                Conv2d-2           [-1, 16, 10, 10]           2,416
                Linear-3                  [-1, 120]          48,120
                Linear-4                   [-1, 84]          10,164
                Linear-5                    [-1, 2]             170
    ================================================================
    Total params: 61,326
    Trainable params: 61,326
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.01
    Forward/backward pass size (MB): 0.05
    Params size (MB): 0.23
    Estimated Total Size (MB): 0.30
    ----------------------------------------------------------------
    None
    
    物联沃分享整理
    物联沃-IOTWORD物联网 » torchsummary和torchstat使用方法和结果分析

    发表评论