主页 > 手机  > 

DeepSeekDeepEP学习(一)lowlatencydispatch

DeepSeekDeepEP学习(一)lowlatencydispatch
背景

为了优化延迟,low lantency使用卡间直接收发cast成fp8的数据的方式,而不是使用normal算子的第一步执行机间同号卡网络发送,再通过nvlink进行转发的两阶段方式。进一步地,normal算子的dispatch包含了notify_dispatch传输meta信息和dispatch传输实际数据两个kernel,而low lantency也省去了notify的过程,为此需要的代价就是显存占用较高,而且也需要配合deepseek版本的gemm。

用法

以github中的demo为例,用法比较简单

首先通过get_low_latency_rdma_size_hint获取num_rdma_bytes,这里获取的就是之后Buffer需要开多大然后初始化Buffer执行low_latency_dispatch def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer: # NOTES: the low-latency mode will consume much more space than the normal mode # So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256 global _buffer num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts) # Allocate a buffer if not existed or not enough buffer size if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes: # NOTES: for best performance, the QP number **must** be equal to the number of the local experts assert num_experts % group.size() == 0 _buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size()) return _buffer def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int): global _buffer # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay) recv_hidden_states, recv_expert_count, handle, event, hook = \ _buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, async_finish=False, return_recv_hook=True) # NOTES: the actual tensor will not be received only if you call `hook()`, # it is useful for double-batch overlapping, but **without any SM occupation** # If you don't want to overlap, please set `return_recv_hook=False` # Later, you can use our GEMM library to do the computation with this specific format return recv_hidden_states, recv_expert_count, handle, event, hook 获取buffer大小

num_max_dispatch_tokens_per_rank表示一个rank最多要dispatch多少个token,由于要cast成fp8,因此一个token的大小为embedding的大小,即hidden,以及scale和token_idx,总大小为num_bytes_per_dispatch_msg。 这里只关注send_buffer_bytes,乘2是为了进行不同batch的并行,send_buffer_bytes比较简单,就是预留最大的token数的空间即可,即num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg。 由于low_latency使用直接收发的方式,不计算实际layout,因此对于recv的极端情况就是所有rank的token都会到同一个expert,所以对于recv,这里需要总的数据大小为num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg。

size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; } LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { const int num_scales = hidden / 128; const int num_local_experts = num_experts / num_ranks; EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); size_t num_bytes_per_dispatch_msg = hidden + num_scales * sizeof(float) + sizeof(int4); size_t num_bytes_per_combine_msg = sizeof(int4) + hidden * sizeof(nv_bfloat16); // Send buffer size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); total_bytes += send_buffer_bytes * 2; size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); total_bytes += recv_buffer_bytes * 2; ... } Buffer创建和初始化

首先创建一个Buffer,这里主要是申请一些内存,low_latency不用关注,可以先忽略 如果是low_latency,还需要开启ibgda,然后rank0通过get_local_nvshmem_unique_id获取nvshmem的handle,类似nccl的一个comm,然后所有rank通过allgather获取到这个handle。

def __init__(self, group: dist.ProcessGroup, num_nvl_bytes: int = 0, num_rdma_bytes: int = 0, low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None: self.rank = group.rank() self.group_size = group.size() self.group = group self.num_nvl_bytes = num_nvl_bytes self.num_rdma_bytes = num_rdma_bytes self.low_latency_mode = low_latency_mode self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode) root_unique_id = None if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: # Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA" if low_latency_mode: assert num_qps_per_rank > 0 os.environ['NVSHMEM_DISABLE_P2P'] = '1' os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check os.environ['NVSHMEM_QP_DEPTH'] = '1024' # NOTES: NVSHMEM initialization requires at least 256 MiB os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' # NOTES: make sure AR (Adaptive Routing) is turned off while running normal kernels, as we cannot verify AR status in the code # Synchronize using the root ID nvshmem_unique_ids = [None, ] * self.group_size if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0): root_unique_id = self.runtime.get_local_nvshmem_unique_id() dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group) root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)] # Make CPP runtime available self.runtime.sync(device_ids, ipc_handles, root_unique_id) assert self.runtime.is_available()

