net.train() 和 net.eval() 两个函数只要适用于Dropout与BatchNormalization的网络,会影响到训练过程中这两者的参数。
- net.train()时,训练时每个min-batch时都会根据情况进行上述两个参数的相应调整
- net.eval()时,由于网络训练完毕后参数都是固定的,因此每个批次的均值和方差都是不变的,因此直接结算所有batch的均值和方差。所有Batch Normalization的训练和测试时的操作不同。
class Model1(nn.Module):
def __init__(self):
super(Model1, self).__init__()
self.dropout = nn.Dropout(0.5)
def forward(self, x):
return self.dropout(x)
m1 = Model1()
inputs = torch.ones(10)
print(inputs)
print(20 * '-' + "train model:" + 20 * '-' + '\r\n')
print(m1(inputs))
print(20 * '-' + "eval model:" + 20 * '-' + '\r\n')
m1.eval()
print(m1(inputs))
"""
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
--------------------train model:--------------------
tensor([0., 2., 0., 2., 2., 0., 0., 0., 2., 0.])
--------------------eval model:--------------------
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
"""