昇思25天学习打卡营第8天|保存与加载
☀️ 最近报名参加了昇思25天学习打卡训练营
☀️ 第1天初步学习了MindSpore的基本操作
☀️ 第2天初步学习了张量Tensor
☀️ 第3天初步学习了数据集Dataset
☀️ 第4天初步学习了数据变换Transforms
☀️ 第5天初步学习了网络构建
☀️ 第6天初步学习了函数式自动微分
☀️ 第7天初步学习了模型训练
☀️ 第8天学习 初学入门 / 初学教程 / 09-保存与加载
1. 教程与代码
上一章节主要介绍了如何调整超参数,并进行网络模型训练。在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,本章节我们将介绍如何保存与加载模型。
首先还是导入库
import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor
定义一个构建神经网络的函数:
def network():
# 使用SequentialCell创建一个顺序模型,这是一种堆叠多个层的线性模型
model = nn.SequentialCell(
# 将输入数据展平成一维数组,假设输入数据的形状为[batch_size, 28, 28],即28x28的图像
nn.Flatten(),
# 第一个全连接层(Dense),将展平后的数据(长度为28*28)映射到512个神经元上
nn.Dense(28*28, 512),
# 激活函数,ReLU(Rectified Linear Unit),用于增加非线性
nn.ReLU(),
# 第二个全连接层,将前一层的512个输出映射到另一组512个神经元上
nn.Dense(512, 512),
# 再次应用ReLU激活函数
nn.ReLU(),
# 第三个全连接层,也是输出层,将前一层的512个输出映射到10个神经元上(假设是一个10类分类问题)
nn.Dense(512, 10)
)
# 返回构建好的模型
return model
1.1 保存和加载模型权重
保存模型使用save_checkpoint接口,传入网络和指定的保存路径:
model = network()
mindspore.save_checkpoint(model, "model.ckpt")
要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpoint和load_param_into_net方法加载参数。
model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)
输出:
[]
param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。
上面输出结果表示我们的所有参数都加载成功了。
1.2 保存和加载MindIR
除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export接口直接将模型保存为MindIR。
model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。
已有的MindIR模型可以方便地通过load接口加载,传入nn.GraphCell即可进行推理。
nn.GraphCell仅支持图模式。
mindspore.set_context(mode=mindspore.GRAPH_MODE)
graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)
输出:
(1, 10)
2. 小结
今天学习了保存和加载模型超参,内容比较简单。

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链
更多推荐

所有评论(0)