CANN模型部署与优化实战:从训练到生产的完整流程
·
本文基于CANN开源社区的多个仓库进行应用案例讲解
CANN组织地址:https://atomgit.com/cann
runtime仓库地址:https://atomgit.com/cann/runtime
ge仓库地址:https://atomgit.com/cann/ge
前言
模型训练完成后,如何部署到生产环境?如何优化推理性能?如何构建高可用的推理服务?
本文将展示从模型导出、转换、优化到服务部署的完整流程,帮助你在昇腾NPU上构建高性能的推理系统。
模型导出
1. PyTorch模型导出
import torch
import torch_npu
# 加载训练好的模型
model = MyModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()
# 方法1:导出为TorchScript
dummy_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save('model_traced.pt')
# 方法2:导出为ONNX
torch.onnx.export(
model,
dummy_input,
'model.onnx',
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
},
opset_version=11,
do_constant_folding=True
)
print("模型导出完成")
2. 验证导出的模型
import onnx
import numpy as np
# 加载ONNX模型
onnx_model = onnx.load('model.onnx')
# 检查模型
onnx.checker.check_model(onnx_model)
print("ONNX模型验证通过")
# 查看模型信息
print("\n输入信息:")
for input in onnx_model.graph.input:
print(f" 名称: {input.name}")
print(f" 形状: {[d.dim_value for d in input.type.tensor_type.shape.dim]}")
print("\n输出信息:")
for output in onnx_model.graph.output:
print(f" 名称: {output.name}")
print(f" 形状: {[d.dim_value for d in output.type.tensor_type.shape.dim]}")
模型优化
1. 量化
import torch
import torch.quantization as quantization
def quantize_model(model, calibration_data):
"""量化模型"""
# 设置量化配置
model.qconfig = quantization.get_default_qconfig('fbgemm')
# 准备量化
model_prepared = quantization.prepare(model, inplace=False)
# 校准
model_prepared.eval()
with torch.no_grad():
for data in calibration_data:
model_prepared(data)
# 转换为量化模型
model_quantized = quantization.convert(model_prepared, inplace=False)
return model_quantized
# 使用
calibration_data = [torch.randn(1, 3, 224, 224) for _ in range(100)]
quantized_model = quantize_model(model, calibration_data)
# 保存量化模型
torch.save(quantized_model.state_dict(), 'model_quantized.pth')
2. 剪枝
import torch.nn.utils.prune as prune
def prune_model(model, amount=0.3):
"""剪枝模型"""
# 对所有卷积层和线性层进行剪枝
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
prune.l1_unstructured(module, name='weight', amount=amount)
# 移除剪枝的重参数化
prune.remove(module, 'weight')
return model
# 使用
pruned_model = prune_model(model, amount=0.3)
print(f"剪枝后模型大小减少约30%")
3. 知识蒸馏
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def __init__(self, temperature=3.0, alpha=0.5):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, student_logits, teacher_logits, labels):
# 硬标签损失
hard_loss = self.ce_loss(student_logits, labels)
# 软标签损失(知识蒸馏)
soft_loss = F.kl_div(
F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1),
reduction='batchmean'
) * (self.temperature ** 2)
# 组合损失
loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
return loss
def distill_model(teacher_model, student_model, train_loader, epochs=10):
"""知识蒸馏"""
teacher_model.eval()
student_model.train()
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
criterion = DistillationLoss(temperature=3.0, alpha=0.5)
for epoch in range(epochs):
for data, labels in train_loader:
data, labels = data.npu(), labels.npu()
# 教师模型推理
with torch.no_grad():
teacher_logits = teacher_model(data)
# 学生模型推理
student_logits = student_model(data)
# 计算损失
loss = criterion(student_logits, teacher_logits, labels)
# 更新学生模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
return student_model
推理服务
1. 单模型推理服务
import torch
import torch_npu
from flask import Flask, request, jsonify
import numpy as np
from PIL import Image
import io
app = Flask(__name__)
class InferenceService:
def __init__(self, model_path):
# 加载模型
self.model = torch.jit.load(model_path)
self.model.eval()
self.model = self.model.npu()
# 预处理
from torchvision import transforms
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def preprocess(self, image_bytes):
"""预处理"""
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
image = self.transform(image)
image = image.unsqueeze(0) # 添加batch维度
return image
def predict(self, image_bytes):
"""推理"""
# 预处理
image = self.preprocess(image_bytes)
image = image.npu()
# 推理
with torch.no_grad():
output = self.model(image)
# 后处理
probabilities = torch.softmax(output, dim=1)
top5_prob, top5_idx = torch.topk(probabilities, 5)
# 转回CPU
top5_prob = top5_prob.cpu().numpy()[0]
top5_idx = top5_idx.cpu().numpy()[0]
# 构建结果
results = []
for prob, idx in zip(top5_prob, top5_idx):
results.append({
'class_id': int(idx),
'probability': float(prob)
})
return results
# 创建服务实例
service = InferenceService('model_traced.pt')
@app.route('/predict', methods=['POST'])
def predict():
"""推理接口"""
if 'image' not in request.files:
return jsonify({'error': 'No image provided'}), 400
image_file = request.files['image']
image_bytes = image_file.read()
try:
results = service.predict(image_bytes)
return jsonify({'predictions': results})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/health', methods=['GET'])
def health():
"""健康检查"""
return jsonify({'status': 'healthy'})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8080)
2. 批量推理服务
import queue
import threading
import time
class BatchInferenceService:
def __init__(self, model_path, batch_size=32, max_wait_time=0.1):
self.model = torch.jit.load(model_path)
self.model.eval()
self.model = self.model.npu()
self.batch_size = batch_size
self.max_wait_time = max_wait_time
# 请求队列
self.request_queue = queue.Queue()
self.response_dict = {}
# 启动批处理线程
self.batch_thread = threading.Thread(target=self._batch_worker, daemon=True)
self.batch_thread.start()
def _batch_worker(self):
"""批处理工作线程"""
while True:
batch_requests = []
batch_ids = []
# 收集请求
start_time = time.time()
while len(batch_requests) < self.batch_size:
try:
timeout = self.max_wait_time - (time.time() - start_time)
if timeout <= 0:
break
request_id, data = self.request_queue.get(timeout=timeout)
batch_requests.append(data)
batch_ids.append(request_id)
except queue.Empty:
break
if not batch_requests:
continue
# 批量推理
batch_tensor = torch.cat(batch_requests, dim=0).npu()
with torch.no_grad():
outputs = self.model(batch_tensor)
# 分发结果
for i, request_id in enumerate(batch_ids):
self.response_dict[request_id] = outputs[i:i+1]
def predict(self, data):
"""推理接口"""
request_id = id(data)
# 提交请求
self.request_queue.put((request_id, data))
# 等待结果
while request_id not in self.response_dict:
time.sleep(0.001)
# 获取结果
result = self.response_dict.pop(request_id)
return result
# 使用
batch_service = BatchInferenceService('model_traced.pt', batch_size=32)
@app.route('/predict_batch', methods=['POST'])
def predict_batch():
"""批量推理接口"""
image_file = request.files['image']
image_bytes = image_file.read()
# 预处理
image = service.preprocess(image_bytes)
# 批量推理
output = batch_service.predict(image)
# 后处理
probabilities = torch.softmax(output, dim=1)
top5_prob, top5_idx = torch.topk(probabilities, 5)
results = []
for prob, idx in zip(top5_prob[0], top5_idx[0]):
results.append({
'class_id': int(idx),
'probability': float(prob)
})
return jsonify({'predictions': results})
3. 多模型服务
class MultiModelService:
def __init__(self, model_configs):
"""
model_configs: {
'model_name': {
'path': 'model.pt',
'device': 0
}
}
"""
self.models = {}
for name, config in model_configs.items():
model = torch.jit.load(config['path'])
model.eval()
model = model.npu(config['device'])
self.models[name] = model
def predict(self, model_name, data):
"""推理"""
if model_name not in self.models:
raise ValueError(f"Model {model_name} not found")
model = self.models[model_name]
with torch.no_grad():
output = model(data)
return output
# 使用
multi_service = MultiModelService({
'resnet50': {'path': 'resnet50.pt', 'device': 0},
'mobilenet': {'path': 'mobilenet.pt', 'device': 1},
'efficientnet': {'path': 'efficientnet.pt', 'device': 2}
})
@app.route('/predict/<model_name>', methods=['POST'])
def predict_multi(model_name):
"""多模型推理接口"""
image_file = request.files['image']
image_bytes = image_file.read()
# 预处理
image = service.preprocess(image_bytes)
image = image.npu()
# 推理
output = multi_service.predict(model_name, image)
# 后处理
probabilities = torch.softmax(output, dim=1)
top5_prob, top5_idx = torch.topk(probabilities, 5)
results = []
for prob, idx in zip(top5_prob[0], top5_idx[0]):
results.append({
'class_id': int(idx),
'probability': float(prob)
})
return jsonify({'predictions': results})
性能优化
1. 预热
def warmup_model(model, input_shape, num_iterations=10):
"""模型预热"""
dummy_input = torch.randn(*input_shape).npu()
for _ in range(num_iterations):
with torch.no_grad():
_ = model(dummy_input)
torch.npu.synchronize()
print("模型预热完成")
# 使用
warmup_model(model, (1, 3, 224, 224))
2. 动态batch
class DynamicBatchService:
def __init__(self, model_path, supported_batch_sizes=[1, 4, 8, 16, 32]):
self.models = {}
# 为每个batch size预热模型
base_model = torch.jit.load(model_path)
base_model.eval()
base_model = base_model.npu()
for bs in supported_batch_sizes:
dummy_input = torch.randn(bs, 3, 224, 224).npu()
# 预热
for _ in range(10):
with torch.no_grad():
_ = base_model(dummy_input)
self.models[bs] = base_model
self.supported_batch_sizes = sorted(supported_batch_sizes)
def predict(self, data_list):
"""推理"""
batch_size = len(data_list)
# 选择最接近的batch size
selected_bs = min(
[bs for bs in self.supported_batch_sizes if bs >= batch_size],
default=self.supported_batch_sizes[-1]
)
# 填充到selected_bs
if batch_size < selected_bs:
padding = [data_list[0]] * (selected_bs - batch_size)
data_list = data_list + padding
# 推理
batch_tensor = torch.cat(data_list, dim=0).npu()
model = self.models[selected_bs]
with torch.no_grad():
outputs = model(batch_tensor)
# 返回实际需要的结果
return outputs[:batch_size]
3. 缓存
from functools import lru_cache
import hashlib
class CachedInferenceService:
def __init__(self, model_path, cache_size=1000):
self.model = torch.jit.load(model_path)
self.model.eval()
self.model = self.model.npu()
self.cache_size = cache_size
self.cache = {}
def _compute_hash(self, data):
"""计算数据哈希"""
return hashlib.md5(data.cpu().numpy().tobytes()).hexdigest()
def predict(self, data):
"""带缓存的推理"""
# 计算哈希
data_hash = self._compute_hash(data)
# 检查缓存
if data_hash in self.cache:
return self.cache[data_hash]
# 推理
with torch.no_grad():
output = self.model(data)
# 更新缓存
if len(self.cache) >= self.cache_size:
# 删除最旧的条目
self.cache.pop(next(iter(self.cache)))
self.cache[data_hash] = output
return output
监控和日志
1. 性能监控
import time
from collections import deque
class PerformanceMonitor:
def __init__(self, window_size=100):
self.latencies = deque(maxlen=window_size)
self.throughputs = deque(maxlen=window_size)
self.start_time = time.time()
self.total_requests = 0
def record_request(self, latency, batch_size=1):
"""记录请求"""
self.latencies.append(latency)
self.throughputs.append(batch_size / latency)
self.total_requests += batch_size
def get_stats(self):
"""获取统计信息"""
if not self.latencies:
return {}
return {
'avg_latency': sum(self.latencies) / len(self.latencies),
'p50_latency': sorted(self.latencies)[len(self.latencies) // 2],
'p95_latency': sorted(self.latencies)[int(len(self.latencies) * 0.95)],
'p99_latency': sorted(self.latencies)[int(len(self.latencies) * 0.99)],
'avg_throughput': sum(self.throughputs) / len(self.throughputs),
'total_requests': self.total_requests,
'uptime': time.time() - self.start_time
}
# 使用
monitor = PerformanceMonitor()
@app.route('/predict', methods=['POST'])
def predict_with_monitor():
start_time = time.time()
# 推理
result = service.predict(image_bytes)
# 记录性能
latency = time.time() - start_time
monitor.record_request(latency)
return jsonify({'predictions': result})
@app.route('/metrics', methods=['GET'])
def metrics():
"""性能指标"""
stats = monitor.get_stats()
return jsonify(stats)
2. 日志记录
import logging
from datetime import datetime
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('inference.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
@app.route('/predict', methods=['POST'])
def predict_with_logging():
request_id = datetime.now().strftime('%Y%m%d%H%M%S%f')
logger.info(f"Request {request_id} started")
try:
start_time = time.time()
result = service.predict(image_bytes)
latency = time.time() - start_time
logger.info(f"Request {request_id} completed in {latency:.3f}s")
return jsonify({'predictions': result})
except Exception as e:
logger.error(f"Request {request_id} failed: {str(e)}")
return jsonify({'error': str(e)}), 500
Docker部署
1. Dockerfile
FROM ascendhub.huawei.com/public-ascendhub/ascend-pytorch:23.0.RC3-ubuntu18.04
WORKDIR /app
# 安装依赖
COPY requirements.txt .
RUN pip install -r requirements.txt
# 复制代码和模型
COPY . .
# 暴露端口
EXPOSE 8080
# 启动服务
CMD ["python", "app.py"]
2. docker-compose.yml
version: '3.8'
services:
inference:
build: .
ports:
- "8080:8080"
volumes:
- ./models:/app/models
environment:
- MODEL_PATH=/app/models/model.pt
- BATCH_SIZE=32
deploy:
resources:
reservations:
devices:
- driver: ascend
device_ids: ['0']
capabilities: [npu]
总结
CANN模型部署的关键点:
- 模型导出:PyTorch → ONNX/TorchScript
- 模型优化:量化、剪枝、蒸馏
- 推理服务:单模型、批量、多模型
- 性能优化:预热、动态batch、缓存
- 监控日志:性能指标、请求日志
- 容器化:Docker部署
通过这些技术,可以构建高性能、高可用的推理服务。
相关链接
runtime仓库地址:https://atomgit.com/cann/runtime
ge仓库地址:https://atomgit.com/cann/ge
CANN组织地址:https://atomgit.com/cann
昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链
更多推荐


所有评论(0)