然后执行sync,由于num_nvl_bytes为0,因此这里先略去相关代码。然后执行init,就是初始化nvshmem,此时机间的rdma连接已经完成建立。

void Buffer::sync(const std::vector<int> &device_ids, const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles, const std::optional<pybind11::bytearray>& root_unique_id_opt) { EP_HOST_ASSERT(not is_available()); // Sync NVSHMEM handles and allocate memory if (num_rdma_bytes > 0) { // Initialize NVSHMEM EP_HOST_ASSERT(root_unique_id_opt.has_value()); std::vector<uint8_t> root_unique_id(root_unique_id_opt->size()); auto root_unique_id_str = root_unique_id_opt->cast<std::string>(); std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size()); auto nvshmem_rank = low_latency_mode ? rank : rdma_rank; auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks; EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode)); internode::barrier(); // Allocate rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES); // Clean buffer (mainly for low-latency mode) CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); // Barrier internode::barrier(); CUDA_CHECK(cudaDeviceSynchronize()); } // Ready to use available = true; }

然后执行alloc,即通过nvshmem_align分配nvshmem的对称内存。执行ibgda_initialize_recv_queue kernel,初始化recv连接,然后post recv wr。

dispatch

下图为一个例子,两个节点,每个节点两张卡,每个卡上两个expert。

图 1 角色分配

dispatch kernel对warp的角色进行了分配,除了数据发送之外的逻辑,warp的角色分配都是一样的,每个sm包含了kNumWarpGroups个warp组,每组有kNumWarpsPerGroup个warp,每个warp group对应一个expert。

图 2

所以对应expert的计算方式为

const auto warp_group_id = warp_id / kNumWarpsPerGroup; const auto sub_warp_id = warp_id % kNumWarpsPerGroup; const auto responsible_expert_idx = sm_id * kNumWarpGroups + warp_group_id; 数据发送

在数据发送的过程中线程角色分配有些不同,如下图所示,每个sm内部的最后一个warp负责count的发送,假设叫count warp,其他的warp负责数据的发送,假设叫data warp。

图 3 每个sm一次处理一个token,那么第一次sm0将处理token0,sm1处理token1,sm0的每个warp处理一个topk_idx,图中sm0的warp0处理topk_idx[0][0]。 首先看一下data warp的逻辑,发送数据的格式为embedding + scale + token_idx __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() { ... if (warp_id < num_warps - 1) { constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); const auto num_threads = (num_warps - 1) * 32; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { const auto x_int4 = reinterpret_cast<const int4*>(x) + token_idx * hidden_bf16_int4; const auto rdma_x_int2 = reinterpret_cast<int2*>(reinterpret_cast<uint8_t*>(rdma_x) + token_idx * num_bytes_per_msg); const auto rdma_x_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(rdma_x_int2) + kHidden); const auto rdma_x_src_idx = reinterpret_cast<int*>(rdma_x_scales + num_scales); ... } } ... }

当前sm读取到第token_idx的token,该token的embedding为x_int4 。如上所述,sm内部的num_topk个warp负责读取topk矩阵,获取当前token将被发送到哪个expert,rdma_x_int2是rdma_x中用于存储转换为fp8之后embedding的buffer。

