vllm 注册自定义算子与Attention kernel源码分析

综述

本文章会介绍vllm两个关于CUDA的部分:

  1. 了解的基础知识。
  2. 自定义的CUDA算子注册到vllm中供框架中调用的实现步骤。
  3. 解析vllm调用cuda的attention代码。

必备技术和知识

  • 了解和学习cuda,c++(>=c++17)和python的基本语法(模板类)。
  • 掌握最简单Element Wise的kernel和launch的写法。
  • 掌握self-attention,paged-attention模块内部计算流程。
  • 了解pybind基本写法。

注册算子到vllm的实现步骤

要将算子注册到python中有如下操作:

  • 所有的算子文件存入(vllm/csrc)中,这里我们假设要实现融合算子paged_attention_v1.cu,将此文件存入到(vllm/csrc/attention/paged_attention_v1.cu)。
  • 因为是写入新文件,就需要在CMakeList.txt,查找如下位置填入新文件的路径。
set(VLLM_EXT_SRC
  "csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
  "csrc/cache_kernels.cu"
  "csrc/attention/paged_attention_v1.cu" // 添加文件路径
  "csrc/attention/paged_attention_v2.cu"
  ...)
  • 要对文件中的kernel进行binding。首先在(vllm/csrc/ops.h)中增加函数定义。
// 添加这个kernel的定义
void paged_attention_v1(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
  • 其次(vllm/csrc/torch_bindings.cpp)中增加其pybind,详细的写法需了解pybind。
  ops.def(
      "paged_attention_v1("
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
      "    int tp_rank, int blocksparse_local_blocks,"
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
  • 完成以上步骤,我们需要将kernel封装成统一的python的api,在(vllm/vllm/_custom_ops.py)中填入
def paged_attention_v1(
    out: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    num_kv_heads: int,
    scale: float,
    block_tables: torch.Tensor,
    seq_lens: torch.Tensor,
    block_size: int,
    max_seq_len: int,
    alibi_slopes: torch.Tensor | None,
    kv_cache_dtype: str,
    k_scale: torch.Tensor,
    v_scale: torch.Tensor,
    tp_rank: int = 0,
    blocksparse_local_blocks: int = 0,
    blocksparse_vert_stride: int = 0,
    blocksparse_block_size: int = 64,
    blocksparse_head_sliding_step: int = 0,
) -> None:
    torch.ops._C.paged_attention_v1(
        out,
        query,
        key_cache,
        value_cache,
        num_kv_heads,
        scale,
        block_tables,
        seq_lens,
        block_size,
        max_seq_len,
        alibi_slopes,
        kv_cache_dtype,
        k_scale,
        v_scale,
        tp_rank,
        blocksparse_local_blocks,
        blocksparse_vert_stride,
        blocksparse_block_size,
        blocksparse_head_sliding_step,
    )
  • 完成以上操作,就可以在vllm中调用此算子了。

vllm/csrc/attention/attention_kernel.cuh代码解析

在这个头文件中,核心算子是paged_attention_kernel,其他函数都是在此基础上调用其算子的。我会尽量介绍围绕paged-attention的算法的代码(有很小部分是在此基础上做了调优,但不影响整个算子),从高到低的视角来解析代码。

  • 首先此算子实现的paged-attention,其中包含工程细节有:
    1. 分别对应Q,K和V的global memory->shared memory->registers
    2. 分别对应Q,K和V的threadIdx分配合适内存和物理内存。
    3. 实现self-attention计算过程

将Q加载到共享内存中,但QK的计算是在寄存器中进行

相关变量
const scalar_t* __restrict__ q;     // [num_seqs, num_heads, head_size]
const int q_stride                  // 一个seq长度。       
const int seq_idx = blockIdx.y;     // Grid: (num_heads, num_seqs, max_num_partitions).
const int head_idx = blockIdx.x;    // Grid: (num_heads, num_seqs, max_num_partitions).
int HEAD_SIZE;                      // 一个头的长度。
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;   // 一个thread_group组内中每个thread分配的部分头长度数。
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); // vec是存储q或者k的,vec的个数是由max(16/(组内thread个数*单个大小),1)。
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;  // thread组内中每个thread分配的vec数。
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;          // thread_group组的序号。
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;       //组内thread的偏移量,也可以理解成thread组内序号。 
相关代码
// 把 Q 向量(注意力计算中的查询向量)按 “线程分组大小” 做 “间隔式拆分”—— 比如分组大小为 4 时,每个线程只处理自己序号间隔 4 的 Q 向量(线程 0 管 0/4/8…,线程 1 管 1/5/9…)。
// 因为 Q 是从 QKV 合并张量里拆出来的,所以这些 Q 向量在显存中可能不是连续存储的(需要注意内存访问的效率问题)。
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; // q:[num_seqs, num_heads, head_size]
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];       // 按照thread_group组内thread数作为行,它的每个thread的vecs数作为列。
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;  // 遍历一个group中的所有query的vecs,将其加载到寄存器中。
    i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] =                      // thread_group_offset:表面是指组内thread的偏移量,也可以理解成thread组内id。
    *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
