77 tensor_model_parallel_all_reduce ,
88 tensor_model_parallel_reduce_scatter )
99from vllm .forward_context import get_forward_context
10+ from vllm .logger import logger
1011from vllm .utils import direct_register_custom_op
1112
1213import vllm_ascend .envs as envs_ascend
1314
1415
1516def _maybe_chunk_residual_impl (x : torch .Tensor ,
1617 residual : torch .Tensor ) -> torch .Tensor :
18+ try :
19+ forward_context = get_forward_context ()
20+ except AssertionError :
21+ logger .info ("Forward context is None, skipping the operation." )
22+ return residual
23+
1724 if x .size (0 ) != residual .size (0 ):
18- flashcomm_v1_enabled = get_forward_context () .flashcomm_v1_enabled
25+ flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
1926 assert flashcomm_v1_enabled is True , (
2027 "Currently, this situation only occurs "
2128 "when flashcomm_v1 is enabled" )
22- pad_size = get_forward_context () .pad_size
29+ pad_size = forward_context .pad_size
2330 if pad_size > 0 :
2431 residual = F .pad (residual , (0 , 0 , 0 , pad_size ))
2532 tp_size = get_tensor_model_parallel_world_size ()
@@ -31,19 +38,31 @@ def _maybe_chunk_residual_impl(x: torch.Tensor,
3138
3239def _maybe_all_gather_and_maybe_unpad_impl (x : torch .Tensor ,
3340 label : bool ) -> torch .Tensor :
34- flashcomm_v1_enabled = get_forward_context ().flashcomm_v1_enabled
41+ try :
42+ forward_context = get_forward_context ()
43+ except AssertionError :
44+ logger .info ("Forward context is None, skipping the operation." )
45+ return x
46+
47+ flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
3548 if flashcomm_v1_enabled and label :
3649 x = tensor_model_parallel_all_gather (x , 0 )
37- pad_size = get_forward_context () .pad_size
50+ pad_size = forward_context .pad_size
3851 if pad_size > 0 :
3952 x = x [:- pad_size , :]
4053 return x
4154
4255
4356def _maybe_pad_and_reduce_impl (x : torch .Tensor ) -> torch .Tensor :
44- flashcomm_v1_enabled = get_forward_context ().flashcomm_v1_enabled
57+ try :
58+ forward_context = get_forward_context ()
59+ except AssertionError :
60+ logger .info ("Forward context is None, skipping the operation." )
61+ return tensor_model_parallel_all_reduce (x )
62+
63+ flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
4564 if flashcomm_v1_enabled :
46- pad_size = get_forward_context () .pad_size
65+ pad_size = forward_context .pad_size
4766 if pad_size > 0 :
4867 x = F .pad (x , (0 , 0 , 0 , pad_size ))
4968 return tensor_model_parallel_reduce_scatter (x , 0 )
@@ -53,7 +72,12 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
5372
5473def _maybe_prefetch_mlp_gate_up_proj_impl (x_dependency : torch .Tensor ,
5574 prefix : str ) -> None :
56- forward_context = get_forward_context ()
75+ try :
76+ forward_context = get_forward_context ()
77+ except AssertionError :
78+ logger .info ("Forward context is None, skipping the operation." )
79+ return
80+
5781 if not forward_context .prefetch_mlp_enabled :
5882 return
5983 model_instance = forward_context .model_instance
@@ -67,9 +91,9 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
6791 prefetch_stream .wait_stream (torch .npu .current_stream ())
6892
6993 with torch .npu .stream (prefetch_stream ):
70- MLP_GATE_UP_PREFETCH_SIZE = envs_ascend .MLP_GATE_UP_PREFETCH_SIZE
94+ mlp_gate_up_prefetch_size = envs_ascend .VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
7195 torch_npu .npu_prefetch (model_instance .model .layers [layer_idx ].mlp .gate_up_proj .weight , \
72- x_dependency , MLP_GATE_UP_PREFETCH_SIZE )
96+ x_dependency , mlp_gate_up_prefetch_size )
7397 return
7498
7599
@@ -79,7 +103,12 @@ def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
79103
80104
81105def _maybe_prefetch_mlp_down_proj_impl (x_dependency : torch .Tensor ) -> None :
82- forward_context = get_forward_context ()
106+ try :
107+ forward_context = get_forward_context ()
108+ except AssertionError :
109+ logger .info ("Forward context is None, skipping the operation." )
110+ return
111+
83112 if not forward_context .prefetch_mlp_enabled :
84113 return
85114 forward_context .prefetch_mlp_down_proj = True
@@ -91,9 +120,9 @@ def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
91120 prefetch_stream .wait_stream (torch .npu .current_stream ())
92121
93122 with torch .npu .stream (prefetch_stream ):
94- MLP_DOWN_PREFETCH_SIZE = envs_ascend .MLP_DOWN_PREFETCH_SIZE
123+ mlp_down_prefetch_size = envs_ascend .VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
95124 torch_npu .npu_prefetch (model_instance .model .layers [layer_idx ].mlp .down_proj .weight , \
96- x_dependency , MLP_DOWN_PREFETCH_SIZE )
125+ x_dependency , mlp_down_prefetch_size )
97126 forward_context .layer_idx += 1
98127 return
99128
@@ -104,12 +133,17 @@ def _maybe_prefetch_mlp_down_proj_impl_fake(
104133
105134
106135def _maybe_wait_prefetch_done_impl (x : torch .Tensor ) -> None :
107- forward_context = get_forward_context ()
136+ try :
137+ forward_context = get_forward_context ()
138+ except AssertionError :
139+ logger .info ("Forward context is None, skipping the operation." )
140+ return
141+
108142 if not forward_context .prefetch_mlp_enabled :
109143 return
110144 if forward_context .prefetch_mlp_gate_up_proj or \
111145 forward_context .prefetch_mlp_down_proj :
112- prefetch_stream = get_forward_context () .prefetch_stream
146+ prefetch_stream = forward_context .prefetch_stream
113147 # wait until prefetch done
114148 torch .npu .current_stream ().wait_stream (prefetch_stream )
115149 forward_context .prefetch_mlp_gate_up_proj = False
0 commit comments