本文基于CANN开源社区的多个仓库进行应用案例讲解

CANN组织地址:https://atomgit.com/cann

runtime仓库地址:https://atomgit.com/cann/runtime

ge仓库地址:https://atomgit.com/cann/ge

前言

模型训练完成后,如何部署到生产环境?如何优化推理性能?如何构建高可用的推理服务?

本文将展示从模型导出、转换、优化到服务部署的完整流程,帮助你在昇腾NPU上构建高性能的推理系统。

模型导出

1. PyTorch模型导出

import torch
import torch_npu

# 加载训练好的模型
model = MyModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()

# 方法1:导出为TorchScript
dummy_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save('model_traced.pt')

# 方法2:导出为ONNX
torch.onnx.export(
    model,
    dummy_input,
    'model.onnx',
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    },
    opset_version=11,
    do_constant_folding=True
)

print("模型导出完成")

2. 验证导出的模型

import onnx
import numpy as np

# 加载ONNX模型
onnx_model = onnx.load('model.onnx')

# 检查模型
onnx.checker.check_model(onnx_model)
print("ONNX模型验证通过")

# 查看模型信息
print("\n输入信息:")
for input in onnx_model.graph.input:
    print(f"  名称: {input.name}")
    print(f"  形状: {[d.dim_value for d in input.type.tensor_type.shape.dim]}")

print("\n输出信息:")
for output in onnx_model.graph.output:
    print(f"  名称: {output.name}")
    print(f"  形状: {[d.dim_value for d in output.type.tensor_type.shape.dim]}")

模型优化

1. 量化

import torch
import torch.quantization as quantization

def quantize_model(model, calibration_data):
    """量化模型"""
    # 设置量化配置
    model.qconfig = quantization.get_default_qconfig('fbgemm')
  
    # 准备量化
    model_prepared = quantization.prepare(model, inplace=False)
  
    # 校准
    model_prepared.eval()
    with torch.no_grad():
        for data in calibration_data:
            model_prepared(data)
  
    # 转换为量化模型
    model_quantized = quantization.convert(model_prepared, inplace=False)
  
    return model_quantized

# 使用
calibration_data = [torch.randn(1, 3, 224, 224) for _ in range(100)]
quantized_model = quantize_model(model, calibration_data)

# 保存量化模型
torch.save(quantized_model.state_dict(), 'model_quantized.pth')

2. 剪枝

import torch.nn.utils.prune as prune

def prune_model(model, amount=0.3):
    """剪枝模型"""
    # 对所有卷积层和线性层进行剪枝
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
            prune.l1_unstructured(module, name='weight', amount=amount)
            # 移除剪枝的重参数化
            prune.remove(module, 'weight')
  
    return model

# 使用
pruned_model = prune_model(model, amount=0.3)
print(f"剪枝后模型大小减少约30%")

3. 知识蒸馏

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
  
    def forward(self, student_logits, teacher_logits, labels):
        # 硬标签损失
        hard_loss = self.ce_loss(student_logits, labels)
      
        # 软标签损失(知识蒸馏)
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1),
            reduction='batchmean'
        ) * (self.temperature ** 2)
      
        # 组合损失
        loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
      
        return loss

def distill_model(teacher_model, student_model, train_loader, epochs=10):
    """知识蒸馏"""
    teacher_model.eval()
    student_model.train()
  
    optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
    criterion = DistillationLoss(temperature=3.0, alpha=0.5)
  
    for epoch in range(epochs):
        for data, labels in train_loader:
            data, labels = data.npu(), labels.npu()
          
            # 教师模型推理
            with torch.no_grad():
                teacher_logits = teacher_model(data)
          
            # 学生模型推理
            student_logits = student_model(data)
          
            # 计算损失
            loss = criterion(student_logits, teacher_logits, labels)
          
            # 更新学生模型
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
      
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
  
    return student_model

推理服务

1. 单模型推理服务

import torch
import torch_npu
from flask import Flask, request, jsonify
import numpy as np
from PIL import Image
import io

