在数据并行(Data Parallel, DP)训练中,若各 DP 节点的序列总长度不均衡,会导致计算量少的节点提前完成等待,形成「木桶效应」。该特性通过装箱算法均衡各 DP 节点的序列总长度,减少节点间等待时间,提升分布式训练效率。详细原理请参考:数据并行负载均衡(DP Batch Balance)。下面我们对该特性的代码实现做详细的介绍。

相关代码

mindspeed_rl/utils/seqlen_balancing.py

mindspeed_rl/trainer/utils/transfer_dock.py

应用时机

在从TransferDock中读取数据时候处理

实现逻辑

默认的数据采样算法

系统提供了一个默认的采样indexes的算法(不是DP Batch Balance),主要通过batch_balencing_sampler实现:

  1. 如果没有指定target_seq_len,就采取随机采样(实际现在默认是不指定target_seq_len的)
  2. 如果指定了target_seq_len,会计算每个usable_index对应的experience_columns的数据总长度,数据长度越接近target_seq_len,被采样的概率越大

#mindspeed_rl.trainer.utils.transfer_dock.GRPOTransferDock.batch_balencing_sampler

def batch_balencing_sampler(

    self, experience_columns, usable_indexes, experience_count, target_seq_len=None):

    #如果没有指定target_seq_len,就采取随机采样

    if target_seq_len is None:

        weights = torch.ones(len(usable_indexes))

    #如果指定了target_seq_len,会计算每个usable_index对应的experience_columns的数据总长度

    #数据长度越接近target_seq_len,被采样的概率越大

    else:

        seq_len = torch.tensor(

            [

                sum([self.experience_data[key][idx].numel() for key in experience_columns])

                for idx in usable_indexes

            ]

        )

        weights = torch.sigmoid(1 / (torch.abs(seq_len - target_seq_len) + 0.001), dim=0)

sampled_indexes_idx = torch.multinomial(weights, experience_count, \

replacement=False).tolist()

    sampled_indexes = [int(usable_indexes[i]) for i in sampled_indexes_idx]

    return sampled_indexes

DP Batch Balance

DP Batch Balance主要通过batch_seqlen_balance_sampler实现:

  1. 如果usable_indexes和experience_count长度相等,直接返回,也就是说可用indexes和要用的长度一样,不用采样
  2. 计算某个index的experience_column数据总长度的时候,根据表batch_seqlen_balance_mapper决定不同consumer使用不同的experience_column字段。
  3. 核心算法通过get_seqlen_balanced_partitions实现,使用堆排序装箱算法,按序列长度从大到小依次分配至当前总长度最小的分组,确保各 DP 节点的总序列长度均衡
batch_seqlen_balance_mapper

#mindspeed_rl.trainer.utils.transfer_dock.GRPOTransferDock.__init__

#支持DP Batch Balance采样的阶段,以及计算长度使用的columns

self.batch_seqlen_balance_mapper = {

    "ref_log_prob": ["prompt_length", "response_length"],

    "actor_log_prob": ["prompt_length", "response_length"],

    "reward_scores": ["prompt_length", "response_length"],

    "actor_train": ["prompt_length", "response_length"]

}

batch_seqlen_balance_sampler

#mindspeed_rl.trainer.utils.transfer_dock.GRPOTransferDock.batch_seqlen_balance_sampler

def batch_seqlen_balance_sampler(self, consumer, usable_indexes, experience_count, get_n_samples=False):

    from mindspeed_rl.utils.seqlen_balancing import get_seqlen_balanced_partitions

    if len(usable_indexes) == experience_count:

        sampled_indexes = [int(usable_indexes[i]) for i in range(experience_count)]

        return sampled_indexes

    #计算usable_indexes中每个index的experience_column数据总长度

    #根据表batch_seqlen_balance_mapper决定不同consumer使用不同的experience_column字段

    seq_len_columns = self.batch_seqlen_balance_mapper.get(consumer)

    if get_n_samples:

        seq_len_list = [

            sum([self.experience_data[key][idx * self.n_samples_per_prompt + addition].item()

                    for addition in range(self.n_samples_per_prompt) for key in seq_len_columns])

            for idx in usable_indexes

        ]

    ......

    k_partitions = len(seq_len_list) // experience_count

    #使用堆排序装箱算法,确保各DP节点的总序列长度均衡

    sampled_indexes_idx = get_seqlen_balanced_partitions(seq_len_list, k_partitions, equal_size=True)

    if len(sampled_indexes_idx) > 0:

        sampled_indexes = [int(usable_indexes[i]) for i in sampled_indexes_idx[0]]

    else:

        sampled_indexes = None

    return sampled_indexes

