☀️ 最近报名参加了昇思25天学习打卡训练营
☀️ 第1天初步学习了MindSpore的基本操作
☀️ 第2天初步学习了张量Tensor
☀️ 第3天初步学习了数据集Dataset
☀️ 第4天初步学习了数据变换Transforms
☀️ 第5天初步学习了网络构建
☀️ 第6天初步学习了函数式自动微分
☀️ 第7天初步学习了模型训练
☀️ 第8天学习 初学入门 / 初学教程 / 09-保存与加载

1. 教程与代码

上一章节主要介绍了如何调整超参数,并进行网络模型训练。在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,本章节我们将介绍如何保存与加载模型。

首先还是导入库

import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor

定义一个构建神经网络的函数:

def network():  
    # 使用SequentialCell创建一个顺序模型,这是一种堆叠多个层的线性模型  
    model = nn.SequentialCell(  
        # 将输入数据展平成一维数组,假设输入数据的形状为[batch_size, 28, 28],即28x28的图像  
        nn.Flatten(),  
        # 第一个全连接层(Dense),将展平后的数据(长度为28*28)映射到512个神经元上  
        nn.Dense(28*28, 512),  
        # 激活函数,ReLU(Rectified Linear Unit),用于增加非线性  
        nn.ReLU(),  
        # 第二个全连接层,将前一层的512个输出映射到另一组512个神经元上  
        nn.Dense(512, 512),  
        # 再次应用ReLU激活函数  
        nn.ReLU(),  
        # 第三个全连接层,也是输出层,将前一层的512个输出映射到10个神经元上(假设是一个10类分类问题)  
        nn.Dense(512, 10)  
    )  
    # 返回构建好的模型  
    return model   

1.1 保存和加载模型权重

保存模型使用save_checkpoint接口,传入网络和指定的保存路径:

model = network()
mindspore.save_checkpoint(model, "model.ckpt")

要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpoint和load_param_into_net方法加载参数。

model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

输出:

[]

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

上面输出结果表示我们的所有参数都加载成功了。

1.2 保存和加载MindIR

除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export接口直接将模型保存为MindIR。

model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。

已有的MindIR模型可以方便地通过load接口加载,传入nn.GraphCell即可进行推理。

nn.GraphCell仅支持图模式。

mindspore.set_context(mode=mindspore.GRAPH_MODE)

graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)

输出:

(1, 10)

2. 小结

今天学习了保存和加载模型超参,内容比较简单。

在这里插入图片描述

Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