本站消息

站长简介/公众号


站长简介:高级软件工程师,曾在阿里云,每日优鲜从事全栈开发工作,利用周末时间开发出本站,欢迎关注我的微信公众号:程序员总部,程序员的家,探索程序员的人生之路!分享IT最新技术,关注行业最新动向,让你永不落伍。了解同行们的工资,生活工作中的酸甜苦辣,谋求程序员的最终出路!

  价值13000svip视频教程,python大神匠心打造,零基础python开发工程师视频教程全套,基础+进阶+项目实战,包含课件和源码

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2020-12(9)

2021-01(59)

pytorch加载模型错误 RuntimeError: Error(s) in loading state_dict for Model: Missing key(s) in state_dict

发布于2021-10-18 00:16     阅读(237)     评论(0)     点赞(1)     收藏(3)



模型在保存时侯以键对值保存,同时在加载时根据现在网络的键值查找模型对应的键值,然后加载。一般报错是因为模型和网络的键值不匹配。

1、最常见的问题是键值多了或者少了 module.

此种情况是模型在DataParallel或者DDP训练后保存的键值有module. ,对应的网络的键值则没有module.

1)可以通过:

model = nn.DataParallel(model)

将模型的键值加上module.

2) 也可以通过遍历模型的键对值修改键值。

   如:加载模型时删除多余的module.  代码如下

  1. state_dict = torch.load(load_path)
  2. for key, param in state_dict.items():
  3. if key.startswith('module.'): #键值包含‘module.’ 则删除
  4. state_dict[key[7:]] = param
  5. state_dict.pop(key)
  6. net.load_state_dict(state_dict)

2、详解load_state_dict(state_dict, False)的False参数

很多教程说名字不匹配直接添加False参数即可,但是这里需要注意一个大坑。

如果模型的键值和网络的键值完全不匹配,那么模型就没有加载预训练参数,虽然不再报错。

该False参数作用在于 非严格匹配加载模型,可以下面几种情况进行分析

1)模型包含网络的部分参数

比如说模型是resnet101模型,你现在的网络是resnet50。再假设resnet50的参数名包含在resnet101的参数中,那么直接使用False会为你的网络resnet50加载键值相同的参数。这样就避免了对resnet101的每个键对值进行循环匹配,看是否是resnet50需要的。

2)模型完全不包含网络的参数

情况如1,模型有100个参数,都包含'module.' ,网络也有100个参数,都没有'module.' 。这种情况下如果参数设置为False,会发现没有任何键值能匹配上,因此网络就不会加载任何参数。

3)再介绍一个False使用场景

比如蒸馏网络PISR中,教师网络包含Encoder和Decoder两部分,学生网络由其中的Decoder部分组成,所以在训练学生网络时,如果要加载教师网络保存的预训练模型,设置False会自动识别Decoder部分键值相同,然后加载。

综上,设置False参数后依旧是按照键值查询加载参数的,有多少键值匹配,就加载多少模型的参数。

 

3、只要参数尺寸相同,就能加载

比如说我有一个10层网络的模型,还有一个3层的网络。我想把其中第9层的参数加载到现在网络的1层。如果参数的尺寸相同,就可以遍历键对值。将参数加载到想要的键值中。

  1. state_dict = torch.load(load_path)
  2. new_state_dict = []
  3. for key, param in state_dict.items():
  4. if 'conv9' in key: # 如果找到conv9对应的参数,将其键值替换为网络的键
  5. new_state_dict[key.replace('conv9', 'conv1')] = param
  6. net.load_state_dict(new_state_dict)

原文链接:https://blog.csdn.net/longshaonihaoa/article/details/120770446







所属网站分类: 技术文章 > 博客

作者:lg

链接:https://www.pythonheidong.com/blog/article/1060687/4ede7a15d64e31ac45ab/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

1 0
收藏该文
已收藏

评论内容:(最多支持255个字符)