get_seqlen_balanced_partitions

#mindspeed_rl.utils.seqlen_balancing.get_seqlen_balanced_partitions

def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool):

    """get order of seq lengths to make partitions balanced, this is

        used in balancing sum of seq length across dp ranks and micro batches

    Parameters:

        seqlen_list (List[int]):

            seq lengths of each items

        k_partitions (int):

            resulting number of partitions

        equal_size (bool):

            if True, number of items in each partitions must be equal.

            if False, only consider balancing the sum, each partition can have

            variable number of items

    Returns:

        partitions (List[List[int]]):

    """      

    def _check_and_sort_partitions(partitions):

        #对各分组做了排序

        seen_idx = set()

        sorted_partitions = [None] * k_partitions

        for i, partition in enumerate(partitions):

            for idx in partition:

                seen_idx.add(idx)

            sorted_partitions[i] = sorted(partition)

        return sorted_partitions

    #调用heapq_partition将seqlen_list分成k_partitions组,equal_size=True要求各分组的元素数目一样

    partitions = heapq_partition(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size)

    return _check_and_sort_partitions(partitions)

def heapq_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool):

    equal_part_num = len(seqlen_list) // k_partitions

    #seqlen_list按照seq长度从大到小排列

    sorted_seqlen = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)], reverse=True)

    # 堆初始化: 每个group存[组元素当前总长度, 元素数, group的index, 元素]

    groups = [[0, 0, i, []] for i in range(k_partitions)]

    heapq.heapify(groups)

    partitions = []

    #按照seqlen从大到小循环

    for seqlen, i in sorted_seqlen:

        #堆中pop出总长度最小的group

        current_group = heapq.heappop(groups)  

        #将最大的seq的信息填入组

        current_group[3].append(i)

        current_group[0] += seqlen

        current_group[1] += 1

        #如果要求各分组的元素数目一样,需要做特殊处理

        if equal_size:

            if current_group[1] < equal_part_num:

                heapq.heappush(groups, current_group)

            else:

                partitions.append(current_group[3])

        #否则将current_group重新加入堆

        else:

            heapq.heappush(groups, current_group)

    partitions.extend([group[3] for group in groups])

    ......

    return partitions

使用示例

mindspeed_rl/trainer/utils/transfer_dock.py

get_experience

GRPOTransferDock的get_experience中在获取indexes采样的时候调用了_sample_ready_index_n_samples或者_sample_ready_index。

#mindspeed_rl.trainer.utils.transfer_dock.GRPOTransferDock.get_experience

if get_n_samples:

    if experience_count % self.n_samples_per_prompt != 0:

        raise ValueError(

            f"get_n_samples need experience_count:{experience_count} must be divisible by "

            f"n_samples_per_prompt: {self.n_samples_per_prompt}"

        )

    indexes = self._sample_ready_index_n_samples(

        consumer, experience_count, experience_columns,

        use_batch_seqlen_balance=use_batch_seqlen_balance

    )

else:

    indexes = self._sample_ready_index(

        consumer, experience_count, experience_columns,

        use_batch_seqlen_balance=use_batch_seqlen_balance

    )

_sample_ready_index

_sample_ready_index_n_samples或者_sample_ready_index中在获取indexes采样的时候都调用了batch_seqlen_balance_sampler或者batch_balencing_sampler。

#mindspeed_rl.trainer.utils.transfer_dock.GRPOTransferDock._sample_ready_index_n_samples

def _sample_ready_index_n_samples(

    self,

    consumer: str,

    experience_count: int,

    experience_columns: List[str],

    target_seq_len: int = None,

    use_batch_seqlen_balance: bool = False

) -> Optional[List[int]]:

    experience_count_n_samples = experience_count // self.n_samples_per_prompt

    ......

    #根据not_consumed_indexes和data_ready_indexes计算usable_indexes

    usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0]

    #使用DP Batch Balance采样indexes

    if consumer in self.batch_seqlen_balance_mapper and use_batch_seqlen_balance and len(

            usable_indexes) % experience_count_n_samples == 0:

        sampled_indexes_n_sample = self.batch_seqlen_balance_sampler(

            consumer, usable_indexes, experience_count_n_samples, get_n_samples=True

        )

        if not sampled_indexes_n_sample:

            return None

    #使用默认的indexes采样方法

    else:

        sampled_indexes_n_sample = self.batch_balencing_sampler(

            experience_columns,

            usable_indexes,

            experience_count_n_samples,

            target_seq_len,

        )

    ......

    #标记采样的indexes的experience_consumer_status[consumer]

    self.experience_consumer_status[consumer][sampled_indexes] = 1

    return sampled_indexes

Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