__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() { ... if (dst_expert_idx >= 0) { int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); const auto dst_rank = dst_expert_idx / num_local_experts; const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; const auto src_ptr = reinterpret_cast<uint64_t>(rdma_x_int2); const auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_x) + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + slot_idx * num_bytes_per_msg; if (dst_rank != rank) { nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); } else { // NOTES: only 2 load iterations for 7K hidden with 8 unrolls const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr); const auto* dst_int4_ptr = reinterpret_cast<int4*>(dst_ptr); UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); } // Increase counter after finishing __syncwarp(); lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; } ... }

dst_expert_idx为该token需要被发送到哪个expert,由于不同的sm处理不同的token,可能会同时向某个expert发送数据,因此该warp的laneid 0线程通过原子加一的方式获取发送到该expert的slot_id,避免内存覆盖,然后shfl_sync广播slot_id到当前warp所有线程。 每个rank的recv_buffer如下图所示,每个蓝色矩形的大小为num_max_dispatch_tokens_per_rank * num_bytes_per_msg,假设当前为rank2,需要发送给ep3,因此对端为rank1,dst_expert_local_idx为1。

图 4 假设slot_idx为0,所以他将会被发送到rank1 rdma buffer的如下位置,即local ep1的rank2的开头。

图 5 然后通过nvshmemi_ibgda_put_nbi_warp执行一次rdma_write发送,完成发送后通过atomic add将atomic_finish_counter_per_expert[dst_expert_idx]+1,表示完成了一个token的发送。 发送的同步

由于通过rdma write发送数据,对端无法感知,因此需要在rdma_write之后再执行一个rdma_write_with_imm让recv端感知到数据的发送,但是由于数据的发送是多个sm一起进行的,sm之间无法通过sync之类的接口进行同步,因此DeepEP设计了一套同步机制,在数据发送结束之后,以及初始化一些内存这两件事做完之后向对端发送rdma_write_with_imm,其中发送由data warp执行,初始化的事情由sm0的count warp执行。

前边有提到过atomic_finish_counter_per_expert,atomic_finish_counter_per_expert[x]表示第x个expert的处理进度到了哪里,初始化为0。data warp每次执行一个发送,会将atomic_finish_counter_per_expert[x]原子 + 1;count warp初始化完成之后会原子 + FINISHED_SUM_TAG,然后计算当前rank需要向expert[x]发送token的总数sum,然后原子 + (FINISHED_SUM_TAG - sum),FINISHED_SUM_TAG一定大于sum,因此无论两个warp执行顺序如何,当轮询到这个值等于2 * FINISHED_SUM_TAG的时候,数据发送和内存初始化的工作就完成了。

然后看count warp的逻辑,sm0会执行next_clean的初始化,完成之后将atomic_finish_counter_per_expert加上FINISHED_SUM_TAG。 然后开始计算发送到每个expert的token数量,一个sm负责了kNumWarpGroups个expert,通过smid可以知道自己负责的expert区间,然后开始遍历topk_idx,看有多少个token是属于自己维护的expert,维护sum到expert_count[kNumWarpGroups],expert_count[x]表示当前sm会发送到这个expert多少个token。每个线程计算完成自己的部分之后,通过warp_reduce_sum拿到warp整体的和,就完成了自己维护的expert对应的token数的总和,如果lane_id为0,那么将自己这个总和原子加到atomic_finish_counter_per_expert,值为FINISHED_SUM_TAG - sum。

__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() { ... int expert_count[kNumWarpGroups] = {0}; const auto expert_begin_idx = sm_id * kNumWarpGroups; const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts); // Per lane count #pragma unroll 8 for (int i = lane_id; i < num_tokens * num_topk; i += 32) { auto idx = static_cast<int>(__ldg(topk_idx + i)); if (idx >= expert_begin_idx and idx < expert_end_idx) expert_count[idx - expert_begin_idx] ++; } // Warp reduce #pragma unroll for (int i = expert_begin_idx; i < expert_end_idx; ++ i) { auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); if (lane_id == 0) { shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); } } ... }

数据发送完成之后,开始通知其他rank自己已经完成了,每个warp group的warp 0的第一个线程开始轮询atomic_finish_counter_per_expert,如果轮询到FINISHED_SUM_TAG * 2,说明数据发送结束,内存清理也完成了,那么执行nvshmemi_ibgda_rma_p通过rdma write with imm将实际发送的数据量num_tokens_sent写过去。

通知recv端

还是和之前一样,每个warpgroup维护一个expert,该group的warp0的线程0会执行通知recv端的逻辑,首先轮询atomic_finish_counter_per_expert,看自己维护对应的第responsible_expert_idx个expert的数据有没有发送完成,如果等于FINISHED_SUM_TAG * 2,说明发送完成了,那么通过nvshmemi_ibgda_rma_p执行rdma_write_with_imm将发送的token数num_tokens_sent发送过去,然后再下发recv wr。

__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() { ... // Issue count sends if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { const auto dst_rank = responsible_expert_idx / num_local_experts; const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups]; // Wait local sends issued and send expert counts while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); if (dst_rank != rank) { nvshmemi_ibgda_rma_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx, 0); nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx); } else { st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1); } // Clean workspace for next use atomic_counter_per_expert[responsible_expert_idx] = 0; atomic_finish_counter_per_expert[responsible_expert_idx] = 0; } ... } 数据接收

然后开始执行数据接收的过程,对于接收侧,如图5 recv buffer的结构,同样是一个warp group对应一个expert,负责处理recv buffer中的一块数据,每个warp group的warp 1的线程0执行poll cq的过程,这里warp1和执行通知recv端的warp0可以overlap,等待该expert对应的rank的write with imm将数据发送过来,轮询到cq之后,读取需要接收的token数num_recv_tokens。

__global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() { ... // Shared between sub-warps in warp groups __shared__ int shared_num_recv_tokens[kNumWarpGroups], shared_recv_token_begin_idx[kNumWarpGroups]; // Wait tokens to arrive // NOTES: using sub-warp 1 to overlap with sub-warp 0 int num_recv_tokens, recv_token_begin_idx; EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); if (sub_warp_id == 1 and lane_id == 0) { if (src_rank != rank) { nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx); num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank); EP_DEVICE_ASSERT(num_recv_tokens != 0); } else { while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); } num_recv_tokens = -num_recv_tokens - 1; recv_token_begin_idx = atomicAdd(atomic_counter_per_local_expert + local_expert_idx, num_recv_tokens); shared_num_recv_tokens[warp_group_id] = num_recv_tokens; shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx); } asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32)); num_recv_tokens = shared_num_recv_tokens[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; ... }

拿到num_recv_tokens之后,会通过atomic add的方式加到atomic_counter_per_local_expert[local_expert_idx]上边,预留了num_recv_tokens的空间用于存储收到的数据。 然后开始拷贝数据,每一个warp对应一个token。recv_x_int4就是输出的地址,recv_token_begin_idx就是通过原子加预留的位置,src的算法和发送过程中一致,通过responsible_expert_idx可以知道src_rank,是下图中黄色位置。拷贝完成数据之后,将src的token_idx填充到recv_src_info。

图 6 __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * 32, 1) void dispatch() { ... for (int i = sub_warp_id; i < num_recv_tokens; i += kNumWarpsPerGroup) { // Copy data // NOTES: only 2 load iterations for 7K hidden with 7 unrolls const auto src = reinterpret_cast<int4*>(rdma_recv_x_uint8 + i * num_bytes_per_msg); const auto dst = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst, src, ld_nc_global, st_na_global); // Copy scales const auto src_scales = reinterpret_cast<float*>(rdma_recv_x_uint8 + i * num_bytes_per_msg + kHidden); const auto dst_scales = reinterpret_cast<float*>(recv_x_scales + recv_token_begin_idx + i); const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank; auto scale_0 = lane_id < num_scales ? ld_nc_global(src_scales + lane_id) : 0; auto scale_1 = (lane_id + 32) < num_scales ? ld_nc_global(src_scales + lane_id + 32) : 0; lane_id < num_scales ? dst_scales[lane_id * scale_stride] = scale_0 : 0.0f; (lane_id + 32) < num_scales ? dst_scales[(lane_id + 32) * scale_stride] = scale_1 : 0.0f; // Copy source info const auto src_src_idx = reinterpret_cast<int*>(src_scales + num_scales); if (lane_id == 0) recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); __syncwarp(); } }
标签:

DeepSeekDeepEP学习(一)lowlatencydispatch由讯客互联手机栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“DeepSeekDeepEP学习(一)lowlatencydispatch