__syncthreads(); 

以上的操作可以理解成,当执行到这个kernel时:

  1. 找到正在运行的线程要处理的Q的子块,并且按照分组找到属于线程所有的seq_q
  2. 根据要求在共享内存中设置合适的Q_vecs,按照其分组顺序遍历存储到对应到位置中,此目的可以加快后续的计算速度,这里存在bank_conflict,有合并访存的优化空间

将K加载到共享内存中,并且做 S=Q*KT

相关变量
const int* __restrict__ block_tables;   // [num_seqs, max_num_blocks_per_seq]
const int max_num_blocks_per_seq;       // 每个seq最大block数
const int partition_idx = blockIdx.z;   // // Grid: (num_heads, num_seqs, max_num_partitions).
const int num_blocks_per_partition =
    USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;        // 每个分块(USE_PARTITIONING)最大block数。
const int start_block_idx =
    USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;  // 有分块存在,block起始位置与分块的序号有关,像sliding window attention是根据当前q近邻的m个作为kv cache从而提高速度,适用于长序列。
const int end_block_idx =
    MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);    // block结束位置。
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; // 当前seq对应的block_table。
/*
 当前seq_idx的对应的block_table(block table是paged attention中的概念,
 旨在将 KV Cache按照block的结构打散在显存中,通过table映射,从而直接提高显存的利用率)。
*/
相关代码
  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
       block_idx += NUM_WARPS) {
    ......
    const int64_t physical_block_number =
        static_cast<int64_t>(block_table[block_idx]);

    // 与Q相同,将要做计算的K存入寄存器中。
    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
      const int physical_block_offset =
          (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
      K_vec k_vecs[NUM_VECS_PER_THREAD];

#pragma unroll
      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
        const cache_t* k_ptr =            // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
            k_cache + physical_block_number * kv_block_stride +
            kv_head_idx * kv_head_stride + physical_block_offset * x;
        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
        const int offset1 = (vec_idx * VEC_SIZE) / x;
        const int offset2 = (vec_idx * VEC_SIZE) % x;

        if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
          k_vecs[j] = *reinterpret_cast<const K_vec*>(
              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
        } else {
          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
          k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
              k_vec_quant, *k_scale);
        }
      }

      // 计算Q*K,这里用了流水进行加速
      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
                             q_vecs[thread_group_offset], k_vecs);

      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;

      if (thread_group_offset == 0) {
        const bool mask = token_idx >= seq_len;
        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
        // 找到qk_max
        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
      }
    }
  }
  1. 代码中包含一层的循环嵌套,外层循环整体实现S=Q*KT
  2. 内层循环是将K的子块加载到共享内存中。
  3. 接着循环后是Q*KT
  4. 找到当前thread中多个q中最大的qk_max,为后续safe softmax提供计算数据。

实现safe softmax(softmax的工程实现版)

相关代码
// 之前的qk_max是thread对应多个q的最大值,现在是在WARP_SIZE中找到最大的qk_max。
// 这个写法是利用归约特性,提高速度,并存取到red_smem中。
// 如果使用右移操作实现除以2的操作,可再提高速度。
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();


// 在所有的NUM_WARPS中找到最大的qk_max。
// NUM_WARPS->WARP_SIZE->thread,定义范围逐级缩小,safe softmax需要找到全局最大值,所以最终要在NUM_WARPS中找到。
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
qk_max = VLLM_SHFL_SYNC(qk_max, 0);


