logo头像
Snippet 博客主题

Pytorch-模型保存&加载

Pytorch有三种方式保存和加载模型


Entire Model

最简单的方式是直接把模型保存到本地路径下

1
2
3
4
5
torch.save(model, PATH)

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

state_dict

state_dict是原生的python dict对象,它可以将模型的每个Layer映射为参数tensor.

In PyTorch, the learnable parameters (i.e. weights and biases) of an torch.nn.Module model are contained in the model’s parameters (accessed with model.parameters()). A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor. Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict. Optimizer objects (torch.optim) also have a state_dict, which contains information about the optimizer’s state, as well as the hyperparameters used.

1
2
3
4
5
torch.save(model.state_dict(), PATH)

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

TorchScript(Recommend)

上面两种方式都有一个缺陷,需要在执行环境中引入模型定义代码. TorchScript是真正被推荐用于大规模生产环境的方式,因为它能够做到训练和预测的分离.

  1. TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously.
  2. This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python
  3. TorchScript gives us a representation in which we can do compiler optimizations on the code to provide more efficient execution
  4. TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators.

save
将模型转化为二进制流之后可以保存在外部存储中

1
2
3
4
5
6
7
import torch
import io

model_scripted = torch.jit.script(model)
buffer = io.BytesIO()
torch.jit.save(model_scripted, buffer)
buffer.getvalue()

load

1
2
3
4
model_bytes = '从外部加载的二进制流'
buffer = io.BytesIO(model_bytes)
model = torch.jit.load(buffer)
model.eval()

GPU训练/CPU推断
当使用GPU训练模型但是却使用CPU做推断的时候,加载模型时需要指定执行的设备

1
2
3
4
5
6
device = torch.device('cpu')

model_bytes = '从外部加载的二进制流'
buffer = io.BytesIO(model_bytes)
model = torch.jit.load(buffer, map_location=device)
model.eval()

调用其它函数
通过jit保存的模型正常只有forward函数对外暴露,但是可以通过加@torch.jit.export注解的方式暴露自定义的函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

@torch.jit.export
def f():
return 'hello world!'

model.load_state_dict(torch.load(PATH, map_location=device))
model.f()

评论系统未开启,无法评论!