错误原因

  • 使用了 torch.load() 加载模型

解决办法

  • 使用 model.load_state_dict(torch.load(./model.pt,map_location=‘cpu’),strict=False)

  • map_location:当模型训练的时候用的 gpu 但是加载的时候用的是 cpu 环境,这个时候要进行映射

  • strict=False 否则容易报错:Unexpected key(s) in state_dict: “lstm.weight_ih_l3”, “lstm.weight_hh_l3”…

总结

如无必要,尽量用 load_state_dict 的方式来加载模型,这种方式更稳定
用 torch.load() 的话,很容易报各种各样的错误

Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