博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
[Pytorch]Pytorch 保存模型与加载模型(转)
阅读量:5288 次
发布时间:2019-06-14

本文共 3079 字,大约阅读时间需要 10 分钟。

转自:

目录:

  • 保存模型与加载模型
  • 冻结一部分参数,训练另一部分参数
  • 采用不同的学习率进行训练

1.保存模型与加载

简单的保存与加载方法:

# 保存整个网络 torch.save(net, PATH) # 保存网络中的参数, 速度快,占空间少 torch.save(net.state_dict(),PATH) #-------------------------------------------------- #针对上面一般的保存方法,加载的方法分别是: model_dict=torch.load(PATH) model_dict=model.load_state_dict(torch.load(PATH))

然而,在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:

torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN, 'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma}, checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')

以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定义损失函数的两个参数;格式以字典的格式存储。

加载的方式:

def load_checkpoint(model, checkpoint_PATH, optimizer): if checkpoint != None: model_CKPT = torch.load(checkpoint_PATH) model.load_state_dict(model_CKPT['state_dict']) print('loading checkpoint!') optimizer.load_state_dict(model_CKPT['optimizer']) return model, optimizer

其他的参数可以通过以字典的方式获得

但是,但是,我们可能修改了一部分网络,比如加了一些,删除一些,等等,那么需要过滤这些参数,加载方式:

def load_checkpoint(model, checkpoint, optimizer, loadOptimizer): if checkpoint != 'No': print("loading checkpoint...") model_dict = model.state_dict() modelCheckpoint = torch.load(checkpoint) pretrained_dict = modelCheckpoint['state_dict'] # 过滤操作 new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} model_dict.update(new_dict) # 打印出来,更新了多少的参数 print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict))) model.load_state_dict(model_dict) print("loaded finished!") # 如果不需要更新优化器那么设置为false if loadOptimizer == True: optimizer.load_state_dict(modelCheckpoint['optimizer']) print('loaded! optimizer') else: print('not loaded optimizer') else: print('No checkpoint is included') return model, optimizer

2.冻结部分参数,训练另一部分参数

1)添加下面一句话到模型中

for p in self.parameters(): p.requires_grad = False

比如加载了resnet预训练模型之后,在resenet的基础上连接了新的模快,resenet模块那部分可以先暂时冻结不更新,只更新其他部分的参数,那么可以在下面加入上面那句话

class RESNET_MF(nn.Module): def init(self, model, pretrained): super(RESNET_MF, self).__init__() self.resnet = model(pretrained) for p in self.parameters(): p.requires_grad = False self.f = SpectralNorm(nn.Conv2d(2048, 512, 1)) self.g = SpectralNorm(nn.Conv2d(2048, 512, 1)) self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1)) ...

同时在优化器中添加:filter(lambda p: p.requires_grad, model.parameters())

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)

2) 参数保存在有序的字典中,那么可以通过查找参数的名字对应的id值,进行冻结

查找的代码:

model_dict = torch.load('net.pth.tar').state_dict() dict_name = list(model_dict) for i, p in enumerate(dict_name): print(i, p)

保存一下这个文件,可以看到大致是这个样子的:

0 gamma 1 resnet.conv1.weight 2 resnet.bn1.weight 3 resnet.bn1.bias 4 resnet.bn1.running_mean 5 resnet.bn1.running_var 6 resnet.layer1.0.conv1.weight 7 resnet.layer1.0.bn1.weight 8 resnet.layer1.0.bn1.bias 9 resnet.layer1.0.bn1.running_mean ....

同样在模型中添加这样的代码:

for i,p in enumerate(net.parameters()): if i < 165: p.requires_grad = False

在优化器中添加上面的那句话可以实现参数的屏蔽

转载于:https://www.cnblogs.com/kk17/p/10074188.html

你可能感兴趣的文章
struts2.X心得2--第一个struts2案例分析以及整合c3p0连接数据库案例分析
查看>>
xml中处理大于小与符号
查看>>
网络七层模型&&网络数据包
查看>>
JavaScript基础---获取元素的属性(title,style,width)
查看>>
Django的认证系统
查看>>
简单了解HashCode()
查看>>
闭包理解
查看>>
asp.net C#后台实现下载文件的几种方法(全)
查看>>
设计模式之命令模式
查看>>
js原型链部分详细使用说明案例
查看>>
JavaScript字符串去除空格
查看>>
ODAC(V9.5.15) 学习笔记(四)TOraDataSet
查看>>
相机-imu外参校准总结
查看>>
数据分析之Pandas(三) DataFrame入门
查看>>
CSS 基础
查看>>
MySQL用户变量的用法
查看>>
HDU 2002 计算球体积
查看>>
Java第八次作业 1502 马 帅
查看>>
大数据时代,百货行业信息化将如何变革?
查看>>
“互联网+”下的数据化运营和技术架构
查看>>