|
15 | 15 | # limitations under the License. |
16 | 16 | # This file is a part of the vllm-ascend project. |
17 | 17 | # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py |
18 | | -# |
19 | 18 |
|
20 | 19 | import gc |
21 | 20 | import os |
@@ -1410,86 +1409,77 @@ def _get_spec_token_ids( |
1410 | 1409 | assert isinstance(self.drafter, NgramProposer) |
1411 | 1410 | spec_token_ids = self._generate_draft_token_ids( |
1412 | 1411 | valid_sampled_token_ids, sampling_metadata) |
1413 | | - elif self.speculative_config.method == "eagle": |
1414 | | - raise NotImplementedError("Eagle Is Not Supported Yet.") |
1415 | | - elif self.speculative_config.method == "eagle3": |
| 1412 | + elif self.use_eagle: |
1416 | 1413 | assert isinstance(self.drafter, EagleProposer) |
1417 | | - if self.speculative_config.use_eagle(): |
1418 | | - next_token_ids: list[int] = [] |
1419 | | - for i, token_ids in enumerate(valid_sampled_token_ids): |
1420 | | - if token_ids: |
1421 | | - # Common case. |
1422 | | - next_token_id = token_ids[-1] |
1423 | | - else: |
1424 | | - # Partial prefill (rare case). |
1425 | | - # Get the next token id from the request state. |
1426 | | - req_id = self.input_batch.req_ids[i] |
1427 | | - req_state = self.requests[req_id] |
1428 | | - seq_len = ( |
1429 | | - req_state.num_computed_tokens + |
1430 | | - scheduler_output.num_scheduled_tokens[req_id]) |
1431 | | - |
1432 | | - next_token_id = req_state.get_token_id(seq_len) |
1433 | | - next_token_ids.append(next_token_id) |
1434 | | - next_token_ids = torch.tensor(next_token_ids, |
1435 | | - dtype=torch.int32, |
1436 | | - device=self.device) |
1437 | | - eagle_attn_metadata = attn_metadata[ |
1438 | | - self.drafter.attn_layer_name] |
1439 | | - num_input_tokens = scheduler_output.total_num_scheduled_tokens |
1440 | | - if spec_decode_metadata is None: |
1441 | | - # input_ids can be None for multimodal models. |
1442 | | - target_token_ids = self.input_ids[:num_scheduled_tokens] |
1443 | | - target_positions = positions[:num_scheduled_tokens] |
1444 | | - if self.use_aux_hidden_state_outputs: |
1445 | | - target_hidden_states = torch.cat([ |
1446 | | - h[:num_scheduled_tokens] for h in aux_hidden_states |
1447 | | - ], |
1448 | | - dim=-1) |
1449 | | - else: |
1450 | | - target_hidden_states = hidden_states[: |
1451 | | - num_scheduled_tokens] |
1452 | | - target_slot_mapping = eagle_attn_metadata.slot_mapping |
1453 | | - cu_num_tokens = eagle_attn_metadata.query_start_loc |
| 1414 | + next_token_ids: list[int] = [] |
| 1415 | + for i, token_ids in enumerate(valid_sampled_token_ids): |
| 1416 | + if token_ids: |
| 1417 | + # Common case. |
| 1418 | + next_token_id = token_ids[-1] |
1454 | 1419 | else: |
1455 | | - num_draft_tokens = spec_decode_metadata.num_draft_tokens |
1456 | | - num_rejected_tokens = [ |
1457 | | - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 |
1458 | | - for i, n in enumerate(num_draft_tokens) |
1459 | | - ] |
1460 | | - num_rejected_tokens = torch.tensor( |
1461 | | - num_rejected_tokens, |
1462 | | - dtype=torch.int32, |
1463 | | - device=self.device, |
1464 | | - ) |
1465 | | - num_tokens = num_scheduled_tokens - sum( |
1466 | | - num_rejected_tokens) |
1467 | | - cu_num_tokens, token_indices = self.drafter.prepare_inputs( |
1468 | | - eagle_attn_metadata.query_start_loc, |
1469 | | - num_rejected_tokens, num_tokens) |
1470 | | - target_token_ids = self.input_ids[token_indices] |
1471 | | - target_positions = positions[token_indices] |
1472 | | - if self.use_aux_hidden_state_outputs: |
1473 | | - target_hidden_states = torch.cat( |
1474 | | - [h[token_indices] for h in aux_hidden_states], |
1475 | | - dim=-1) |
1476 | | - else: |
1477 | | - target_hidden_states = hidden_states[token_indices] |
1478 | | - target_slot_mapping = eagle_attn_metadata.slot_mapping[ |
1479 | | - token_indices] |
1480 | | - |
1481 | | - positions = self.positions[:num_input_tokens] |
1482 | | - draft_token_ids = self.drafter.propose( |
1483 | | - target_token_ids=target_token_ids, |
1484 | | - target_positions=target_positions, |
1485 | | - target_hidden_states=target_hidden_states, |
1486 | | - target_slot_mapping=target_slot_mapping, |
1487 | | - next_token_ids=next_token_ids, |
1488 | | - cu_num_tokens=cu_num_tokens, |
1489 | | - block_table=eagle_attn_metadata.block_tables, |
1490 | | - sampling_metadata=sampling_metadata, |
| 1420 | + # Partial prefill (rare case). |
| 1421 | + # Get the next token id from the request state. |
| 1422 | + req_id = self.input_batch.req_ids[i] |
| 1423 | + req_state = self.requests[req_id] |
| 1424 | + seq_len = (req_state.num_computed_tokens + |
| 1425 | + scheduler_output.num_scheduled_tokens[req_id]) |
| 1426 | + |
| 1427 | + next_token_id = req_state.get_token_id(seq_len) |
| 1428 | + next_token_ids.append(next_token_id) |
| 1429 | + next_token_ids = torch.tensor(next_token_ids, |
| 1430 | + dtype=torch.int32, |
| 1431 | + device=self.device) |
| 1432 | + eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] |
| 1433 | + num_input_tokens = scheduler_output.total_num_scheduled_tokens |
| 1434 | + if spec_decode_metadata is None: |
| 1435 | + # input_ids can be None for multimodal models. |
| 1436 | + target_token_ids = self.input_ids[:num_scheduled_tokens] |
| 1437 | + target_positions = positions[:num_scheduled_tokens] |
| 1438 | + if self.use_aux_hidden_state_outputs: |
| 1439 | + target_hidden_states = torch.cat( |
| 1440 | + [h[:num_scheduled_tokens] for h in aux_hidden_states], |
| 1441 | + dim=-1) |
| 1442 | + else: |
| 1443 | + target_hidden_states = hidden_states[:num_scheduled_tokens] |
| 1444 | + target_slot_mapping = eagle_attn_metadata.slot_mapping |
| 1445 | + cu_num_tokens = eagle_attn_metadata.query_start_loc |
| 1446 | + else: |
| 1447 | + num_draft_tokens = spec_decode_metadata.num_draft_tokens |
| 1448 | + num_rejected_tokens = [ |
| 1449 | + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 |
| 1450 | + for i, n in enumerate(num_draft_tokens) |
| 1451 | + ] |
| 1452 | + num_rejected_tokens = torch.tensor( |
| 1453 | + num_rejected_tokens, |
| 1454 | + dtype=torch.int32, |
| 1455 | + device=self.device, |
1491 | 1456 | ) |
1492 | | - spec_token_ids = draft_token_ids.tolist() |
| 1457 | + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) |
| 1458 | + cu_num_tokens, token_indices = self.drafter.prepare_inputs( |
| 1459 | + eagle_attn_metadata.query_start_loc, num_rejected_tokens, |
| 1460 | + num_tokens) |
| 1461 | + target_token_ids = self.input_ids[token_indices] |
| 1462 | + target_positions = positions[token_indices] |
| 1463 | + if self.use_aux_hidden_state_outputs: |
| 1464 | + target_hidden_states = torch.cat( |
| 1465 | + [h[token_indices] for h in aux_hidden_states], dim=-1) |
| 1466 | + else: |
| 1467 | + target_hidden_states = hidden_states[token_indices] |
| 1468 | + target_slot_mapping = eagle_attn_metadata.slot_mapping[ |
| 1469 | + token_indices] |
| 1470 | + |
| 1471 | + positions = self.positions[:num_input_tokens] |
| 1472 | + draft_token_ids = self.drafter.propose( |
| 1473 | + target_token_ids=target_token_ids, |
| 1474 | + target_positions=target_positions, |
| 1475 | + target_hidden_states=target_hidden_states, |
| 1476 | + target_slot_mapping=target_slot_mapping, |
| 1477 | + next_token_ids=next_token_ids, |
| 1478 | + cu_num_tokens=cu_num_tokens, |
| 1479 | + block_table=eagle_attn_metadata.block_tables, |
| 1480 | + sampling_metadata=sampling_metadata, |
| 1481 | + ) |
| 1482 | + spec_token_ids = draft_token_ids.tolist() |
1493 | 1483 | elif self.speculative_config.method == 'deepseek_mtp': |
1494 | 1484 | assert isinstance(self.drafter, MtpProposer) |
1495 | 1485 | spec_token_ids = self._generate_mtp_token_ids( |
@@ -2001,10 +1991,11 @@ def load_model(self) -> None: |
2001 | 1991 | pass |
2002 | 1992 | if self.drafter: |
2003 | 1993 | logger.info("Loading drafter model...") |
2004 | | - if self.use_aux_hidden_state_outputs: |
| 1994 | + if self.use_eagle: |
2005 | 1995 | self.drafter.load_model(self.model) |
2006 | | - self.model.set_aux_hidden_state_layers( |
2007 | | - self.model.get_eagle3_aux_hidden_state_layers()) |
| 1996 | + if self.use_aux_hidden_state_outputs: |
| 1997 | + self.model.set_aux_hidden_state_layers( |
| 1998 | + self.model.get_eagle3_aux_hidden_state_layers()) |
2008 | 1999 | else: |
2009 | 2000 | self.drafter.load_model() |
2010 | 2001 | if self.lora_config: |
|
0 commit comments