昇腾MindSpeed RL的数据并行负载均衡特性代码解析
在数据并行(Data Parallel, DP)训练中,若各 DP 节点的序列总长度不均衡,会导致计算量少的节点提前完成等待,形成「木桶效应」。该特性通过装箱算法均衡各 DP 节点的序列总长度,减少节点间等待时间,提升分布式训练效率。详细原理请参考:数据并行负载均衡(DP Batch Balance)。下面我们对该特性的代码实现做详细的介绍。
在数据并行(Data Parallel, DP)训练中,若各 DP 节点的序列总长度不均衡,会导致计算量少的节点提前完成等待,形成「木桶效应」。该特性通过装箱算法均衡各 DP 节点的序列总长度,减少节点间等待时间,提升分布式训练效率。详细原理请参考:数据并行负载均衡(DP Batch Balance)。下面我们对该特性的代码实现做详细的介绍。
相关代码
mindspeed_rl/utils/seqlen_balancing.py
mindspeed_rl/trainer/utils/transfer_dock.py
应用时机
在从TransferDock中读取数据时候处理
实现逻辑
默认的数据采样算法
系统提供了一个默认的采样indexes的算法(不是DP Batch Balance),主要通过batch_balencing_sampler实现:
- 如果没有指定target_seq_len,就采取随机采样(实际现在默认是不指定target_seq_len的)
- 如果指定了target_seq_len,会计算每个usable_index对应的experience_columns的数据总长度,数据长度越接近target_seq_len,被采样的概率越大
|
#mindspeed_rl.trainer.utils.transfer_dock.GRPOTransferDock.batch_balencing_sampler def batch_balencing_sampler( self, experience_columns, usable_indexes, experience_count, target_seq_len=None): #如果没有指定target_seq_len,就采取随机采样 if target_seq_len is None: weights = torch.ones(len(usable_indexes)) #如果指定了target_seq_len,会计算每个usable_index对应的experience_columns的数据总长度 #数据长度越接近target_seq_len,被采样的概率越大 else: seq_len = torch.tensor( [ sum([self.experience_data[key][idx].numel() for key in experience_columns]) for idx in usable_indexes ] ) weights = torch.sigmoid(1 / (torch.abs(seq_len - target_seq_len) + 0.001), dim=0) sampled_indexes_idx = torch.multinomial(weights, experience_count, \ replacement=False).tolist() sampled_indexes = [int(usable_indexes[i]) for i in sampled_indexes_idx] return sampled_indexes |
DP Batch Balance
DP Batch Balance主要通过batch_seqlen_balance_sampler实现:
- 如果usable_indexes和experience_count长度相等,直接返回,也就是说可用indexes和要用的长度一样,不用采样
- 计算某个index的experience_column数据总长度的时候,根据表batch_seqlen_balance_mapper决定不同consumer使用不同的experience_column字段。
- 核心算法通过get_seqlen_balanced_partitions实现,使用堆排序装箱算法,按序列长度从大到小依次分配至当前总长度最小的分组,确保各 DP 节点的总序列长度均衡
batch_seqlen_balance_mapper
|
#mindspeed_rl.trainer.utils.transfer_dock.GRPOTransferDock.__init__ #支持DP Batch Balance采样的阶段,以及计算长度使用的columns self.batch_seqlen_balance_mapper = { "ref_log_prob": ["prompt_length", "response_length"], "actor_log_prob": ["prompt_length", "response_length"], "reward_scores": ["prompt_length", "response_length"], "actor_train": ["prompt_length", "response_length"] } |
batch_seqlen_balance_sampler
|
#mindspeed_rl.trainer.utils.transfer_dock.GRPOTransferDock.batch_seqlen_balance_sampler def batch_seqlen_balance_sampler(self, consumer, usable_indexes, experience_count, get_n_samples=False): from mindspeed_rl.utils.seqlen_balancing import get_seqlen_balanced_partitions if len(usable_indexes) == experience_count: sampled_indexes = [int(usable_indexes[i]) for i in range(experience_count)] return sampled_indexes #计算usable_indexes中每个index的experience_column数据总长度 #根据表batch_seqlen_balance_mapper决定不同consumer使用不同的experience_column字段 seq_len_columns = self.batch_seqlen_balance_mapper.get(consumer) if get_n_samples: seq_len_list = [ sum([self.experience_data[key][idx * self.n_samples_per_prompt + addition].item() for addition in range(self.n_samples_per_prompt) for key in seq_len_columns]) for idx in usable_indexes ] ...... k_partitions = len(seq_len_list) // experience_count #使用堆排序装箱算法,确保各DP节点的总序列长度均衡 sampled_indexes_idx = get_seqlen_balanced_partitions(seq_len_list, k_partitions, equal_size=True) if len(sampled_indexes_idx) > 0: sampled_indexes = [int(usable_indexes[i]) for i in sampled_indexes_idx[0]] else: sampled_indexes = None return sampled_indexes |
get_seqlen_balanced_partitions
|
#mindspeed_rl.utils.seqlen_balancing.get_seqlen_balanced_partitions def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool): """get order of seq lengths to make partitions balanced, this is used in balancing sum of seq length across dp ranks and micro batches Parameters: seqlen_list (List[int]): seq lengths of each items k_partitions (int): resulting number of partitions equal_size (bool): if True, number of items in each partitions must be equal. if False, only consider balancing the sum, each partition can have variable number of items Returns: partitions (List[List[int]]): """ def _check_and_sort_partitions(partitions): #对各分组做了排序 seen_idx = set() sorted_partitions = [None] * k_partitions for i, partition in enumerate(partitions): for idx in partition: seen_idx.add(idx) sorted_partitions[i] = sorted(partition) return sorted_partitions #调用heapq_partition将seqlen_list分成k_partitions组,equal_size=True要求各分组的元素数目一样 partitions = heapq_partition(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) return _check_and_sort_partitions(partitions) |
|
def heapq_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool): equal_part_num = len(seqlen_list) // k_partitions #seqlen_list按照seq长度从大到小排列 sorted_seqlen = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)], reverse=True) # 堆初始化: 每个group存[组元素当前总长度, 元素数, group的index, 元素] groups = [[0, 0, i, []] for i in range(k_partitions)] heapq.heapify(groups) partitions = [] #按照seqlen从大到小循环 for seqlen, i in sorted_seqlen: #堆中pop出总长度最小的group current_group = heapq.heappop(groups) #将最大的seq的信息填入组 current_group[3].append(i) current_group[0] += seqlen current_group[1] += 1 #如果要求各分组的元素数目一样,需要做特殊处理 if equal_size: if current_group[1] < equal_part_num: heapq.heappush(groups, current_group) else: partitions.append(current_group[3]) #否则将current_group重新加入堆 else: heapq.heappush(groups, current_group) partitions.extend([group[3] for group in groups]) ...... return partitions |
使用示例
mindspeed_rl/trainer/utils/transfer_dock.py
get_experience
GRPOTransferDock的get_experience中在获取indexes采样的时候调用了_sample_ready_index_n_samples或者_sample_ready_index。
|
#mindspeed_rl.trainer.utils.transfer_dock.GRPOTransferDock.get_experience if get_n_samples: if experience_count % self.n_samples_per_prompt != 0: raise ValueError( f"get_n_samples need experience_count:{experience_count} must be divisible by " f"n_samples_per_prompt: {self.n_samples_per_prompt}" ) indexes = self._sample_ready_index_n_samples( consumer, experience_count, experience_columns, use_batch_seqlen_balance=use_batch_seqlen_balance ) else: indexes = self._sample_ready_index( consumer, experience_count, experience_columns, use_batch_seqlen_balance=use_batch_seqlen_balance ) |
_sample_ready_index
_sample_ready_index_n_samples或者_sample_ready_index中在获取indexes采样的时候都调用了batch_seqlen_balance_sampler或者batch_balencing_sampler。
|
#mindspeed_rl.trainer.utils.transfer_dock.GRPOTransferDock._sample_ready_index_n_samples def _sample_ready_index_n_samples( self, consumer: str, experience_count: int, experience_columns: List[str], target_seq_len: int = None, use_batch_seqlen_balance: bool = False ) -> Optional[List[int]]: experience_count_n_samples = experience_count // self.n_samples_per_prompt ...... #根据not_consumed_indexes和data_ready_indexes计算usable_indexes usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0] #使用DP Batch Balance采样indexes if consumer in self.batch_seqlen_balance_mapper and use_batch_seqlen_balance and len( usable_indexes) % experience_count_n_samples == 0: sampled_indexes_n_sample = self.batch_seqlen_balance_sampler( consumer, usable_indexes, experience_count_n_samples, get_n_samples=True ) if not sampled_indexes_n_sample: return None #使用默认的indexes采样方法 else: sampled_indexes_n_sample = self.batch_balencing_sampler( experience_columns, usable_indexes, experience_count_n_samples, target_seq_len, ) ...... #标记采样的indexes的experience_consumer_status[consumer] self.experience_consumer_status[consumer][sampled_indexes] = 1 return sampled_indexes |
昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链
更多推荐

所有评论(0)