PyTorch模型加载与保存的最佳实践

一般来说PyTorch有两种保存和读取模型参数的方法。但这篇文章我记录了一种最佳实践,可以在加载模型时避免掉一些问题。

第一种方案是保存整个模型:

```

1<br></br>
|
torch.save(model_object, 'model.pth')<br></br>


第二种方法是保存模型网络参数:

| ```
|-----|
1&lt;br&gt;&lt;/br&gt;
``` | ```
torch.save(model_object.state_dict(), 'params.pth')&lt;br&gt;&lt;/br&gt;
``` |
加载的时候分别这样加载:

| ```
|-----|
1&lt;br&gt;&lt;/br&gt;
``` | ```
model = torch.load('model.pth')&lt;br&gt;&lt;/br&gt;
``` |
以及:

| ```
|-----|
1&lt;br&gt;&lt;/br&gt;
``` | ```
model_object.load_state_dict(torch.load('params.pth'))&lt;br&gt;&lt;/br&gt;
``` |
#  

注意到这个方案是因为模型在加载之后,loss会飙升之后再慢慢降回来。查阅有关分析之后,判定是优化器optimizer的问题。

如果模型的保存是为了恢复训练状态,那么可以考虑同时保存优化器optimizer的参数:

| ```
|-----|
1&lt;br&gt;&lt;/br&gt;2&lt;br&gt;&lt;/br&gt;3&lt;br&gt;&lt;/br&gt;4&lt;br&gt;&lt;/br&gt;5&lt;br&gt;&lt;/br&gt;6&lt;br&gt;&lt;/br&gt;7&lt;br&gt;&lt;/br&gt;
``` | ```
state = {&lt;br&gt;&lt;/br&gt;    'epoch': epoch,&lt;br&gt;&lt;/br&gt;    'net': model.state_dict(),&lt;br&gt;&lt;/br&gt;    'optimizer': optimizer.state_dict(),&lt;br&gt;&lt;/br&gt;    ...&lt;br&gt;&lt;/br&gt;}&lt;br&gt;&lt;/br&gt;torch.save(state, filepath)&lt;br&gt;&lt;/br&gt;
``` |
然后这样加载:

| ```
|-----|
1&lt;br&gt;&lt;/br&gt;2&lt;br&gt;&lt;/br&gt;3&lt;br&gt;&lt;/br&gt;4&lt;br&gt;&lt;/br&gt;
``` | ```
checkpoint = torch.load(model_path)&lt;br&gt;&lt;/br&gt;model.load_state_dict(checkpoint['net'])&lt;br&gt;&lt;/br&gt;optimizer.load_state_dict(checkpoint['optimizer'])&lt;br&gt;&lt;/br&gt;start_epoch =  checkpoint['epoch'] + 1&lt;br&gt;&lt;/br&gt;
``` |
如果模型的保存是为了方便以后进行validation和test,可以在加载完之后制定model.eval()固定dropout和BN层。

声明:该文章系转载,转载该文章的目的在于更广泛的传递信息,并不代表本网站赞同其观点,文章内容仅供参考。

本站是一个个人学习和交流平台,网站上部分文章为网站管理员和网友从相关媒体转载而来,并不用于任何商业目的,内容为作者个人观点, 并不代表本网站赞同其观点和对其真实性负责。

我们已经尽可能的对作者和来源进行了通告,但是可能由于能力有限或疏忽,导致作者和来源有误,亦可能您并不期望您的作品在我们的网站上发布。我们为这些问题向您致歉,如果您在我站上发现此类问题,请及时联系我们,我们将根据您的要求,立即更正或者删除有关内容。本站拥有对此声明的最终解释权。