|
41 | 41 | from vllm.model_executor.layers.vocab_parallel_embedding import ( |
42 | 42 | ParallelLMHead, VocabParallelEmbedding) |
43 | 43 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 44 | +from vllm.model_executor.models.interfaces import SupportsPP |
| 45 | +from vllm.model_executor.models.utils import ( |
| 46 | + AutoWeightsLoader, is_pp_missing_parameter, |
| 47 | + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) |
44 | 48 | from vllm.model_executor.sampling_metadata import SamplingMetadata |
45 | 49 | from vllm.sequence import IntermediateTensors |
46 | 50 | from vllm.transformers_utils.configs import FlexOlmoConfig |
47 | 51 |
|
48 | | -from .interfaces import SupportsPP |
49 | | -from .utils import (is_pp_missing_parameter, |
50 | | - make_empty_intermediate_tensors_factory, make_layers, |
51 | | - maybe_prefix) |
52 | | - |
53 | 52 | logger = init_logger(__name__) |
54 | 53 |
|
55 | 54 |
|
@@ -307,6 +306,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
307 | 306 |
|
308 | 307 | config = vllm_config.model_config.hf_config |
309 | 308 | assert isinstance(config, FlexOlmoConfig) |
| 309 | + self.config = config |
310 | 310 |
|
311 | 311 | self.vocab_size = config.vocab_size |
312 | 312 |
|
@@ -359,58 +359,6 @@ def forward( |
359 | 359 | hidden_states = self.norm(hidden_states) |
360 | 360 | return hidden_states |
361 | 361 |
|
362 | | - |
363 | | -class FlexOlmoForCausalLM(nn.Module, SupportsPP): |
364 | | - |
365 | | - fall_back_to_pt_during_load = False |
366 | | - |
367 | | - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
368 | | - super().__init__() |
369 | | - config = vllm_config.model_config.hf_config |
370 | | - assert isinstance(config, FlexOlmoConfig) |
371 | | - quant_config = vllm_config.quant_config |
372 | | - self.config = config |
373 | | - self.quant_config = quant_config |
374 | | - self.model = FlexOlmoModel(vllm_config=vllm_config, |
375 | | - prefix=maybe_prefix(prefix, "model")) |
376 | | - self.lm_head = ParallelLMHead(config.vocab_size, |
377 | | - config.hidden_size, |
378 | | - quant_config=quant_config, |
379 | | - prefix=maybe_prefix(prefix, "lm_head")) |
380 | | - self.logits_processor = LogitsProcessor(config.vocab_size) |
381 | | - self.sampler = get_sampler() |
382 | | - |
383 | | - self.make_empty_intermediate_tensors = ( |
384 | | - self.model.make_empty_intermediate_tensors) |
385 | | - |
386 | | - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
387 | | - return self.model.get_input_embeddings(input_ids) |
388 | | - |
389 | | - def forward( |
390 | | - self, |
391 | | - input_ids: torch.Tensor, |
392 | | - positions: torch.Tensor, |
393 | | - intermediate_tensors: Optional[IntermediateTensors] = None, |
394 | | - inputs_embeds: Optional[torch.Tensor] = None, |
395 | | - ) -> Union[torch.Tensor, IntermediateTensors]: |
396 | | - hidden_states = self.model(input_ids, positions, intermediate_tensors, |
397 | | - inputs_embeds) |
398 | | - return hidden_states |
399 | | - |
400 | | - def compute_logits(self, hidden_states: torch.Tensor, |
401 | | - sampling_metadata: SamplingMetadata) -> torch.Tensor: |
402 | | - logits = self.logits_processor(self.lm_head, hidden_states, |
403 | | - sampling_metadata) |
404 | | - return logits |
405 | | - |
406 | | - def sample( |
407 | | - self, |
408 | | - logits: Optional[torch.Tensor], |
409 | | - sampling_metadata: SamplingMetadata, |
410 | | - ) -> Optional[SamplerOutput]: |
411 | | - next_tokens = self.sampler(logits, sampling_metadata) |
412 | | - return next_tokens |
413 | | - |
414 | 362 | def load_weights(self, weights: Iterable[tuple[str, |
415 | 363 | torch.Tensor]]) -> set[str]: |
416 | 364 | stacked_params_mapping = [ |
@@ -508,3 +456,58 @@ def load_weights(self, weights: Iterable[tuple[str, |
508 | 456 | weight_loader(param, loaded_weight) |
509 | 457 | loaded_params.add(name) |
510 | 458 | return loaded_params |
| 459 | + |
| 460 | + |
| 461 | +class FlexOlmoForCausalLM(nn.Module, SupportsPP): |
| 462 | + |
| 463 | + fall_back_to_pt_during_load = False |
| 464 | + |
| 465 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 466 | + super().__init__() |
| 467 | + config = vllm_config.model_config.hf_config |
| 468 | + assert isinstance(config, FlexOlmoConfig) |
| 469 | + quant_config = vllm_config.quant_config |
| 470 | + self.quant_config = quant_config |
| 471 | + self.model = FlexOlmoModel(vllm_config=vllm_config, |
| 472 | + prefix=maybe_prefix(prefix, "model")) |
| 473 | + self.lm_head = ParallelLMHead(config.vocab_size, |
| 474 | + config.hidden_size, |
| 475 | + quant_config=quant_config, |
| 476 | + prefix=maybe_prefix(prefix, "lm_head")) |
| 477 | + self.logits_processor = LogitsProcessor(config.vocab_size) |
| 478 | + self.sampler = get_sampler() |
| 479 | + |
| 480 | + self.make_empty_intermediate_tensors = ( |
| 481 | + self.model.make_empty_intermediate_tensors) |
| 482 | + |
| 483 | + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| 484 | + return self.model.get_input_embeddings(input_ids) |
| 485 | + |
| 486 | + def forward( |
| 487 | + self, |
| 488 | + input_ids: torch.Tensor, |
| 489 | + positions: torch.Tensor, |
| 490 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 491 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 492 | + ) -> Union[torch.Tensor, IntermediateTensors]: |
| 493 | + hidden_states = self.model(input_ids, positions, intermediate_tensors, |
| 494 | + inputs_embeds) |
| 495 | + return hidden_states |
| 496 | + |
| 497 | + def compute_logits(self, hidden_states: torch.Tensor, |
| 498 | + sampling_metadata: SamplingMetadata) -> torch.Tensor: |
| 499 | + logits = self.logits_processor(self.lm_head, hidden_states, |
| 500 | + sampling_metadata) |
| 501 | + return logits |
| 502 | + |
| 503 | + def sample( |
| 504 | + self, |
| 505 | + logits: Optional[torch.Tensor], |
| 506 | + sampling_metadata: SamplingMetadata, |
| 507 | + ) -> Optional[SamplerOutput]: |
| 508 | + next_tokens = self.sampler(logits, sampling_metadata) |
| 509 | + return next_tokens |
| 510 | + |
| 511 | + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| 512 | + loader = AutoWeightsLoader(self) |
| 513 | + return loader.load_weights(weights) |
0 commit comments