@@ -1832,6 +1832,31 @@ def get_finished_kv_transfer(
18321832 scheduler_output .finished_req_ids )
18331833 return None , None
18341834
1835+ def _build_attention_metadata (self , with_prefill , num_reqs , skip_attn ):
1836+ if skip_attn :
1837+ attn_metadata = None
1838+ else :
1839+ # TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
1840+ attn_metadata = None
1841+ return attn_metadata
1842+
1843+ def _generate_dummy_run_hidden_states (self , with_prefill ,
1844+ is_torchair_compile , input_ids ,
1845+ positions , attn_metadata , num_tokens ,
1846+ intermediate_tensors , inputs_embeds ):
1847+ maybe_converting_weight_acl_format (self .model , ACL_FORMAT_FRACTAL_ND )
1848+ hidden_states = self .model (input_ids = input_ids ,
1849+ positions = positions ,
1850+ intermediate_tensors = intermediate_tensors ,
1851+ inputs_embeds = inputs_embeds )
1852+ if self .use_aux_hidden_state_outputs :
1853+ hidden_states , _ = hidden_states
1854+ else :
1855+ hidden_states = hidden_states
1856+ if self .use_spec_decode and isinstance (self .drafter , EagleProposer ):
1857+ self .drafter .dummy_run (num_tokens )
1858+ return hidden_states
1859+
18351860 @torch .inference_mode ()
18361861 def _dummy_run (
18371862 self ,
@@ -1868,20 +1893,11 @@ def _dummy_run(
18681893 if self .is_kv_producer :
18691894 with_prefill = True
18701895
1871- # NOTE: If torchair graph mode and not with_prefill,
1872- # we can't skip_attn, it will cause graph recompile.
1873- if self .torchair_graph_enabled and not with_prefill :
1874- attn_metadata = self .attn_metadata_builder .build_torchair_graph_dummy (
1875- num_reqs = num_reqs , num_actual_tokens = 1 )
1876- elif skip_attn :
1877- attn_metadata = None
1878- else :
1879- # TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
1880- attn_metadata = None
1896+ attn_metadata = self ._build_attention_metadata (with_prefill , num_reqs ,
1897+ skip_attn )
18811898
18821899 with self .maybe_dummy_run_with_lora (self .lora_config ,
18831900 num_scheduled_tokens ):
1884- model = self .model
18851901 if self .is_multimodal_model :
18861902 input_ids = None
18871903 inputs_embeds = self .inputs_embeds [:num_tokens ]
@@ -1917,61 +1933,10 @@ def _dummy_run(
19171933 in_profile_run = self .in_profile_run ,
19181934 num_actual_tokens = 0 ,
19191935 ):
1920- model_kwargs = {}
1921- if self .torchair_graph_enabled and not with_prefill :
1922- # Only mark static while compiling
1923- if is_torchair_compile :
1924- torch ._dynamo .mark_static (input_ids )
1925- torch ._dynamo .mark_static (positions )
1926- torch ._dynamo .mark_static (
1927- attn_metadata .decode .block_table )
1928- torch ._dynamo .mark_static (
1929- attn_metadata .decode .input_positions )
1930- torch ._dynamo .mark_static (
1931- get_forward_context ().mc2_mask )
1932- if hasattr (attn_metadata .decode , "sin" ):
1933- torch ._dynamo .mark_static (attn_metadata .decode .sin )
1934- torch ._dynamo .mark_static (attn_metadata .decode .cos )
1935- torch ._dynamo .mark_static (attn_metadata .slot_mapping )
1936- if self .speculative_config :
1937- torch ._dynamo .mark_static (
1938- attn_metadata .decode .attn_mask )
1939- for kv in self .kv_caches :
1940- assert isinstance (
1941- kv , tuple ), "kv_cache must be a tuple"
1942- torch ._dynamo .mark_static (kv [0 ])
1943- torch ._dynamo .mark_static (kv [1 ])
1944-
1945- maybe_converting_weight_acl_format (self .model ,
1946- ACL_FORMAT_FRACTAL_NZ )
1947-
1948- compiled_model = self ._get_torchair_lazy_compiled_model (
1949- num_tokens )
1950- model_kwargs ["kv_caches" ] = self .kv_caches
1951- model_kwargs ["attn_metadata" ] = attn_metadata
1952- hidden_states = compiled_model (
1953- input_ids = input_ids ,
1954- positions = positions ,
1955- intermediate_tensors = intermediate_tensors ,
1956- inputs_embeds = None ,
1957- ** model_kwargs ,
1958- )
1959- else :
1960- maybe_converting_weight_acl_format (self .model ,
1961- ACL_FORMAT_FRACTAL_ND )
1962-
1963- hidden_states = model (
1964- input_ids = input_ids ,
1965- positions = positions ,
1966- intermediate_tensors = intermediate_tensors ,
1967- inputs_embeds = inputs_embeds )
1968- if self .use_aux_hidden_state_outputs :
1969- hidden_states , _ = hidden_states
1970- else :
1971- hidden_states = hidden_states
1972- if self .use_spec_decode and isinstance (
1973- self .drafter , EagleProposer ):
1974- self .drafter .dummy_run (num_tokens )
1936+ hidden_states = self ._generate_dummy_run_hidden_states (
1937+ with_prefill , is_torchair_compile , input_ids , positions ,
1938+ attn_metadata , num_tokens , intermediate_tensors ,
1939+ inputs_embeds )
19751940 if self .speculative_config and self .speculative_config .method == "deepseek_mtp" :
19761941 assert isinstance (self .drafter , MtpProposer )
19771942 self .drafter .dummy_run (
0 commit comments