|
62 | 62 | from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata |
63 | 63 | from vllm_ascend.platform import NPUPlatform |
64 | 64 | from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler |
| 65 | +from vllm_ascend.utils import ProfileExecuteDuration |
65 | 66 | from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer |
66 | 67 |
|
67 | 68 | if TYPE_CHECKING: |
@@ -663,36 +664,38 @@ def _process_reqs( |
663 | 664 | with set_forward_context(attn_metadata, |
664 | 665 | self.vllm_config, |
665 | 666 | num_tokens=num_input_tokens): |
666 | | - model_kwargs = {} |
667 | | - if self.enable_torchair_graph_mode: |
668 | | - model_kwargs["kv_caches"] = self.kv_caches |
669 | | - model_kwargs["attn_metadata"] = attn_metadata |
670 | | - if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: |
671 | | - torch._dynamo.mark_static(input_ids) |
672 | | - torch._dynamo.mark_static(positions) |
673 | | - torch._dynamo.mark_static(attn_metadata.decode.block_table) |
674 | | - torch._dynamo.mark_static(attn_metadata.decode.input_positions) |
675 | | - torch._dynamo.mark_static(attn_metadata.slot_mapping) |
676 | | - for kv in self.kv_caches: |
677 | | - if isinstance(kv, tuple): |
678 | | - torch._dynamo.mark_static(kv[0]) |
679 | | - torch._dynamo.mark_static(kv[1]) |
680 | | - hidden_states = self.compile_model( |
681 | | - input_ids=input_ids, |
682 | | - positions=positions, |
683 | | - intermediate_tensors=intermediate_tensors, |
684 | | - inputs_embeds=None, |
685 | | - **model_kwargs, |
686 | | - ) |
687 | | - else: |
688 | | - assert self.model is not None |
689 | | - hidden_states = self.model( |
690 | | - input_ids=input_ids, |
691 | | - positions=positions, |
692 | | - intermediate_tensors=intermediate_tensors, |
693 | | - inputs_embeds=None, |
694 | | - **model_kwargs, |
695 | | - ) |
| 667 | + with ProfileExecuteDuration().capture_async("forward"): |
| 668 | + model_kwargs = {} |
| 669 | + if self.enable_torchair_graph_mode: |
| 670 | + model_kwargs["kv_caches"] = self.kv_caches |
| 671 | + model_kwargs["attn_metadata"] = attn_metadata |
| 672 | + if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: |
| 673 | + torch._dynamo.mark_static(input_ids) |
| 674 | + torch._dynamo.mark_static(positions) |
| 675 | + torch._dynamo.mark_static(attn_metadata.decode.block_table) |
| 676 | + torch._dynamo.mark_static( |
| 677 | + attn_metadata.decode.input_positions) |
| 678 | + torch._dynamo.mark_static(attn_metadata.slot_mapping) |
| 679 | + for kv in self.kv_caches: |
| 680 | + if isinstance(kv, tuple): |
| 681 | + torch._dynamo.mark_static(kv[0]) |
| 682 | + torch._dynamo.mark_static(kv[1]) |
| 683 | + hidden_states = self.compile_model( |
| 684 | + input_ids=input_ids, |
| 685 | + positions=positions, |
| 686 | + intermediate_tensors=intermediate_tensors, |
| 687 | + inputs_embeds=None, |
| 688 | + **model_kwargs, |
| 689 | + ) |
| 690 | + else: |
| 691 | + assert self.model is not None |
| 692 | + hidden_states = self.model( |
| 693 | + input_ids=input_ids, |
| 694 | + positions=positions, |
| 695 | + intermediate_tensors=intermediate_tensors, |
| 696 | + inputs_embeds=None, |
| 697 | + **model_kwargs, |
| 698 | + ) |
696 | 699 |
|
697 | 700 | use_spec_decode = len( |
698 | 701 | scheduler_output.scheduled_spec_decode_tokens) > 0 |
@@ -885,103 +888,113 @@ def execute_model( |
885 | 888 | scheduler_output: "SchedulerOutput", |
886 | 889 | intermediate_tensors: Optional[IntermediateTensors] = None, |
887 | 890 | ) -> Union[ModelRunnerOutput, torch.Tensor]: |
888 | | - self._update_states(scheduler_output) |
889 | | - if not scheduler_output.total_num_scheduled_tokens: |
890 | | - # Return empty ModelRunnerOuptut if there's no work to do. |
891 | | - return EMPTY_MODEL_RUNNER_OUTPUT |
892 | | - (attn_metadata, hidden_states, spec_decode_metadata, positions, |
893 | | - num_scheduled_tokens, |
894 | | - sample_indices) = (self._process_reqs(scheduler_output, |
895 | | - intermediate_tensors)) |
896 | | - logits = self.model.compute_logits(hidden_states[sample_indices], None) |
897 | | - |
898 | | - # Apply structured output bitmasks if present |
899 | | - if scheduler_output.grammar_bitmask is not None: |
900 | | - logits = self.apply_grammar_bitmask(scheduler_output, logits) |
901 | | - |
902 | | - # Sample the next token and get logprobs if needed. |
903 | | - sampling_metadata = self.input_batch.sampling_metadata |
904 | | - if spec_decode_metadata is None: |
905 | | - sampler_output = self.sampler( |
906 | | - logits=logits, |
907 | | - sampling_metadata=sampling_metadata, |
908 | | - ) |
909 | | - else: |
910 | | - # When indexing with a tensor (bonus_logits_indices), PyTorch |
911 | | - # creates a new tensor with separate storage from the original |
912 | | - # logits tensor. This means any in-place operations on bonus_logits |
913 | | - # won't affect the original logits tensor. |
914 | | - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] |
915 | | - sampler_output = self.sampler( |
916 | | - logits=bonus_logits, |
917 | | - sampling_metadata=sampling_metadata, |
918 | | - ) |
919 | | - bonus_token_ids = sampler_output.sampled_token_ids |
| 891 | + with ProfileExecuteDuration().capture_async( |
| 892 | + "prepare input and forward"): |
| 893 | + self._update_states(scheduler_output) |
| 894 | + if not scheduler_output.total_num_scheduled_tokens: |
| 895 | + # Return empty ModelRunnerOuptut if there's no work to do. |
| 896 | + return EMPTY_MODEL_RUNNER_OUTPUT |
| 897 | + (attn_metadata, hidden_states, spec_decode_metadata, positions, |
| 898 | + num_scheduled_tokens, |
| 899 | + sample_indices) = (self._process_reqs(scheduler_output, |
| 900 | + intermediate_tensors)) |
| 901 | + |
| 902 | + with ProfileExecuteDuration().capture_async("post process"): |
| 903 | + logits = self.model.compute_logits(hidden_states[sample_indices], |
| 904 | + None) |
| 905 | + |
| 906 | + # Apply structured output bitmasks if present |
| 907 | + if scheduler_output.grammar_bitmask is not None: |
| 908 | + logits = self.apply_grammar_bitmask(scheduler_output, logits) |
| 909 | + |
| 910 | + # Sample the next token and get logprobs if needed. |
| 911 | + sampling_metadata = self.input_batch.sampling_metadata |
| 912 | + if spec_decode_metadata is None: |
| 913 | + sampler_output = self.sampler( |
| 914 | + logits=logits, |
| 915 | + sampling_metadata=sampling_metadata, |
| 916 | + ) |
| 917 | + else: |
| 918 | + # When indexing with a tensor (bonus_logits_indices), PyTorch |
| 919 | + # creates a new tensor with separate storage from the original |
| 920 | + # logits tensor. This means any in-place operations on bonus_logits |
| 921 | + # won't affect the original logits tensor. |
| 922 | + bonus_logits = logits[ |
| 923 | + spec_decode_metadata.bonus_logits_indices] |
| 924 | + sampler_output = self.sampler( |
| 925 | + logits=bonus_logits, |
| 926 | + sampling_metadata=sampling_metadata, |
| 927 | + ) |
| 928 | + bonus_token_ids = sampler_output.sampled_token_ids |
| 929 | + |
| 930 | + # Just like `bonus_logits`, `target_logits` is a new tensor with |
| 931 | + # separate storage from the original `logits` tensor. Therefore, |
| 932 | + # it is safe to update `target_logits` in place. |
| 933 | + target_logits = logits[ |
| 934 | + spec_decode_metadata.target_logits_indices] |
| 935 | + output_token_ids = self.rejection_sampler( |
| 936 | + spec_decode_metadata, |
| 937 | + None, # draft_probs |
| 938 | + target_logits, |
| 939 | + bonus_token_ids, |
| 940 | + sampling_metadata, |
| 941 | + ) |
| 942 | + sampler_output.sampled_token_ids = output_token_ids |
920 | 943 |
|
921 | | - # Just like `bonus_logits`, `target_logits` is a new tensor with |
922 | | - # separate storage from the original `logits` tensor. Therefore, |
923 | | - # it is safe to update `target_logits` in place. |
924 | | - target_logits = logits[spec_decode_metadata.target_logits_indices] |
925 | | - output_token_ids = self.rejection_sampler( |
926 | | - spec_decode_metadata, |
927 | | - None, # draft_probs |
928 | | - target_logits, |
929 | | - bonus_token_ids, |
| 944 | + # TODO(woosuk): The following loop can be slow since it iterates over |
| 945 | + # the requests one by one. Optimize. |
| 946 | + for i, req_id in enumerate(self.input_batch.req_ids): |
| 947 | + req_state = self.requests[req_id] |
| 948 | + seq_len = (req_state.num_computed_tokens + |
| 949 | + scheduler_output.num_scheduled_tokens[req_id]) |
| 950 | + if seq_len < req_state.num_tokens: |
| 951 | + # Ignore the sampled token. |
| 952 | + # Rewind the generator state as if the token was not sampled. |
| 953 | + generator = self.input_batch.generators.get(i) |
| 954 | + if generator is not None: |
| 955 | + generator.set_offset(generator.get_offset() - 4) |
| 956 | + |
| 957 | + # NOTE: NPU -> CPU Sync happens here. |
| 958 | + # Move as many CPU operations as possible before this sync point. |
| 959 | + logprobs_tensors = sampler_output.logprobs_tensors |
| 960 | + logprobs_lists = logprobs_tensors.tolists() \ |
| 961 | + if logprobs_tensors is not None else None |
| 962 | + |
| 963 | + # Get the valid generated tokens. |
| 964 | + sampled_token_ids = sampler_output.sampled_token_ids |
| 965 | + max_gen_len = sampled_token_ids.shape[-1] |
| 966 | + if max_gen_len == 1: |
| 967 | + # No spec decode tokens. |
| 968 | + valid_sampled_token_ids = sampled_token_ids.tolist() |
| 969 | + else: |
| 970 | + # Includes spec decode tokens. |
| 971 | + valid_sampled_token_ids = self.rejection_sampler.parse_output( |
| 972 | + sampled_token_ids, |
| 973 | + self.input_batch.vocab_size, |
| 974 | + ) |
| 975 | + |
| 976 | + spec_token_ids = self._get_spec_token_ids( |
| 977 | + valid_sampled_token_ids, |
930 | 978 | sampling_metadata, |
| 979 | + scheduler_output, |
| 980 | + spec_decode_metadata, |
| 981 | + positions, |
| 982 | + num_scheduled_tokens, |
| 983 | + hidden_states, |
| 984 | + attn_metadata, |
931 | 985 | ) |
932 | | - sampler_output.sampled_token_ids = output_token_ids |
933 | 986 |
|
934 | | - # TODO(woosuk): The following loop can be slow since it iterates over |
935 | | - # the requests one by one. Optimize. |
936 | | - for i, req_id in enumerate(self.input_batch.req_ids): |
937 | | - req_state = self.requests[req_id] |
938 | | - seq_len = (req_state.num_computed_tokens + |
939 | | - scheduler_output.num_scheduled_tokens[req_id]) |
940 | | - if seq_len < req_state.num_tokens: |
941 | | - # Ignore the sampled token. |
942 | | - # Rewind the generator state as if the token was not sampled. |
943 | | - generator = self.input_batch.generators.get(i) |
944 | | - if generator is not None: |
945 | | - generator.set_offset(generator.get_offset() - 4) |
946 | | - |
947 | | - # NOTE: NPU -> CPU Sync happens here. |
948 | | - # Move as many CPU operations as possible before this sync point. |
949 | | - logprobs_tensors = sampler_output.logprobs_tensors |
950 | | - logprobs_lists = logprobs_tensors.tolists() \ |
951 | | - if logprobs_tensors is not None else None |
952 | | - |
953 | | - # Get the valid generated tokens. |
954 | | - sampled_token_ids = sampler_output.sampled_token_ids |
955 | | - max_gen_len = sampled_token_ids.shape[-1] |
956 | | - if max_gen_len == 1: |
957 | | - # No spec decode tokens. |
958 | | - valid_sampled_token_ids = sampled_token_ids.tolist() |
959 | | - else: |
960 | | - # Includes spec decode tokens. |
961 | | - valid_sampled_token_ids = self.rejection_sampler.parse_output( |
962 | | - sampled_token_ids, |
963 | | - self.input_batch.vocab_size, |
| 987 | + model_runner_output = ModelRunnerOutput( |
| 988 | + req_ids=self.input_batch.req_ids, |
| 989 | + req_id_to_index=self.input_batch.req_id_to_index, |
| 990 | + sampled_token_ids=valid_sampled_token_ids, |
| 991 | + spec_token_ids=spec_token_ids, |
| 992 | + logprobs=logprobs_lists, |
| 993 | + prompt_logprobs_dict={}, |
964 | 994 | ) |
965 | 995 |
|
966 | | - spec_token_ids = self._get_spec_token_ids( |
967 | | - valid_sampled_token_ids, |
968 | | - sampling_metadata, |
969 | | - scheduler_output, |
970 | | - spec_decode_metadata, |
971 | | - positions, |
972 | | - num_scheduled_tokens, |
973 | | - hidden_states, |
974 | | - attn_metadata, |
975 | | - ) |
976 | | - |
977 | | - model_runner_output = ModelRunnerOutput( |
978 | | - req_ids=self.input_batch.req_ids, |
979 | | - req_id_to_index=self.input_batch.req_id_to_index, |
980 | | - sampled_token_ids=valid_sampled_token_ids, |
981 | | - spec_token_ids=spec_token_ids, |
982 | | - logprobs=logprobs_lists, |
983 | | - prompt_logprobs_dict={}, |
984 | | - ) |
| 996 | + capture_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" |
| 997 | + ProfileExecuteDuration().pop_captured_sync(capture_name) |
985 | 998 | return model_runner_output |
986 | 999 |
|
987 | 1000 | def _profile_multimodal(self) -> None: |
|
0 commit comments