image.png

我们已经将算子的单核与多核性能优化到了一个非常高的水平。但所有这些优化,都建立在一个重要的前提之上:算子的输入形状(Shape)在编译时是固定的。然而,在真实的AI应用中,尤其是NLP和推荐模型,输入的形状常常是动态变化的(例如,处理不同长度的句子,batch size可变)。

这篇文章,我们将挑战CANN算-子开发中一个极为重要且复杂的领域,它决定了一个算子是否具备“实战”能力——动态Shape(Dynamic Shape)支持


1. 新的挑战:当世界不再是静态的

在我们之前所有的讨论中,我们都隐式地做了一个假设:算子的输入Tensor形状,比如(Batch, Channel, Height, Width),在编译算子代码时是完全已知的。这使得我们可以进行各种编译期优化,甚至将某些维度作为模板参数写死在代码里。

但在真实的生产环境中,静态Shape的假设往往不成立:

  • NLP领域:文本输入的句子长度(sequence_length)几乎总是在变化。
  • 推荐系统:为了提升吞吐,batch_size可能会根据请求量动态调整。
  • CV领域:模型可能需要处理不同分辨率的图片。

如果一个算子只支持静态Shape,那么每遇到一个新的输入形状,就需要重新编译一次算子。这在推理服务中是完全不可接受的。因此,让算子支持动态Shape,是其从“实验品”走向“产品化”的关键一步。

2. CANN的答案:编译期“规则制定”与运行期“按图索骥”

如何处理这种“未知”?CANN框架给出了一套优雅的两阶段解决方案,将问题分解到编译期运行期分别处理。

  • 编译期 (Compile Time):此时在Host侧(CPU)执行。我们不知道具体的维度值,但我们知道形状的计算规则InferShape函数就在这个阶段大显身手。它的任务不是计算具体值,而是建立一个符号化的形状推导表达式
  • 运行期 (Run Time):此时在Device侧(NPU AICore)执行。当一个具体形状的Tensor输入到模型中时,框架会根据InferShape建立的规则,计算出具体的输出形状,并将这些具体的维度值传递给Kernel函数。Kernel函数再根据这些值,执行计算。

类比:GPS导航

  • InferShape:就像是你在出发前,用地图软件规划的行车路线。它告诉你:“从A到B,你需要先走G15高速,然后转S32省道…”。它定义了规则,但不知道实时路况。
  • Kernel函数:就像是你开车时,车载导航的实时指引。它接收到实时的GPS信号(具体的维度值),然后告诉你:“前方2公里后,在xx出口下高速”。它是在执行已经规划好的路线。
    在这里插入图片描述
3. InferShape的“魔法”:从具体计算到符号推导

InferShape是实现动态Shape的核心。它运行在编译期,输入的是一个带有“未知数”(通常表示为-1)的Shape对象,输出的也是一个带有“未知数”的Shape对象。

静态Shape下的InferShape(简单回顾):
假设实现一个(N, C, H, W) -> (N, C, H/2, W/2)的降采样算子,且所有维度已知。

// 伪代码
graph::Shape input_shape = op.GetInputDesc(0).GetShape();
int N = input_shape.GetDim(0); // N=8
int C = input_shape.GetDim(1); // C=3
// ...
graph::Shape output_shape({N, C, H/2, W/2});
op.GetOutputDesc(0).SetShape(output_shape);

这里的N, C, H, W都是具体的整数。

动态Shape下的InferShape
现在,假设NH是动态的,输入Shape可能是(-1, 3, -1, 224)

// 动态Shape下的伪代码
graph::Shape input_shape = op.GetInputDesc(0).GetShape(); // 获取带有未知数的Shape
// 即使GetDim获取的是-1,Shape对象内部也保留了其符号信息
std::vector<int64_t> dims = input_shape.GetDims(); 

// 我们直接对这个带有未知数的dims进行规则操作
dims[2] = dims[2] / 2; // 'H' 维度除以2
dims[3] = dims[3] / 2; // 'W' 维度除以2

// 用新的符号化dims构建输出Shape
graph::Shape output_shape(dims);
op.GetOutputDesc(0).SetShape(output_shape);

这里的关键在于,我们操作的不再是简单的int,而是Shape对象及其dims向量。CANN的图编译器会理解这些操作,并构建一个形状推导的子图。当运行时输入一个具体的Shape(8, 3, 448, 224),框架会自动执行这个子图,推导出输出Shape(8, 3, 224, 112)

4. Kernel的“应变”:从硬编码到运行时参数

InferShape制定了规则,那么Kernel如何在运行时拿到具体的维度值呢?答案是:通过Tiling数据结构

当框架准备在NPU上执行Kernel时,它已经知道了本次运行的具体输入/输出形状。它会将这些信息,连同计算出的Tiling方案,一起打包到一个传递给Kernel的结构体中(我们通常称之为TilingData)。

代码对比:

静态Shape下的Kernel(硬编码或模板):

// 假设模板参数写死了维度
template<int N, int C, int H, int W>
__aicore__ void my_kernel(...) {
    for (int n = 0; n < N; ++n) { // 直接使用编译期常量N
        // ...
    }
}

动态Shape下的Kernel(运行时读取):

// tiling结构体由框架在运行时填充
struct TilingData {
    int32_t N, C, H, W; // 存储本次运行的具体维度
    // ... 其他Tiling信息
};

__aicore__ void my_kernel(uint8_t* tiling_data_ptr, ...) {
    // 1. 从传入的指针中,获取本次运行的TilingData
    auto* tiling = reinterpret_cast<TilingData*>(tiling_data_ptr);
    
    // 2. 从tiling结构体中读取具体的维度值
    int current_N = tiling->N;
    int current_H = tiling->H;
    
    // 3. 使用这些运行时变量进行计算
    for (int n = 0; n < current_N; ++n) {
        // ...
    }
}

通过这种方式,我们的Kernel变得极具通用性。它不再依赖于任何写死的维度信息,而是完全由运行时传入的TilingData驱动。一份编译好的Kernel代码,就可以处理所有符合InferShape规则的输入形状。

5. 总结:动态Shape是衡量算子成熟度的试金石

支持动态Shape,是对CANN算子开发者综合能力的全面考验。它要求我们:

  1. 具备分层思维:清晰地区分编译期(Host侧,InferShape)和运行期(Device侧,Kernel)的职责。
  2. 理解符号计算:掌握如何在InferShape中对未知的Shape对象进行合规的操作,以建立正确的推导规则。
  3. 掌握运行时参数传递:懂得如何在Kernel中,从TilingData里解析出本次运行所需的具体维度信息。

一个只支持静态Shape的算子,可能是一个性能优异的“赛道车”,只能在特定的赛道上飞驰。而一个支持动态Shape的算子,才是一辆能够应对各种复杂路况的“全地形越野车”,真正具备了在复杂AI模型中服役的能力。


昇腾CANN训练营第二季,火热报名中!

从零到一,精通算子开发!🚀

2025昇腾CANN训练营第二季重磅回归,无论你是AI新手还是进阶开发者,这里都有为你量身打造的课程:

  • 零基础入门:轻松掌握算子开发基础。

  • 进阶实战特辑:挑战高阶技巧,码力全开。

  • 开发者案例分享:借鉴实战经验,少走弯路。

【专属福利】
✅ 官方权威认证:通过考核,赢取 Ascend C算子中级认证 证书!
🎁 社区惊喜好礼:完成任务,解锁精美社区周边!

名额有限,立即锁定席位!
🔗 报名链接https://www.hiascend.com/developer/activities/cann20252

Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