pytorch 如何加载部分预训练模型
pretrained_dict=torch.load(model_weight)
model_dict=myNet.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, vin pretrained_dict.items() if k inmodel_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
保存模型参数PyTorch 中保存模型的方式有许多种:
# 保存整个网络
torch.save(model, PATH)
# 保存网络中的参数, 速度快,占空间少
torch.save(model.state_dict(),PATH)
# 选择保存网络中的一部分参数或者额外保存其余的参数
torch.save({'state_dict': model.state_dict(), 'fc_dict':model.fc.state_dict(),
'optimizer': optimizer.state_dict(),
'alpha': loss.alpha,
'gamma': loss.gamma}, PATH)
读取模型参数
同样的,PyTorch 中读取模型参数的方式也有许多种:
# 读取整个网络
model = torch.load(PATH)
# 读取 Checkpoint 中的网络参数
model.load_state_dict(torch.load(PATH))
# 若 Checkpoint 中的网络参数与当前网络参数有部分不同,有以下两种方式进行加载:
# 1. 利用字典的 update 方法进行加载
Checkpoint = torch.load(Path)
model_dict = model.state_dict()
model_dict.update(Checkpoint)
model.load_state_dict(model_dict)
# 2. 利用 load_state_dict() 的 strict 参数进行部分加载
model.load_state_dict(torch.load(PATH), strict=False)
冻结部分模型参数,进行 fine-tuning
加载完 Pre-Trained Model 后,我们需要对其进行 Finetune。但是在此之前,我们往往需要冻结一部分的模型参数:
# 第一种方式
for p in freeze.parameters(): # 将需要冻结的参数的 requires_grad 设置为 False
p.requires_grad = False
for p in no_freeze.parameters(): # 将fine-tuning 的参数的 requires_grad 设置为 True
p.requires_grad = True
# 将需要 fine-tuning 的参数放入optimizer 中
optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
# 第二种方式
optim_param = []
for p in freeze.parameters(): # 将需要冻结的参数的 requires_grad 设置为 False
p.requires_grad = False
for p in no_freeze.parameters(): # 将fine-tuning 的参数的 requires_grad 设置为 True
p.requires_grad = True
optim_param.append(p)
optimizer.SGD(optim_param, lr=1e-3) # 将需要 fine-tuning 的参数放入optimizer 中
模型训练与测试的设置
训练时,应调用 model.train() ;测试时,应调用 model.eval(),以及 with torch.no_grad():
model.train():使 model 变成训练模式,此时 dropout 和 batch normalization 的操作在训练起到防止网络过拟合的问题。
model.eval():PyTorch会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。不然的话,一旦测试集的 Batch Size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
with torch.no_grad():PyTorch 将不再计算梯度,这将使得模型 forward 的时候,显存的需求大幅减少,速度大幅提高。