|
15 | 15 | # limitations under the License. |
16 | 16 | # This file is a part of the vllm-ascend project. |
17 | 17 | # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py |
18 | | -# |
19 | 18 |
|
20 | 19 | import gc |
21 | 20 | import os |
@@ -1334,86 +1333,77 @@ def _get_spec_token_ids( |
1334 | 1333 | assert isinstance(self.drafter, NgramProposer) |
1335 | 1334 | spec_token_ids = self._generate_draft_token_ids( |
1336 | 1335 | valid_sampled_token_ids, sampling_metadata) |
1337 | | - elif self.speculative_config.method == "eagle": |
1338 | | - raise NotImplementedError("Eagle Is Not Supported Yet.") |
1339 | | - elif self.speculative_config.method == "eagle3": |
| 1336 | + elif self.use_eagle: |
1340 | 1337 | assert isinstance(self.drafter, EagleProposer) |
1341 | | - if self.speculative_config.use_eagle(): |
1342 | | - next_token_ids: list[int] = [] |
1343 | | - for i, token_ids in enumerate(valid_sampled_token_ids): |
1344 | | - if token_ids: |
1345 | | - # Common case. |
1346 | | - next_token_id = token_ids[-1] |
1347 | | - else: |
1348 | | - # Partial prefill (rare case). |
1349 | | - # Get the next token id from the request state. |
1350 | | - req_id = self.input_batch.req_ids[i] |
1351 | | - req_state = self.requests[req_id] |
1352 | | - seq_len = ( |
1353 | | - req_state.num_computed_tokens + |
1354 | | - scheduler_output.num_scheduled_tokens[req_id]) |
1355 | | - |
1356 | | - next_token_id = req_state.get_token_id(seq_len) |
1357 | | - next_token_ids.append(next_token_id) |
1358 | | - next_token_ids = torch.tensor(next_token_ids, |
1359 | | - dtype=torch.int32, |
1360 | | - device=self.device) |
1361 | | - eagle_attn_metadata = attn_metadata[ |
1362 | | - self.drafter.attn_layer_name] |
1363 | | - num_input_tokens = scheduler_output.total_num_scheduled_tokens |
1364 | | - if spec_decode_metadata is None: |
1365 | | - # input_ids can be None for multimodal models. |
1366 | | - target_token_ids = self.input_ids[:num_scheduled_tokens] |
1367 | | - target_positions = positions[:num_scheduled_tokens] |
1368 | | - if self.use_aux_hidden_state_outputs: |
1369 | | - target_hidden_states = torch.cat([ |
1370 | | - h[:num_scheduled_tokens] for h in aux_hidden_states |
1371 | | - ], |
1372 | | - dim=-1) |
1373 | | - else: |
1374 | | - target_hidden_states = hidden_states[: |
1375 | | - num_scheduled_tokens] |
1376 | | - target_slot_mapping = eagle_attn_metadata.slot_mapping |
1377 | | - cu_num_tokens = eagle_attn_metadata.query_start_loc |
| 1338 | + next_token_ids: list[int] = [] |
| 1339 | + for i, token_ids in enumerate(valid_sampled_token_ids): |
| 1340 | + if token_ids: |
| 1341 | + # Common case. |
| 1342 | + next_token_id = token_ids[-1] |
1378 | 1343 | else: |
1379 | | - num_draft_tokens = spec_decode_metadata.num_draft_tokens |
1380 | | - num_rejected_tokens = [ |
1381 | | - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 |
1382 | | - for i, n in enumerate(num_draft_tokens) |
1383 | | - ] |
1384 | | - num_rejected_tokens = torch.tensor( |
1385 | | - num_rejected_tokens, |
1386 | | - dtype=torch.int32, |
1387 | | - device=self.device, |
1388 | | - ) |
1389 | | - num_tokens = num_scheduled_tokens - sum( |
1390 | | - num_rejected_tokens) |
1391 | | - cu_num_tokens, token_indices = self.drafter.prepare_inputs( |
1392 | | - eagle_attn_metadata.query_start_loc, |
1393 | | - num_rejected_tokens, num_tokens) |
1394 | | - target_token_ids = self.input_ids[token_indices] |
1395 | | - target_positions = positions[token_indices] |
1396 | | - if self.use_aux_hidden_state_outputs: |
1397 | | - target_hidden_states = torch.cat( |
1398 | | - [h[token_indices] for h in aux_hidden_states], |
1399 | | - dim=-1) |
1400 | | - else: |
1401 | | - target_hidden_states = hidden_states[token_indices] |
1402 | | - target_slot_mapping = eagle_attn_metadata.slot_mapping[ |
1403 | | - token_indices] |
1404 | | - |
1405 | | - positions = self.positions[:num_input_tokens] |
1406 | | - draft_token_ids = self.drafter.propose( |
1407 | | - target_token_ids=target_token_ids, |
1408 | | - target_positions=target_positions, |
1409 | | - target_hidden_states=target_hidden_states, |
1410 | | - target_slot_mapping=target_slot_mapping, |
1411 | | - next_token_ids=next_token_ids, |
1412 | | - cu_num_tokens=cu_num_tokens, |
1413 | | - block_table=eagle_attn_metadata.block_tables, |
1414 | | - sampling_metadata=sampling_metadata, |
| 1344 | + # Partial prefill (rare case). |
| 1345 | + # Get the next token id from the request state. |
| 1346 | + req_id = self.input_batch.req_ids[i] |
| 1347 | + req_state = self.requests[req_id] |
| 1348 | + seq_len = (req_state.num_computed_tokens + |
| 1349 | + scheduler_output.num_scheduled_tokens[req_id]) |
| 1350 | + |
| 1351 | + next_token_id = req_state.get_token_id(seq_len) |
| 1352 | + next_token_ids.append(next_token_id) |
| 1353 | + next_token_ids = torch.tensor(next_token_ids, |
| 1354 | + dtype=torch.int32, |
| 1355 | + device=self.device) |
| 1356 | + eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] |
| 1357 | + num_input_tokens = scheduler_output.total_num_scheduled_tokens |
| 1358 | + if spec_decode_metadata is None: |
| 1359 | + # input_ids can be None for multimodal models. |
| 1360 | + target_token_ids = self.input_ids[:num_scheduled_tokens] |
| 1361 | + target_positions = positions[:num_scheduled_tokens] |
| 1362 | + if self.use_aux_hidden_state_outputs: |
| 1363 | + target_hidden_states = torch.cat( |
| 1364 | + [h[:num_scheduled_tokens] for h in aux_hidden_states], |
| 1365 | + dim=-1) |
| 1366 | + else: |
| 1367 | + target_hidden_states = hidden_states[:num_scheduled_tokens] |
| 1368 | + target_slot_mapping = eagle_attn_metadata.slot_mapping |
| 1369 | + cu_num_tokens = eagle_attn_metadata.query_start_loc |
| 1370 | + else: |
| 1371 | + num_draft_tokens = spec_decode_metadata.num_draft_tokens |
| 1372 | + num_rejected_tokens = [ |
| 1373 | + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 |
| 1374 | + for i, n in enumerate(num_draft_tokens) |
| 1375 | + ] |
| 1376 | + num_rejected_tokens = torch.tensor( |
| 1377 | + num_rejected_tokens, |
| 1378 | + dtype=torch.int32, |
| 1379 | + device=self.device, |
1415 | 1380 | ) |
1416 | | - spec_token_ids = draft_token_ids.tolist() |
| 1381 | + num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) |
| 1382 | + cu_num_tokens, token_indices = self.drafter.prepare_inputs( |
| 1383 | + eagle_attn_metadata.query_start_loc, num_rejected_tokens, |
| 1384 | + num_tokens) |
| 1385 | + target_token_ids = self.input_ids[token_indices] |
| 1386 | + target_positions = positions[token_indices] |
| 1387 | + if self.use_aux_hidden_state_outputs: |
| 1388 | + target_hidden_states = torch.cat( |
| 1389 | + [h[token_indices] for h in aux_hidden_states], dim=-1) |
| 1390 | + else: |
| 1391 | + target_hidden_states = hidden_states[token_indices] |
| 1392 | + target_slot_mapping = eagle_attn_metadata.slot_mapping[ |
| 1393 | + token_indices] |
| 1394 | + |
| 1395 | + positions = self.positions[:num_input_tokens] |
| 1396 | + draft_token_ids = self.drafter.propose( |
| 1397 | + target_token_ids=target_token_ids, |
| 1398 | + target_positions=target_positions, |
| 1399 | + target_hidden_states=target_hidden_states, |
| 1400 | + target_slot_mapping=target_slot_mapping, |
| 1401 | + next_token_ids=next_token_ids, |
| 1402 | + cu_num_tokens=cu_num_tokens, |
| 1403 | + block_table=eagle_attn_metadata.block_tables, |
| 1404 | + sampling_metadata=sampling_metadata, |
| 1405 | + ) |
| 1406 | + spec_token_ids = draft_token_ids.tolist() |
1417 | 1407 | elif self.speculative_config.method == 'deepseek_mtp': |
1418 | 1408 | assert isinstance(self.drafter, MtpProposer) |
1419 | 1409 | spec_token_ids = self._generate_mtp_token_ids( |
@@ -1797,7 +1787,7 @@ def load_model(self) -> None: |
1797 | 1787 | self.model = get_model(vllm_config=self.vllm_config) |
1798 | 1788 | if self.drafter: |
1799 | 1789 | logger.info("Loading drafter model...") |
1800 | | - if self.use_aux_hidden_state_outputs: |
| 1790 | + if self.use_eagle: |
1801 | 1791 | self.drafter.load_model(self.model) |
1802 | 1792 | self.model.set_aux_hidden_state_layers( |
1803 | 1793 | self.model.get_eagle3_aux_hidden_state_layers()) |
|
0 commit comments