app = Flask(__name__)

class InferenceService:
    def __init__(self, model_path):
        # 加载模型
        self.model = torch.jit.load(model_path)
        self.model.eval()
        self.model = self.model.npu()
      
        # 预处理
        from torchvision import transforms
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
  
    def preprocess(self, image_bytes):
        """预处理"""
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        image = self.transform(image)
        image = image.unsqueeze(0)  # 添加batch维度
        return image
  
    def predict(self, image_bytes):
        """推理"""
        # 预处理
        image = self.preprocess(image_bytes)
        image = image.npu()
      
        # 推理
        with torch.no_grad():
            output = self.model(image)
      
        # 后处理
        probabilities = torch.softmax(output, dim=1)
        top5_prob, top5_idx = torch.topk(probabilities, 5)
      
        # 转回CPU
        top5_prob = top5_prob.cpu().numpy()[0]
        top5_idx = top5_idx.cpu().numpy()[0]
      
        # 构建结果
        results = []
        for prob, idx in zip(top5_prob, top5_idx):
            results.append({
                'class_id': int(idx),
                'probability': float(prob)
            })
      
        return results

# 创建服务实例
service = InferenceService('model_traced.pt')

@app.route('/predict', methods=['POST'])
def predict():
    """推理接口"""
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400
  
    image_file = request.files['image']
    image_bytes = image_file.read()
  
    try:
        results = service.predict(image_bytes)
        return jsonify({'predictions': results})
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/health', methods=['GET'])
def health():
    """健康检查"""
    return jsonify({'status': 'healthy'})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8080)

2. 批量推理服务

import queue
import threading
import time

class BatchInferenceService:
    def __init__(self, model_path, batch_size=32, max_wait_time=0.1):
        self.model = torch.jit.load(model_path)
        self.model.eval()
        self.model = self.model.npu()
      
        self.batch_size = batch_size
        self.max_wait_time = max_wait_time
      
        # 请求队列
        self.request_queue = queue.Queue()
        self.response_dict = {}
      
        # 启动批处理线程
        self.batch_thread = threading.Thread(target=self._batch_worker, daemon=True)
        self.batch_thread.start()
  
    def _batch_worker(self):
        """批处理工作线程"""
        while True:
            batch_requests = []
            batch_ids = []
          
            # 收集请求
            start_time = time.time()
            while len(batch_requests) < self.batch_size:
                try:
                    timeout = self.max_wait_time - (time.time() - start_time)
                    if timeout <= 0:
                        break
                  
                    request_id, data = self.request_queue.get(timeout=timeout)
                    batch_requests.append(data)
                    batch_ids.append(request_id)
                except queue.Empty:
                    break
          
            if not batch_requests:
                continue
          
            # 批量推理
            batch_tensor = torch.cat(batch_requests, dim=0).npu()
          
            with torch.no_grad():
                outputs = self.model(batch_tensor)
          
            # 分发结果
            for i, request_id in enumerate(batch_ids):
                self.response_dict[request_id] = outputs[i:i+1]
  
    def predict(self, data):
        """推理接口"""
        request_id = id(data)
      
        # 提交请求
        self.request_queue.put((request_id, data))
      
        # 等待结果
        while request_id not in self.response_dict:
            time.sleep(0.001)
      
        # 获取结果
        result = self.response_dict.pop(request_id)
      
        return result

# 使用
batch_service = BatchInferenceService('model_traced.pt', batch_size=32)

@app.route('/predict_batch', methods=['POST'])
def predict_batch():
    """批量推理接口"""
    image_file = request.files['image']
    image_bytes = image_file.read()
  
    # 预处理
    image = service.preprocess(image_bytes)
  
    # 批量推理
    output = batch_service.predict(image)
  
    # 后处理
    probabilities = torch.softmax(output, dim=1)
    top5_prob, top5_idx = torch.topk(probabilities, 5)
  
    results = []
    for prob, idx in zip(top5_prob[0], top5_idx[0]):
        results.append({
            'class_id': int(idx),
            'probability': float(prob)
        })
  
    return jsonify({'predictions': results})

