Skip to content

Commit

Permalink
Refine infra
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Sep 24, 2024
1 parent c0d30e9 commit 974b420
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 51 deletions.
241 changes: 195 additions & 46 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,73 +44,220 @@
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available

# _NN_MODULE_TO_COMPILE[torch.nn.Linear] = torch.compile means
# that after torch.nn.Linear is loaded, it will be compiled using torch.compile
# Search for _NN_MODULE_TO_COMPILE in the codebase to see how it is used.
_NN_MODULE_TO_COMPILE: Dict[torch.nn.Module, Callable[[torch.nn.Module],
Callable], ] = {}
# _NN_MODULE_TORCH_COMPILE_CONFIGS maps a top-level module to a dictionary of
# its child modules and their corresponding compiler functions.
# This allows for specific child modules to be compiled with different
# compiler functions.
# If _NN_MODULE_TORCH_COMPILE_CONFIGS = {
# Phi3SmallForCausalLM: {
# Phi3SmallMLP: (mlp_compile_args, mlp_compile_kwargs),
# Phi3LongRoPEScaledRotaryEmbedding, (rotary_compile_args, rotary_compile_kwargs)
# },
# }, where
# mlp_compile_args = (,) and mlp_compile_kwargs = {"dynamic": True}
# rotary_compile_args = (,) and rotary_compile_kwargs = {"dynamic": True}
# then all `instance`s of Phi3SmallMLP and Phi3LongRoPEScaledRotaryEmbedding
# in a Phi3SmallForCausalLM model will be compiled using
# torch.compile(`instance`, dynamic=True).
_NN_MODULE_TORCH_COMPILE_CONFIGS: Dict[torch.nn.Module, Dict[
torch.nn.Module,
# args and kwargs passed to torch.nn.Module.compile
Dict[Tuple[Any, ...], Dict[str, Any]], ], ] = {}


def register_module_to_compile(
top_level_module: torch.nn.Module,
module: torch.nn.Module,
compiler: Callable[[torch.nn.Module], Callable],
torch_compile_args: Tuple[Any, ...],
torch_compile_kwargs: Dict[str, Any],
) -> None:
"""
Registers a neural network module with a specified compiler function.
This function ensures that the given module is not already registered
with a compiler. If the module is not registered, it associates the
module with the provided compiler function. All torch.nn.Module instances
registered with this function will be compiled using the provided compiler
function when being loaded.
Registers a neural network module with specific compilation arguments and keyword arguments.
This function registers a given `module` under a `top_level_module` to be compiled using
the provided `torch_compile_args` and `torch_compile_kwargs`. If the `top_level_module`
is not already in the compilation configuration dictionary, it initializes an entry for it.
Args:
module (torch.nn.Module): The neural network module to be registered.
Example:
class MyModule(torch.nn.Module):
def forward(self, x):
return x * 2
my_module = MyModule()
compiler (Callable[[torch.nn.Module], Callable]): A function that
takes a neural network module
and returns a compiled version of that module.
Example:
import functools
customized_compiler = functools.partial(
torch.compile, dynamic=True
)
top_level_module (torch.nn.Module): The top-level neural network module whose child
modules will be compiled.
module (torch.nn.Module): The specific neural network module to be registered for
compilation. All instances of this module (unless they were compiled during
top-down scanning of child modules) in the `top_level_module` will be compiled.
torch_compile_args (Tuple[Any, ...]): Positional arguments passed to the torch.compile.
See torch.nn.Module.compile for more information.
torch_compile_kwargs (Dict[str, Any]): Keyword arguments torch.compile.
See torch.nn.Module.compile for more information.
Example:
```python
import torch.nn as nn
# Define a simple module
class MyModule(nn.Module):
def forward(self, x):
return x * 2
# Define a top-level module
class TopLevelModule(nn.Module):
def __init__(self):
super().__init__()
self.child = MyModule()
def forward(self, x):
return self.child(x)
# Create instances of the modules
top_level_module = TopLevelModule()
my_module = top_level_module.child
# Register the module for compilation
register_module_to_compile(
top_level_module,
my_module,
torch_compile_args=(),
torch_compile_kwargs={"dynamic": True}
)
# Generate example input
x = torch.rand(1, 1)
# Run the model
top_level_module(x)
```
"""
_NN_MODULE_TO_COMPILE[module] = compiler
# Check if the top-level module is already in the compile configuration dictionary
if top_level_module not in _NN_MODULE_TORCH_COMPILE_CONFIGS:
# If not, initialize an empty dictionary for it
_NN_MODULE_TORCH_COMPILE_CONFIGS[top_level_module] = {}

# Register the module with its compile arguments and keyword arguments
_NN_MODULE_TORCH_COMPILE_CONFIGS[top_level_module][module] = (
torch_compile_args,
torch_compile_kwargs,
)
logger.info(
f"Registering module {module} for compilation in {top_level_module} "
f"using torch.compile with args={torch_compile_args} and "
f"kwargs={torch_compile_kwargs}.")


