使用pytorch训练和预测时会分别使用到以下两行代码:

model.train()
model.eval()

后来想了解model.eval()的具体作用,在网上查找资料大都是以下原因:
模型中有BatchNormalization和Dropout,在预测时使用model.eval()后会将其关闭以免影响预测结果。

但是没有找到BN和Dropout是具体如何影响预测结果的,直到看到这篇博客中的内容才有所理解,个人理解如下:
1)训练过程中BN的变化。
在训练过程中BN会不断的计算均值和方差,训练结束后得到最终的均值和方差,在此处将其记为mean_train,variance_train。

2)预测过程中BN的变化。
预测过程中如果不使用model.eval()的话,BN层还是会根据输入的预测数据继续计算均值和方差,假设输入一条预测数据后,BN层计算得到其均值和方差分别为mean_test,variance_test,此时BN层的均值和方差则变成了(mean_train+mean_test),(variance_train+variance_test),相比于训练过程中的均值和方差发生了变化因此会导致预测结果发生变化。

如果使用model.eval()则BN层就不会再计算预测数据的均值和方差,即在预测过程中BN层的均值和方差就是训练过程得到的均值和方差mean_train,variance_train,此时预测结果就不会再发生变化。

3)训练过程中Dropout的变化
训练过程中依据设置的dropout比例会使一部分的网络连接不进行计算。

4)预测过程中Dropout的变化
预测过程中如果不使用model.eval()的话,依然会使一部分的网络连接不进行计算,而使用model.eval()后就是所有的网络连接均进行计算。

来源:嘿,兄弟,好久不见

物联沃分享整理
物联沃-IOTWORD物联网 » Pytorch model.eval()的作用

发表评论