3. 多模型服务

class MultiModelService:
    def __init__(self, model_configs):
        """
        model_configs: {
            'model_name': {
                'path': 'model.pt',
                'device': 0
            }
        }
        """
        self.models = {}
      
        for name, config in model_configs.items():
            model = torch.jit.load(config['path'])
            model.eval()
            model = model.npu(config['device'])
            self.models[name] = model
  
    def predict(self, model_name, data):
        """推理"""
        if model_name not in self.models:
            raise ValueError(f"Model {model_name} not found")
      
        model = self.models[model_name]
      
        with torch.no_grad():
            output = model(data)
      
        return output

# 使用
multi_service = MultiModelService({
    'resnet50': {'path': 'resnet50.pt', 'device': 0},
    'mobilenet': {'path': 'mobilenet.pt', 'device': 1},
    'efficientnet': {'path': 'efficientnet.pt', 'device': 2}
})

@app.route('/predict/<model_name>', methods=['POST'])
def predict_multi(model_name):
    """多模型推理接口"""
    image_file = request.files['image']
    image_bytes = image_file.read()
  
    # 预处理
    image = service.preprocess(image_bytes)
    image = image.npu()
  
    # 推理
    output = multi_service.predict(model_name, image)
  
    # 后处理
    probabilities = torch.softmax(output, dim=1)
    top5_prob, top5_idx = torch.topk(probabilities, 5)
  
    results = []
    for prob, idx in zip(top5_prob[0], top5_idx[0]):
        results.append({
            'class_id': int(idx),
            'probability': float(prob)
        })
  
    return jsonify({'predictions': results})

性能优化

1. 预热

def warmup_model(model, input_shape, num_iterations=10):
    """模型预热"""
    dummy_input = torch.randn(*input_shape).npu()
  
    for _ in range(num_iterations):
        with torch.no_grad():
            _ = model(dummy_input)
  
    torch.npu.synchronize()
    print("模型预热完成")

# 使用
warmup_model(model, (1, 3, 224, 224))

2. 动态batch

class DynamicBatchService:
    def __init__(self, model_path, supported_batch_sizes=[1, 4, 8, 16, 32]):
        self.models = {}
      
        # 为每个batch size预热模型
        base_model = torch.jit.load(model_path)
        base_model.eval()
        base_model = base_model.npu()
      
        for bs in supported_batch_sizes:
            dummy_input = torch.randn(bs, 3, 224, 224).npu()
          
            # 预热
            for _ in range(10):
                with torch.no_grad():
                    _ = base_model(dummy_input)
          
            self.models[bs] = base_model
      
        self.supported_batch_sizes = sorted(supported_batch_sizes)
  
    def predict(self, data_list):
        """推理"""
        batch_size = len(data_list)
      
        # 选择最接近的batch size
        selected_bs = min(
            [bs for bs in self.supported_batch_sizes if bs >= batch_size],
            default=self.supported_batch_sizes[-1]
        )
      
        # 填充到selected_bs
        if batch_size < selected_bs:
            padding = [data_list[0]] * (selected_bs - batch_size)
            data_list = data_list + padding
      
        # 推理
        batch_tensor = torch.cat(data_list, dim=0).npu()
        model = self.models[selected_bs]
      
        with torch.no_grad():
            outputs = model(batch_tensor)
      
        # 返回实际需要的结果
        return outputs[:batch_size]

3. 缓存

from functools import lru_cache
import hashlib

class CachedInferenceService:
    def __init__(self, model_path, cache_size=1000):
        self.model = torch.jit.load(model_path)
        self.model.eval()
        self.model = self.model.npu()
      
        self.cache_size = cache_size
        self.cache = {}
  
    def _compute_hash(self, data):
        """计算数据哈希"""
        return hashlib.md5(data.cpu().numpy().tobytes()).hexdigest()
  
    def predict(self, data):
        """带缓存的推理"""
        # 计算哈希
        data_hash = self._compute_hash(data)
      
        # 检查缓存
        if data_hash in self.cache:
            return self.cache[data_hash]
      
        # 推理
        with torch.no_grad():
            output = self.model(data)
      
        # 更新缓存
        if len(self.cache) >= self.cache_size:
            # 删除最旧的条目
            self.cache.pop(next(iter(self.cache)))
      
        self.cache[data_hash] = output
      
        return output

