PyTorch 打印网络模型结构

🤵 AuthorHorizon Max

编程技巧篇各种操作小结

🎇 机器视觉篇会变魔术 OpenCV

💥 深度学习篇简单入门 PyTorch

🏆 神经网络篇经典网络模型

💻 算法篇再忙也别忘了 LeetCode

文章目录

  • PyTorch 打印网络模型结构
  • 使用 Print() 函数打印网络
  • Tensorflow / Keras 打印网络
  • PyTorch summary打印网络结构的方法
  • PyTorch 打印网络模型结构

    使用 Print() 函数打印网络

    我们在使用PyTorch打印模型结构时都是这样操作的:

    model = simpleNet()
    print(model)
    

    打印结果:

    simpleNet(
      (layer1): Sequential(
        (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): ReLU()
      )
      (layer2): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): ReLU()
      )
      (layer3): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (3): ReLU()
      )
      (dropout): Dropout(p=0.5, inplace=False)
      (fc): Linear(in_features=1024, out_features=10, bias=True)
      (out): Linear(in_features=10, out_features=10, bias=True)
    )
    

    可以很容易发现这样打印出来的网络结构 不清晰 ,参数看起来都很 !

    如果是一个简单一点的网络可能影响不是很大,但当随着网络层数加深、结构复杂、参数量变大时,就会看得很难受 !

    Tensorflow / Keras 打印网络

    使用 model.summary() 函数打印出网络结构:

    model = MyNet()
    model.summary()
    

    对比上面可以看到网络结构 很清晰

    _________________________________________________________________
    Layer (type)                 Output Shape              Param #   
    =================================================================
    dense_4 (Dense)              (None, 256)               25856
    _________________________________________________________________
    leaky_re_lu_3 (LeakyReLU)    (None, 256)               0
    _________________________________________________________________
    batch_normalization_1 (Batch (None, 256)               1024
    _________________________________________________________________
    dense_5 (Dense)              (None, 512)               131584
    _________________________________________________________________
    leaky_re_lu_4 (LeakyReLU)    (None, 512)               0
    _________________________________________________________________
    batch_normalization_2 (Batch (None, 512)               2048
    _________________________________________________________________
    dense_6 (Dense)              (None, 1024)              525312
    _________________________________________________________________
    leaky_re_lu_5 (LeakyReLU)    (None, 1024)              0
    _________________________________________________________________
    batch_normalization_3 (Batch (None, 1024)              4096
    _________________________________________________________________
    dense_7 (Dense)              (None, 784)               803600
    _________________________________________________________________
    reshape_1 (Reshape)          (None, 28, 28, 1)         0
    =================================================================
    Total params: 1,493,520
    Trainable params: 1,489,936
    Non-trainable params: 3,584
    _________________________________________________________________
    

    PyTorch summary打印网络结构的方法

    首先需要安装一个库文件 torchinfo

    pip install torchinfo
    
    conda install -c conda-forge torchinfo
    

    然后使用 summary 函数打印网络结构:

    model = simpleNet()
    batch_size = 64
    summary(model, input_size=(batch_size, 3, 32, 32))
    

    网络结构输出结果如下:

    ==========================================================================================
    Layer (type:depth-idx)                   Output Shape              Param #
    ==========================================================================================
    simpleNet                                --                        --
    ├─Sequential: 1-1                        [64, 16, 16, 16]          --
    │    └─Conv2d: 2-1                       [64, 16, 32, 32]          448
    │    └─BatchNorm2d: 2-2                  [64, 16, 32, 32]          32
    │    └─MaxPool2d: 2-3                    [64, 16, 16, 16]          --
    │    └─ReLU: 2-4                         [64, 16, 16, 16]          --
    ├─Sequential: 1-2                        [64, 32, 8, 8]            --
    │    └─Conv2d: 2-5                       [64, 32, 16, 16]          4,640
    │    └─BatchNorm2d: 2-6                  [64, 32, 16, 16]          64
    │    └─MaxPool2d: 2-7                    [64, 32, 8, 8]            --
    │    └─ReLU: 2-8                         [64, 32, 8, 8]            --
    ├─Sequential: 1-3                        [64, 64, 4, 4]            --
    │    └─Conv2d: 2-9                       [64, 64, 8, 8]            18,496
    │    └─BatchNorm2d: 2-10                 [64, 64, 8, 8]            128
    │    └─MaxPool2d: 2-11                   [64, 64, 4, 4]            --
    │    └─ReLU: 2-12                        [64, 64, 4, 4]            --
    ├─Dropout: 1-4                           [64, 1024]                --
    ├─Linear: 1-5                            [64, 10]                  10,250
    ├─Linear: 1-6                            [64, 10]                  110
    ==========================================================================================
    Total params: 34,168
    Trainable params: 34,168
    Non-trainable params: 0
    Total mult-adds (M): 181.82
    ==========================================================================================
    Input size (MB): 0.79
    Forward/backward pass size (MB): 29.37
    Params size (MB): 0.14
    Estimated Total Size (MB): 30.29
    ==========================================================================================
    

    更多详情可以参考 github 源码:torchinfo

    来源:Horizon Max

    物联沃分享整理
    物联沃-IOTWORD物联网 » PyTorch 打印网络模型结构

    发表评论