Pytorch模型
在PyTorch中,模型保存和加载的常用格式有两种:pth和pkl。其中,pth格式是PyTorch专用的格式,可以直接加载到PyTorch中。而pkl格式是Python的pickle文件,可以保存任意Python对象,包括PyTorch模型。
import torch
import pickle
# 假设我们有一个训练好的模型
model = torch.nn.Linear(10, 2)
model.load_state_dict({
'weight': torch.tensor([[-1.6142, -0.1638, -0.5264, -0.2743, -0.3898, 0.5448, -0.3325, -0.3148, -0.2186, -0.3691],[-0.2477, -0.4999, 0.3698, 0.0896, 0.1587, 0.3773, 0.3378, 0.1277, 0.1679, -0.4554]]),
'bias': torch.tensor([-1.2683, 0.6465])
})
model.eval()
# 将模型保存为PKL文件
with open('model.pkl', 'wb') as f:
pickle.dump(model, f)
在上面的代码中,创建了一个简单的线性模型,然后使用load_state_dict
方法加载了预训练的权重和偏置项。最后,我们使用Python的pickle模块将模型保存到名为model.pkl
的文件中。
加载pkl文件中的模型
import torch
import pickle
# 从PKL文件中加载模型
with open('model.pkl', 'rb') as f:
model = pickle.load(f)
# 将模型转换为PyTorch模型对象
model = model.to(device) # 将模型移至指定设备上(例如GPU或CPU)
model.eval() # 设置模型为评估模式
使用pickle模块从名为model.pkl
的文件中加载模型。将模型移至指定的设备上(例如GPU或CPU),并将模型设置为评估模式。
作者:SteveChen 创建时间:2025-04-15 11:52
最后编辑:SteveChen 更新时间:2025-04-15 11:58
最后编辑:SteveChen 更新时间:2025-04-15 11:58