Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] A simple solution to recursively compile loaded model: using phi3-small as an example #8398

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 243 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,242 @@
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available

# _NN_MODULE_TORCH_COMPILE_CONFIGS maps a top-level module to a dictionary of
# its child modules and their corresponding compiler functions.
# This allows for specific child modules to be compiled with different
# compiler functions.
# If _NN_MODULE_TORCH_COMPILE_CONFIGS = {
# Phi3SmallForCausalLM: {
# Phi3SmallMLP: (mlp_compile_args, mlp_compile_kwargs),
# Phi3LongRoPEScaledRotaryEmbedding,
# (
# rotary_compile_args,
# rotary_compile_kwargs,
# )
# },
# }, where
# mlp_compile_args = (,) and mlp_compile_kwargs = {"dynamic": True}
# rotary_compile_args = (,) and rotary_compile_kwargs = {"dynamic": True}
# then all `instance`s of Phi3SmallMLP and Phi3LongRoPEScaledRotaryEmbedding
# in a Phi3SmallForCausalLM model will be compiled using
# torch.compile(`instance`, dynamic=True).
_NN_MODULE_TORCH_COMPILE_CONFIGS: Dict[torch.nn.Module, Dict[
torch.nn.Module,
# args and kwargs passed to torch.nn.Module.compile
Tuple[Tuple[Any, ...], Dict[str, Any]], ], ] = {}


def register_module_to_compile(
top_level_module: torch.nn.Module,
module: torch.nn.Module,
torch_compile_args: Tuple[Any, ...],
torch_compile_kwargs: Dict[str, Any],
) -> None:
"""
Registers a neural network module with specific compilation arguments
and keyword arguments. This function registers a given `module` under
a `top_level_module` to be compiled using the provided `torch_compile_args`
and `torch_compile_kwargs`. If the `top_level_module` is not already in
the compilation configuration dictionary, it initializes an entry for it.

Args:
top_level_module (torch.nn.Module): The top-level neural network
module whose child modules will be compiled.
module (torch.nn.Module): The specific neural network module to
be registered for compilation. All instances of this module
(unless they were compiled during top-down scanning of child
modules) in the `top_level_module` will be compiled.
torch_compile_args (Tuple[Any, ...]): Positional arguments passed to
the torch.compile. See torch.nn.Module.compile for more information.
torch_compile_kwargs (Dict[str, Any]): Keyword arguments torch.compile.
See torch.nn.Module.compile for more information.

Example:
```python
import torch.nn as nn

# Define a simple module
class MyModule(nn.Module):
def forward(self, x):
return x * 2

# Define a top-level module
class TopLevelModule(nn.Module):
def __init__(self):
super().__init__()
self.child = MyModule()

def forward(self, x):
return self.child(x)

# Create instances of the modules
top_level_module = TopLevelModule()
my_module = top_level_module.child

# Register the module for compilation
register_module_to_compile(
top_level_module,
my_module,
torch_compile_args=(),
torch_compile_kwargs={"dynamic": True}
)

# Generate example input
x = torch.rand(1, 1)

# Run the model
top_level_module(x)
```

"""
# Check if the top-level module is already in the compile
# configuration dictionary
if top_level_module not in _NN_MODULE_TORCH_COMPILE_CONFIGS:
# If not, initialize an empty dictionary for it
_NN_MODULE_TORCH_COMPILE_CONFIGS[top_level_module] = {}

# Register the module with its compile arguments and keyword arguments
_NN_MODULE_TORCH_COMPILE_CONFIGS[top_level_module][module] = (
torch_compile_args,
torch_compile_kwargs,
)
logger.info(
"Registering module %s for compilation in %s "
"using torch.compile with args=%s "
"and kwargs=%s.",
module,
top_level_module,
torch_compile_args,
torch_compile_kwargs,
)


def unregister_module_to_compile(top_level_module: torch.nn.Module,
module: torch.nn.Module) -> None:
"""
Unregisters a torch.nn.Module from being compiled.
This function removes the specified `module` from the list of modules
that are to be compiled when the `top_level_module` is loaded.

Also see `register_module_to_compile` and `compile_child_modules`
for more information.

Args:
top_level_module (torch.nn.Module): The top-level module
whose child modules will be compiled.
module (torch.nn.Module): The specific module to be
unregistered from compilation.

Example:
```python
import torch
import torch.nn as nn

# Define a simple module
class MyModule(nn.Module):
def forward(self, x):
return x * 2

# Create instances of the module
top_level_module = nn.Sequential(MyModule(), MyModule())
my_module = MyModule()

# Register the module for compilation
register_module_to_compile(top_level_module, my_module, torch.compile)

# Unregister the module from compilation
unregister_module_to_compile(top_level_module, my_module)
```
"""
if top_level_module in _NN_MODULE_TORCH_COMPILE_CONFIGS:
logger.info(
"Unregistering module %s from compilation in %s.",
module,
top_level_module,
)
_NN_MODULE_TORCH_COMPILE_CONFIGS[top_level_module].pop(module, None)
else:
logger.info(
"Module %s not found in %s for unregistration.",
module,
top_level_module,
)


