Skip to content

Commit

Permalink
Lazy init for activations
Browse files Browse the repository at this point in the history
  • Loading branch information
ProExpertProg committed Oct 17, 2024
1 parent 0d42d55 commit 408d576
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
26 changes: 17 additions & 9 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import LazyDict


@CustomOp.register("fatrelu_and_mul")
Expand Down Expand Up @@ -250,15 +251,22 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data.copy_(loaded_weight)


_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_fast": FastGELU(),
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"relu2": ReLUSquaredActivation(),
"quick_gelu": QuickGELU(),
}
_ACTIVATION_REGISTRY = LazyDict({
"gelu":
lambda: nn.GELU(),
"gelu_fast":
lambda: FastGELU(),
"gelu_new":
lambda: NewGELU(),
"gelu_pytorch_tanh":
lambda: nn.GELU(approximate="tanh"),
"relu":
lambda: nn.ReLU(),
"relu2":
lambda: ReLUSquaredActivation(),
"quick_gelu":
lambda: QuickGELU(),
})


def get_act_fn(
Expand Down
22 changes: 22 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import warnings
import weakref
from asyncio import FIRST_COMPLETED, ensure_future
from collections.abc import Mapping
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Expand Down Expand Up @@ -1442,3 +1443,24 @@ def dec(self, num=1):
@property
def value(self):
return self._value


# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping, Generic[T]):

def __init__(self, factory: Dict[str, Callable[[], T]]):
self._factory = factory
self._dict: Dict[str, T] = {}

def __getitem__(self, key) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
self._dict[key] = self._factory[key]()
return self._dict[key]

def __iter__(self):
return iter(self._factory)

def __len__(self):
return len(self._factory)

0 comments on commit 408d576

Please sign in to comment.