监控和日志

1. 性能监控

import time
from collections import deque

class PerformanceMonitor:
    def __init__(self, window_size=100):
        self.latencies = deque(maxlen=window_size)
        self.throughputs = deque(maxlen=window_size)
        self.start_time = time.time()
        self.total_requests = 0
  
    def record_request(self, latency, batch_size=1):
        """记录请求"""
        self.latencies.append(latency)
        self.throughputs.append(batch_size / latency)
        self.total_requests += batch_size
  
    def get_stats(self):
        """获取统计信息"""
        if not self.latencies:
            return {}
      
        return {
            'avg_latency': sum(self.latencies) / len(self.latencies),
            'p50_latency': sorted(self.latencies)[len(self.latencies) // 2],
            'p95_latency': sorted(self.latencies)[int(len(self.latencies) * 0.95)],
            'p99_latency': sorted(self.latencies)[int(len(self.latencies) * 0.99)],
            'avg_throughput': sum(self.throughputs) / len(self.throughputs),
            'total_requests': self.total_requests,
            'uptime': time.time() - self.start_time
        }

# 使用
monitor = PerformanceMonitor()

@app.route('/predict', methods=['POST'])
def predict_with_monitor():
    start_time = time.time()
  
    # 推理
    result = service.predict(image_bytes)
  
    # 记录性能
    latency = time.time() - start_time
    monitor.record_request(latency)
  
    return jsonify({'predictions': result})

@app.route('/metrics', methods=['GET'])
def metrics():
    """性能指标"""
    stats = monitor.get_stats()
    return jsonify(stats)

2. 日志记录

import logging
from datetime import datetime

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('inference.log'),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger(__name__)

@app.route('/predict', methods=['POST'])
def predict_with_logging():
    request_id = datetime.now().strftime('%Y%m%d%H%M%S%f')
  
    logger.info(f"Request {request_id} started")
  
    try:
        start_time = time.time()
        result = service.predict(image_bytes)
        latency = time.time() - start_time
      
        logger.info(f"Request {request_id} completed in {latency:.3f}s")
      
        return jsonify({'predictions': result})
  
    except Exception as e:
        logger.error(f"Request {request_id} failed: {str(e)}")
        return jsonify({'error': str(e)}), 500

Docker部署

1. Dockerfile

FROM ascendhub.huawei.com/public-ascendhub/ascend-pytorch:23.0.RC3-ubuntu18.04

WORKDIR /app

# 安装依赖
COPY requirements.txt .
RUN pip install -r requirements.txt

# 复制代码和模型
COPY . .

# 暴露端口
EXPOSE 8080

# 启动服务
CMD ["python", "app.py"]

2. docker-compose.yml

version: '3.8'

services:
  inference:
    build: .
    ports:
      - "8080:8080"
    volumes:
      - ./models:/app/models
    environment:
      - MODEL_PATH=/app/models/model.pt
      - BATCH_SIZE=32
    deploy:
      resources:
        reservations:
          devices:
            - driver: ascend
              device_ids: ['0']
              capabilities: [npu]

总结

CANN模型部署的关键点:

  • 模型导出:PyTorch → ONNX/TorchScript
  • 模型优化:量化、剪枝、蒸馏
  • 推理服务:单模型、批量、多模型
  • 性能优化:预热、动态batch、缓存
  • 监控日志:性能指标、请求日志
  • 容器化:Docker部署

通过这些技术,可以构建高性能、高可用的推理服务。

相关链接

runtime仓库地址:https://atomgit.com/cann/runtime

ge仓库地址:https://atomgit.com/cann/ge

CANN组织地址:https://atomgit.com/cann

Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