From f96f85424b1518b47a92609d1915f3582b5d813e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 4 Sep 2025 14:52:37 +0000 Subject: [PATCH 1/8] [Kernels] Overlap shared experts with combine instead of dispatch Signed-off-by: Bill Nell --- .../fused_moe/deepep_ht_prepare_finalize.py | 50 ++++++++++-- .../fused_moe/deepep_ll_prepare_finalize.py | 43 ++++++++++- .../layers/fused_moe/modular_kernel.py | 77 ++++++++++++++++--- .../layers/fused_moe/pplx_prepare_finalize.py | 39 +++++++++- 4 files changed, 188 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 92cbb1742974..0cb46a73e77b 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -245,7 +245,7 @@ def prepare( quant_config) return receiver() - def finalize( + def _finalize( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -253,7 +253,8 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + do_async: bool, + ) -> Optional[Callable]: assert self.handle is not None @@ -276,7 +277,46 @@ def finalize( topk_weights=None, config=self._get_combine_config(), previous_event=None, - async_finish=False, + async_finish=do_async, allocate_on_comm_stream=False) - # Respect inplace outputs. - output.copy_(combined_x, non_blocking=True) + + if do_async: + + def _receiver(): + event.current_stream_wait() + # Respect inplace outputs. + output.copy_(combined_x, non_blocking=True) + + return lambda: _receiver() + else: + # Respect inplace outputs. + output.copy_(combined_x, non_blocking=True) + return None + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> Callable: + receiver = self._finalize(output, fused_expert_output, topk_weights, + topk_ids, apply_router_weight_on_input, + weight_and_reduce_impl, True) + assert receiver is not None + return receiver + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + self._finalize(output, fused_expert_output, topk_weights, topk_ids, + apply_router_weight_on_input, weight_and_reduce_impl, + False) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 61f8297f0f14..d192e0f97cf6 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -203,7 +203,7 @@ def prepare( hook() return receiver() - def finalize( + def _finalize( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -211,7 +211,8 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + do_async: bool, + ) -> Callable: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") @@ -240,3 +241,41 @@ def finalize( if recv_hook is not None: dbo_register_recv_hook(recv_hook) dbo_yield() + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> Callable: + return self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + True, + ) + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + False, + ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index efaa9cc058e4..ce48663f5034 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -213,7 +213,8 @@ def prepare( def supports_async(self) -> bool: """ - Indicates whether or not this class implements prepare_async. + Indicates whether or not this class implements prepare_async and + finalize_async. """ return False @@ -281,6 +282,43 @@ def finalize( """ raise NotImplementedError + @abstractmethod + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + ) -> Callable: + """ + Perform any combine plus apply weights and perform a reduction on the + fused experts output but do not wait for results from other workers. + - output: The output tensor, written in place. Must be (M, K) shape. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. + - topk_weights: The weights to be applied to the fused_experts_output. + - topk_ids: The topk_ids. + - apply_router_weight_on_input: When False, apply the weights to + fused_expert_output. + - weight_and_reduce_impl: An optional TopKWeightAndReduce + implementation. + + Returns a callback that when invoked waits for results from other + workers and has the same return signature as `finalize`, e.g. + + receiver = obj.finalize_async(output, ...) + ... output not valid yet ... + receiver() + ... output valid here ... + + is equivalent to: + + obj.finalize(output, ...) + """ + raise NotImplementedError + @property @abstractmethod def activation_format(self) -> FusedMoEActivationFormat: @@ -455,7 +493,7 @@ def apply( - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights: A map of row to expert weights. Some implementations - choose to do weight application. + choose to do weight application. - topk_ids (torch.Tensor): A map of row to expert id. - activation (str): The activation function to apply after the first MoE layer. @@ -907,16 +945,35 @@ def forward( apply_router_weight_on_input=apply_router_weight_on_input, ) - self.prepare_finalize.finalize( - output, - fused_out, - topk_weights, - topk_ids, - apply_router_weight_on_input, - self.fused_experts.finalize_weight_and_reduce_impl(), - ) + shared_output: Optional[torch.Tensor] = None + + if (not self.prepare_finalize.supports_async() + or self.shared_experts is None): + self.prepare_finalize.finalize( + output, + fused_out, + topk_weights, + topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + ) + if self.shared_experts is not None: + shared_output = self.shared_experts(a1) + else: + receiver = self.prepare_finalize.finalize_async( + output, + fused_out, + topk_weights, + topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + ) + assert self.shared_experts is not None + shared_output = self.shared_experts(a1) + receiver() if self.shared_experts is None: return output else: + assert shared_output is not None return shared_output, output diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index b8c1c14317c4..55ca1fd71077 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -276,7 +276,7 @@ def prepare( hook() return receiver() - def finalize( + def finalize_async( self, output: torch.Tensor, fused_expert_output: torch.Tensor, @@ -284,7 +284,7 @@ def finalize( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + ) -> Callable: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") @@ -307,8 +307,39 @@ def finalize( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) + topk_ids_u32 = topk_ids.view(dtype=torch.uint32) + self.a2a.combine(out_tokens=output, - indices=topk_ids.view(dtype=torch.uint32), + indices=topk_ids_u32, weights=topk_weights, expert_y=fused_expert_output, - bound_m=bound_m) + bound_m=bound_m, + do_send=True, + do_recv=False) + + return lambda: self.a2a.combine(out_tokens=output, + indices=topk_ids_u32, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=False, + do_recv=True) + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + receiver = self.finalize_async( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + ) + receiver() From 1b9d4446f9ca8e79c6db54799f7db4086c1ce604 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 4 Sep 2025 15:00:05 +0000 Subject: [PATCH 2/8] don't allow inplace if shared_experts are present Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index ce48663f5034..4743d54e1ddc 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -850,7 +850,10 @@ def forward( """ a1 = hidden_states - output = a1 if inplace else torch.zeros_like(a1) + if inplace and self.shared_experts is None: + output = a1 + else: + output = torch.zeros_like(a1) local_num_experts = w1.size(0) if global_num_experts == -1: From 7e9d423515d4505ceb86cd2ed4f7129222fe18f6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 4 Sep 2025 15:41:31 +0000 Subject: [PATCH 3/8] update signature Signed-off-by: Bill Nell --- .../layers/fused_moe/deepep_ll_prepare_finalize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index d192e0f97cf6..52fa741a2c6a 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -212,7 +212,7 @@ def _finalize( apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, do_async: bool, - ) -> Callable: + ) -> Optional[Callable]: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") From ec4bd9e7ceeb4ffd77ee756969ca56525a961da5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 4 Sep 2025 17:40:04 +0000 Subject: [PATCH 4/8] fix lint Signed-off-by: Bill Nell --- .../layers/fused_moe/deepep_ll_prepare_finalize.py | 4 +++- vllm/model_executor/layers/fused_moe/modular_kernel.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 52fa741a2c6a..8bc87c51b681 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -251,7 +251,7 @@ def finalize_async( apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> Callable: - return self._finalize( + receiver = self._finalize( output, fused_expert_output, topk_weights, @@ -260,6 +260,8 @@ def finalize_async( weight_and_reduce_impl, True, ) + assert receiver is not None + return receiver def finalize( self, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 4743d54e1ddc..5c342b2bded4 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -282,7 +282,6 @@ def finalize( """ raise NotImplementedError - @abstractmethod def finalize_async( self, output: torch.Tensor, From 2e5beaef58b7837677b0317fc821f4d0e079136e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 9 Sep 2025 23:40:08 +0000 Subject: [PATCH 5/8] back out naive fix Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d22bb253f4a7..ad0acfb73f24 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1839,6 +1839,9 @@ def forward_impl( self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_config.use_flashinfer_cutlass_kernels) + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits) # If there are shared experts but we are not using a modular kernel, the # shared experts must be called here @@ -1849,10 +1852,6 @@ def forward_impl( else: shared_output = None - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits) - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1885,9 +1884,8 @@ def forward_impl( final_hidden_states, ) - def reduce_output(states: torch.Tensor, - do_combine: bool = True) -> torch.Tensor: - if do_naive_dispatch_combine and do_combine: + def reduce_output(states: torch.Tensor) -> torch.Tensor: + if do_naive_dispatch_combine: states = get_ep_group().combine(states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): @@ -1896,11 +1894,10 @@ def reduce_output(states: torch.Tensor, return states if self.shared_experts is None: - assert not isinstance(final_hidden_states, tuple) return reduce_output(final_hidden_states) else: return ( - reduce_output(final_hidden_states[0], do_combine=False), + reduce_output(final_hidden_states[0]), reduce_output(final_hidden_states[1]), ) From dd50c885770d79cfaea71b26cde4298ce50ee410 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 16 Sep 2025 21:24:12 +0000 Subject: [PATCH 6/8] rebase on dbo Signed-off-by: Bill Nell --- .../fused_moe/deepep_ll_prepare_finalize.py | 20 ++++++------- .../layers/fused_moe/modular_kernel.py | 28 +++++++++---------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 8bc87c51b681..3a6c57ad2ebd 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -12,8 +12,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape) from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, - dbo_maybe_run_recv_hook, - dbo_register_recv_hook, dbo_yield) + dbo_maybe_run_recv_hook) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -218,7 +217,7 @@ def _finalize( ), ("Weight application and reduction happens in the combine kernel.") a2a_idx = dbo_current_ubatch_id() - do_recv_hook = dbo_enabled() + do_recv_hook = dbo_enabled() or do_async handle = self.handles[a2a_idx] assert handle is not None @@ -238,9 +237,8 @@ def _finalize( zero_copy=False, return_recv_hook=do_recv_hook, out=output) - if recv_hook is not None: - dbo_register_recv_hook(recv_hook) - dbo_yield() + + return recv_hook def finalize_async( self, @@ -251,17 +249,17 @@ def finalize_async( apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> Callable: - receiver = self._finalize( + recv_hook = self._finalize( output, fused_expert_output, topk_weights, topk_ids, apply_router_weight_on_input, weight_and_reduce_impl, - True, + do_async=True, ) - assert receiver is not None - return receiver + assert recv_hook is not None + return recv_hook def finalize( self, @@ -279,5 +277,5 @@ def finalize( topk_ids, apply_router_weight_on_input, weight_and_reduce_impl, - False, + do_async=False, ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 5c342b2bded4..f2fb79e56958 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -858,17 +858,11 @@ def forward( if global_num_experts == -1: global_num_experts = local_num_experts - shared_output: torch.Tensor - if not self.prepare_finalize.supports_async(): # We shouldn't be running an a2a kernel that doesn't # support async prepare/finalize assert not dbo_enabled() - # Run shared experts serially with dispatch. - if self.shared_experts is not None: - shared_output = self.shared_experts(a1) - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( a1, @@ -896,9 +890,6 @@ def forward( self.fused_experts.quant_config, ) - if self.shared_experts is not None: - shared_output = self.shared_experts(a1) - # If DBO is being used, register the hook with the ubatch context # and call it in dbo_maybe_run_recv_hook instead of passing it to # the receiver. @@ -949,8 +940,9 @@ def forward( shared_output: Optional[torch.Tensor] = None - if (not self.prepare_finalize.supports_async() - or self.shared_experts is None): + if not self.prepare_finalize.supports_async(): + assert not dbo_enabled() + self.prepare_finalize.finalize( output, fused_out, @@ -962,7 +954,7 @@ def forward( if self.shared_experts is not None: shared_output = self.shared_experts(a1) else: - receiver = self.prepare_finalize.finalize_async( + recv_hook = self.prepare_finalize.finalize_async( output, fused_out, topk_weights, @@ -970,9 +962,15 @@ def forward( apply_router_weight_on_input, self.fused_experts.finalize_weight_and_reduce_impl(), ) - assert self.shared_experts is not None - shared_output = self.shared_experts(a1) - receiver() + + if self.shared_experts is not None: + shared_output = self.shared_experts(a1) + + if recv_hook is not None: + dbo_register_recv_hook(recv_hook) + dbo_yield() + if not dbo_enabled(): + recv_hook() if self.shared_experts is None: return output From 140015df7eaddd7ada3a2eacbe52240785456f24 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 16 Sep 2025 21:26:03 +0000 Subject: [PATCH 7/8] fix layer.py merge Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ad0acfb73f24..d22bb253f4a7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1839,9 +1839,6 @@ def forward_impl( self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_config.use_flashinfer_cutlass_kernels) - if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits) # If there are shared experts but we are not using a modular kernel, the # shared experts must be called here @@ -1852,6 +1849,10 @@ def forward_impl( else: shared_output = None + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch( + hidden_states, router_logits) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1884,8 +1885,9 @@ def forward_impl( final_hidden_states, ) - def reduce_output(states: torch.Tensor) -> torch.Tensor: - if do_naive_dispatch_combine: + def reduce_output(states: torch.Tensor, + do_combine: bool = True) -> torch.Tensor: + if do_naive_dispatch_combine and do_combine: states = get_ep_group().combine(states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): @@ -1894,10 +1896,11 @@ def reduce_output(states: torch.Tensor) -> torch.Tensor: return states if self.shared_experts is None: + assert not isinstance(final_hidden_states, tuple) return reduce_output(final_hidden_states) else: return ( - reduce_output(final_hidden_states[0]), + reduce_output(final_hidden_states[0], do_combine=False), reduce_output(final_hidden_states[1]), ) From be2cd7604d8d09794ee70ec29218a76036e87732 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 17 Sep 2025 18:06:33 +0000 Subject: [PATCH 8/8] add assert Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index f2fb79e56958..28d8d975c639 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -966,8 +966,8 @@ def forward( if self.shared_experts is not None: shared_output = self.shared_experts(a1) - if recv_hook is not None: - dbo_register_recv_hook(recv_hook) + assert recv_hook is not None + dbo_register_recv_hook(recv_hook) dbo_yield() if not dbo_enabled(): recv_hook()