def unregister_module_to_compile(module: torch.nn.Module) -> None:
_NN_MODULE_TO_COMPILE.pop(module, None)
def unregister_module_to_compile(top_level_module: torch.nn.Module,
module: torch.nn.Module) -> None:
"""
Unregisters a torch.nn.Module from being compiled.
This function removes the specified `module` from the list of modules
that are to be compiled when the `top_level_module` is loaded.
Also see `register_module_to_compile` and `compile_child_modules`
for more information.
def compile_child_modules(module):
Args:
top_level_module (torch.nn.Module): The top-level module
whose child modules will be compiled.
module (torch.nn.Module): The specific module to be
unregistered from compilation.
Example:
```python
import torch
import torch.nn as nn
# Define a simple module
class MyModule(nn.Module):
def forward(self, x):
return x * 2
# Create instances of the module
top_level_module = nn.Sequential(MyModule(), MyModule())
my_module = MyModule()
# Register the module for compilation
register_module_to_compile(top_level_module, my_module, torch.compile)
# Unregister the module from compilation
unregister_module_to_compile(top_level_module, my_module)
```
"""
Recursively compiles child modules of a given module.
This function traverses through all child modules of the provided
`module` and compiles them if their type is listed in the
`_NN_MODULE_TO_COMPILE` dictionary. The compiled module in-place
replaces the original child module in the parent module.
if top_level_module in _NN_MODULE_TORCH_COMPILE_CONFIGS:
logger.info(
f"Unregistering module {module} from compilation in {top_level_module}"
)
_NN_MODULE_TORCH_COMPILE_CONFIGS[top_level_module].pop(module, None)
else:
logger.info(
f"Module {module} not found in {top_level_module} for unregistration"
)


def compile_child_modules(
module: torch.nn.Module,
compilation_configs: Optional[Dict[torch.nn.Module,
Dict[Tuple[Any, ...],
Dict[str, Any]]]] = None,
compilation_counters: Optional[Dict[torch.nn.Module, int]] = None,
) -> Dict[torch.nn.Module, int]:
"""
Recursively compiles child modules of a given `module`.
This function uses depth-first scan to traverse through
child modules of the provided `module` and compiles them
if their type can be found in the
`_NN_MODULE_TORCH_COMPILE_COMPILE[type(module)]` dictionary.
The compiled module in-place replaces the original
child module in the parent module. Compiled modules' child
modules will not be compiled again.
Args:
module (torch.nn.Module): The parent module whose child modules
are to be compiled.
compiled_modules (dict): A dictionary that maps module types to
their corresponding compiler functions.
module (torch.nn.Module): The parent module whose child modules are to be compiled.
compilation_configs (Optional[Dict[torch.nn.Module, Dict[Tuple[Any, ...], Dict[str, Any]]]], optional):
A dictionary containing compilation configurations for different module types.
Defaults to None.
compilation_counters (Optional[Dict[torch.nn.Module, int]], optional):
A dictionary to keep track of the number of times each module type has been compiled.
Defaults to None.
Returns:
None
"""
# Get the compilation configuration for the module
# and propagate it to all its child modules.
if compilation_configs is None:
# This line should only be executed once when processing the top-level module.
compilation_configs = _NN_MODULE_TORCH_COMPILE_CONFIGS.get(
type(module), {})

# Create global counters to count the number of times a module is compiled.
if compilation_counters is None:
# This line should only be executed once when processing the top-level module.
compilation_counters: Dict[torch.nn.Module,
int] = collections.defaultdict(int)

for child_name, child in module.named_children():
if type(child) in _NN_MODULE_TO_COMPILE:
module_compiler = _NN_MODULE_TO_COMPILE[type(child)]
setattr(module, child_name, module_compiler(child))
compile_child_modules(child)
if type(child
) in compilation_configs and child._compiled_call_impl is None:
# This child module is registered for compilation
# and has not been compiled yet.
# Once compiled, this child module and its child modules
# will not be compiled again.
logger.info(f"Compiling {child_name} in {module}")
compilation_counters[type(child)] += 1
child.compile(
# args
*compilation_configs[type(child)][0],
# kwargs
**compilation_configs[type(child)][1],
)
else:
# This child module is not registered for compilation,
# but its child modules may be. Let's check them.
logger.info(
f"Not compiling {child_name} in {module} but checking its children."
)
compile_child_modules(
child,
compilation_configs=compilation_configs,
compilation_counters=compilation_counters,
)
return compilation_counters


@contextmanager
Expand Down Expand Up @@ -445,7 +592,9 @@ def load_model(self, *, model_config: ModelConfig,
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)

compile_child_modules(model)
compiled_module_counts = compile_child_modules(model)
logger.info(
f"Number of compiled torch.nn.Module's: {compiled_module_counts }")

return model.eval()

Expand Down
12 changes: 7 additions & 5 deletions vllm/model_executor/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Register compiler if environment variables is detected.
# The model will be compiled after being loaded and quantized.
if os.environ.get("COMPILE_PHI3_SMALL", "0") == "1":
dynamic_shape_compiler = functools.partial(torch.compile, dynamic=True)
register_module_to_compile(torch.nn.LayerNorm, dynamic_shape_compiler)
register_module_to_compile(Phi3SmallMLP, dynamic_shape_compiler)
register_module_to_compile(Phi3LongRoPEScaledRotaryEmbedding,
dynamic_shape_compiler)
register_module_to_compile(Phi3SmallForCausalLM, torch.nn.LayerNorm, (),
{"dynamic": True})
register_module_to_compile(Phi3SmallForCausalLM, Phi3SmallMLP, (),
{"dynamic": True})
register_module_to_compile(Phi3SmallForCausalLM,
Phi3LongRoPEScaledRotaryEmbedding, (),
{"dynamic": True})

0 comments on commit 974b420

Please sign in to comment.