如题,我通过以下代码加载测试集并在训练中验证准确率,想请问这个时候是不是其实只有部分测试数据在验证,测试集不能加.repeat和.batch

class EvalCallBack(Callback):
    def __init__(self, model, eval_dataset, eval_per_epoch, epoch_per_eval):
        self.model = model
        self.eval_dataset = eval_dataset
        self.eval_per_epoch = eval_per_epoch
        self.epoch_per_eval = epoch_per_eval
 
 
    def on_train_epoch_end(self, run_context):
        cb_param = run_context.original_args()
        cur_epoch = cb_param.cur_epoch_num
        if cur_epoch % self.eval_per_epoch == 0:
            acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
            self.epoch_per_eval["epoch"].append(cur_epoch)
            self.epoch_per_eval["acc"].append(acc["Accuracy"])
            print(acc)
    # apply batch operations
    ds = ds.batch(batch_size, drop_remainder=True)

    # apply dataset repeat operation
    ds = ds.repeat(repeat_num)

****************************************************解答*****************************************************

你是使用的哪个数据处理的接口,mindspore数据处理接口都支持支持repeat操作,这是所有的数据处理接口。是否在测试集上使用batch和repeat根据自己实际情况而定。

https://www.mindspore.cn/docs/zh-CN/r1.9/api_python/mindspore.dataset.html?highlight=dataset

eg:可借鉴bert网络训练接推理的操作。

       https://gitee.com/mindspore/models/blob/r1.9/official/nlp/bert/run_classifier.py

Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