def compile_child_modules(
module: torch.nn.Module,
compilation_configs: Optional[Dict[torch.nn.Module,
Tuple[Tuple[Any, ...],
Dict[str, Any]]]] = None,
compilation_counters: Optional[Dict[torch.nn.Module, int]] = None,
) -> Dict[torch.nn.Module, int]:
"""
Recursively compiles child modules of a given `module`.
This function uses depth-first scan to traverse through
child modules of the provided `module` and compiles them
if their type can be found in the
`_NN_MODULE_TORCH_COMPILE_COMPILE[type(module)]` dictionary.
The compiled module in-place replaces the original
child module in the parent module. Compiled modules' child
modules will not be compiled again.

Args:
module: The parent module whose child
modules are to be compiled.
compilation_configs:
A dictionary containing compilation configurations
for different module types. Defaults to None.
compilation_counters:
A dictionary to keep track of the number of times
each module type has been compiled. Defaults to None.
Returns:
None
"""
# Get the compilation configuration for the module
# and propagate it to all its child modules.
if compilation_configs is None:
# This line should only be executed once
# when processing the top-level module.
compilation_configs = _NN_MODULE_TORCH_COMPILE_CONFIGS.get(
type(module), {})

# Create global counters to count the number
# of times a module is compiled.
if compilation_counters is None:
# This line should only be executed once when
# processing the top-level module.
compilation_counters = collections.defaultdict(int)

for child_name, child in module.named_children():
if type(child
) in compilation_configs and child._compiled_call_impl is None:
# This child module is registered for compilation
# and has not been compiled yet.
# Once compiled, this child module and its child modules
# will not be compiled again.
logger.info("Compiling %s in %s.", child_name, module)
compilation_counters[type(child)] += 1
child.compile(
# args
*compilation_configs[type(child)][0],
# kwargs
**compilation_configs[type(child)][1],
)
else:
# This child module is not registered for compilation,
# but its child modules may be. Let's check them.
logger.info(
"Not compiling %s in %s but checking its children.",
child_name,
module,
)
compile_child_modules(
child,
compilation_configs=compilation_configs,
compilation_counters=compilation_counters,
)
return compilation_counters


@contextmanager
def device_loading_context(module: torch.nn.Module,
Expand Down Expand Up @@ -412,6 +648,13 @@ def load_model(self, *, model_config: ModelConfig,
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)

compiled_module_counts = compile_child_modules(model)
logger.info(
"Number of compiled torch.nn.Module's: %s",
compiled_module_counts,
)

return model.eval()


Expand Down
25 changes: 22 additions & 3 deletions vllm/model_executor/models/phi3_small.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import os
from typing import Iterable, List, Optional, Tuple

import torch
Expand All @@ -15,10 +16,12 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding import (
Phi3LongRoPEScaledRotaryEmbedding, get_rope)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.loader import register_module_to_compile
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -50,12 +53,16 @@ def weight_loader(self, param: torch.nn.Parameter,
return load_column_parallel_weight(param, loaded_weight)


@torch.jit.script
# When COMPILE_PHI3_SMALL=1, this will be compiled by
# torch.compile because it's wrapped inside nn.Module
# to be compiled.
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)


@torch.jit.script
# When COMPILE_PHI3_SMALL=1, this will be compiled by
# torch.compile because it's wrapped inside nn.Module
# to be compiled.
def gegelu(input, limit: Optional[float] = None):
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
if limit is not None:
Expand Down Expand Up @@ -451,3 +458,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)


# Register compiler if environment variables is detected.
# The model will be compiled after being loaded and quantized.
if os.environ.get("COMPILE_CHILD_NN_MODULES", "0") == "1":
register_module_to_compile(Phi3SmallForCausalLM, torch.nn.LayerNorm, (),
{"dynamic": True})
register_module_to_compile(Phi3SmallForCausalLM, Phi3SmallMLP, (),
{"dynamic": True})
register_module_to_compile(Phi3SmallForCausalLM,
Phi3LongRoPEScaledRotaryEmbedding, (),
{"dynamic": True})
Loading