import os
import sys
import torch
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("gpu", type=str, choices=["0", "1", "2", "3", "4", "5", "6", "7"])
args = parser.parse_args()

os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"  # 使得gpu硬件序号与程序序号一直对应
#os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu  # 选择要使用的gpu,空string代表不使用gpu  '/device:GPU:X'


device = torch.device('cuda:{}'.format(args.gpu))
model = torch.load(checkpoint_path,map_location=device)

注意:当使用 os.environ['CUDA_VISIBLE_DEVICES']指定某个GPU后, torch.device('cuda:{}'.format(args.gpu))是对指定的gpu们进行重新排序后的每个GPU的编号

如,当指定 os.environ['CUDA_VISIBLE_DEVICES'] = “5,6,7”后,

device = torch.device('cuda:2')实际上是用的GPU-7,而不是GPU-2.

当指定 os.environ['CUDA_VISIBLE_DEVICES'] = “5,6,7”后,device = torch.device('cuda:5')会报错,因为没有这个GPU。

所以建议:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu  

device = torch.device('cuda:{}'.format(args.gpu))

不要同时使用,二者只使用一个就可以了。

参考:https://blog.csdn.net/iamjingong/article/details/85308600 

Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