一、.pt或.pth文件简介
解释:PyTorch 的默认模型文件格式,用于保存和加载完整的 PyTorch 模型,包含模型的结构和参数等信息。
适用场景:需要保存和加载完整的 PyTorch 模型的场景,例如在训练中保存最佳的模型或在部署中加载训练好的模型。
包含的参数:
model_state_dict:模型每一层可学习的节点的参数,比如weight/bias
optimizer_state_dict:模型的优化器中的参数
epoch:当前的训练轮数
loss:当前的损失值
二、查看.pt或.pth文件
方法1:
import torch
content = torch.load('model.pt',map_location=torch.device('cpu'))
print(content.keys())
print(content['model'])
方法2:
import torch
content = torch.load('model.pt',map_location=torch.device('cpu'))
## k 参数名,v 对应参数值
for k,v in content.items():
print(k,v)
方法3:
import torch
content = torch.load('model.pt',map_location=torch.device('cpu'))
for parameter in content.parameters():
print(parameter)
方法4:
import torch
content = torch.load("model.pt")
content.eval()
用数据对模型进行训练后得到了比较理想的模型,但在实际应用的时候不可能每次都先进行训练然后再使用,所以就得先将之前训练好的模型保存下来,然后在需要用到的时候加载一下直接使用。
本文参考的文章:
1.https://blog.csdn.net/weixin_44212848/article/details/124022462
2.https://blog.csdn.net/qq_27353621/article/details/126551086
3.https://zhuanlan.zhihu.com/p/620688513
4.https://zhuanlan.zhihu.com/p/422797058