float exp_sum = 0.f;
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);

//计算 safe softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum;       // 将结果保存到logits中。
}
__syncthreads();

计算O=S*V

相关变量
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; // 创建一个vector,长度是V_VEC_SIZE,每个大小是scalar_t
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type;

constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; // 一行V_VEC数量。
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;       // 一次迭代能并行处理的行数。
constexpr int NUM_ROWS_PER_THREAD =
    DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);  
相关代码
  scalar_t zero_value;
  zero(zero_value);
  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
       block_idx += NUM_WARPS) {
    // 与前面K类似,根据当前有关的V cache做计算。
    const int64_t physical_block_number =
        static_cast<int64_t>(block_table[block_idx]);
    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
    L_vec logits_vec; 
    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
                                                           start_token_idx));

    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
                           kv_head_idx * kv_head_stride;
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE) {
        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
        V_vec v_vec;

        if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
        } else {
          V_quant_vec v_quant_vec =
              *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
          // Vector conversion from V_quant_vec to V_vec.
          v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
                                                                    *v_scale);
        }
        if (block_idx == num_seq_blocks - 1) {
          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
          for (int j = 0; j < V_VEC_SIZE; j++) {
            v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
          }
        }
        // O=S*V
        accs[i] += dot(logits_vec, v_vec);
      }
    }
  }

  // 归约求和
#pragma unroll
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    float acc = accs[i];
#pragma unroll
    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
      acc += VLLM_SHFL_XOR_SYNC(acc, mask);
    }
    accs[i] = acc;
  }

  __syncthreads();
  1. 这里是循环的嵌套,内部的循环实现通过v当前的index去查找cache的数据。
  2. 外层循环则是,将得到V与前面得到的S做,O=S*V。
  3. 归约求和实现将O存入accs中。

归约存储结果并从寄存器存储到全局内存

相关代码

  float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
  for (int i = NUM_WARPS; i > 1; i /= 2) {
    int mid = i / 2;
    if (warp_idx >= mid && warp_idx < i) {
      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
          dst[row_idx] = accs[i];
        }
      }
    }
    __syncthreads();

    if (warp_idx < mid) {
      const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
          accs[i] += src[row_idx];
        }
      }
    }
    __syncthreads();
  }

  if (warp_idx == 0) {
    scalar_t* out_ptr =
        out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
        head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
        from_float(*(out_ptr + row_idx), accs[i]);
      }
    }
  }

计算已经结束,通过归约将计算完数据从寄存器逐步转移到全局内存:
先到共享内存,看第一个for循环中,可以分成上下半区。

以一个 Block 内有 8 个 Warp(NUM_WARPS=8)、注意力头维度 HEAD_SIZE=64 为例,展示分治归约的完整过程:

第 1次循环:i=8,mid=4(8个Warp->4个Warp)

  • 上半区:4、5、6、7号Warp->分别将结果写入共享内存0、1、2、3号地址;
  • 同步后,下半区:0、1、2、3号Warp->分别读取共享内存0、1、2、3号地址,与自身结果累加;
  • 结果:仅剩03号Warp持有有效结果,47号Warp退出后续归约。

第2次循环:i=4,mid=2(4个Warp->2个Warp)

  • 上半区:2、3号Warp->将结果写入共享内存0、1号地址;
  • 同步后,下半区:0、1号Warp->读取共享内存0、1号地址,与自身结果累加;
  • 结果:仅剩01号Warp持有有效结果,23号Warp退出后续归约。

第3次循环:i=2,mid=1(2个Warp->1个Warp)

  • 上半区:1号Warp->将结果写入共享内存0号地址;
  • 同步后,下半区:0号Warp->读取共享内存0号地址,与自身结果累加;
  • 结果:仅剩0号Warp持有全局最终归约结果,循环结束(i=1,不满足i>1)。
最后将结果存储到全局内存的out_ptr中。

总结

  1. 第一部份实现将算子注册到vllm,我们可以应用到自己写的推理框架中。
  2. 第二部分介绍的代码,没有实际的介绍到vllm核心的paged attention部分,并且也有一定的改进空间,但我们也可以作为学习源码flash attention的入门(当然它的源码主要还是用cute来实现,这里是cuda实现)。
  3. 如果存在纰漏,欢迎指正。
Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