Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jul 17, 2024
1 parent 5049697 commit e5b5d63
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 53 deletions.
4 changes: 2 additions & 2 deletions mlora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Cache,
LLMBatchConfig,
LLMForCausalLM,
LLMModelArgs,
LLMModelConfig,
LLMModelInput,
LLMModelOutput,
LoraConfig,
Expand Down Expand Up @@ -32,7 +32,7 @@
__all__ = [
"Cache",
"cache_factory",
"LLMModelArgs",
"LLMModelConfig",
"LLMModelOutput",
"LLMForCausalLM",
"LLMBatchConfig",
Expand Down
6 changes: 2 additions & 4 deletions mlora/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,12 @@
DataClass,
Labels,
LLMBatchConfig,
LLMModelArgs,
LLMModelConfig,
LLMModelInput,
LLMModelOutput,
LoraConfig,
Masks,
MixConfig,
TokenizerArgs,
Tokens,
lora_config_factory,
)
Expand Down Expand Up @@ -91,8 +90,7 @@
"Labels",
"Masks",
"DataClass",
"TokenizerArgs",
"LLMModelArgs",
"LLMModelConfig",
"LLMModelOutput",
"LLMBatchConfig",
"LLMModelInput",
Expand Down
26 changes: 5 additions & 21 deletions mlora/common/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from .model import Cache
from .modelargs import LLMModelArgs
from .modelargs import LLMModelConfig


class DynamicCache(Cache):
Expand Down Expand Up @@ -146,25 +146,9 @@ def batch_select_indices(self, indices: torch.Tensor):


class StaticCache(Cache):
"""
Static Cache class to be used with `torch.compile(model)`.
Parameters:
config (`LLMModelArgs):
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""

def __init__(
self,
config: LLMModelArgs,
config: LLMModelConfig,
max_batch_size: int,
max_cache_len: int,
device,
Expand Down Expand Up @@ -259,7 +243,7 @@ def reset(self):
class SlidingWindowCache(StaticCache):
def __init__(
self,
config: LLMModelArgs,
config: LLMModelConfig,
max_batch_size: int,
max_cache_len: int,
device,
Expand Down Expand Up @@ -333,7 +317,7 @@ def get_max_length(self) -> Optional[int]:
class HybridCache(Cache):
def __init__(
self,
config: LLMModelArgs,
config: LLMModelConfig,
max_batch_size,
max_cache_len,
device="cpu",
Expand Down Expand Up @@ -497,7 +481,7 @@ def reset(self):

def cache_factory(
cache_implementation: str,
config: LLMModelArgs,
config: LLMModelConfig,
max_batch_size: int,
max_cache_len: int,
):
Expand Down
4 changes: 2 additions & 2 deletions mlora/common/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .lora_linear import Linear, get_range_tensor
from .mix_lora import moe_layer_factory
from .model import LLMFeedForward
from .modelargs import LLMModelArgs, LLMModelInput, MixConfig
from .modelargs import LLMModelConfig, LLMModelInput, MixConfig


class FeedForward(torch.nn.Module):
Expand All @@ -30,7 +30,7 @@ def forward(

# MixLoRA
def init_moe_weight(
self, args: LLMModelArgs, config: MixConfig, gate: Optional[torch.Tensor] = None
self, args: LLMModelConfig, config: MixConfig, gate: Optional[torch.Tensor] = None
):
self.moes_[config.adapter_name] = moe_layer_factory(args, config)
if gate is None:
Expand Down
8 changes: 4 additions & 4 deletions mlora/common/mix_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers.activations import ACT2FN

from .model import LLMFeedForward
from .modelargs import LLMModelArgs, MixConfig
from .modelargs import LLMModelConfig, MixConfig


def _mixtral_load_balancing_loss_func(
Expand Down Expand Up @@ -130,7 +130,7 @@ def _mixtral_compatible_forward(


class MixtralSparseMoe(torch.nn.Module):
def __init__(self, args: LLMModelArgs, config: MixConfig) -> None:
def __init__(self, args: LLMModelConfig, config: MixConfig) -> None:
super().__init__()

self.adapter_name_: str = config.adapter_name
Expand Down Expand Up @@ -326,7 +326,7 @@ def forward(self, router_outputs, attention_mask) -> torch.Tensor:


class SwitchSparseMoe(torch.nn.Module):
def __init__(self, args: LLMModelArgs, config: MixConfig) -> None:
def __init__(self, args: LLMModelConfig, config: MixConfig) -> None:
super().__init__()

self.adapter_name_: str = config.adapter_name
Expand Down Expand Up @@ -439,7 +439,7 @@ def router_loss_factory(config: MixConfig) -> torch.nn.Module:
moe_layer_dict = {"mixtral": MixtralSparseMoe, "switch": SwitchSparseMoe}


def moe_layer_factory(args: LLMModelArgs, config: MixConfig) -> torch.nn.Module:
def moe_layer_factory(args: LLMModelConfig, config: MixConfig) -> torch.nn.Module:
if config.routing_strategy_ not in router_loss_dict:
raise ValueError(f"Unknown routing strategy {config.routing_strategy_}")
return moe_layer_dict[config.routing_strategy_](args, config)
4 changes: 2 additions & 2 deletions mlora/common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from .modelargs import LLMModelArgs, LLMModelInput
from .modelargs import LLMModelConfig, LLMModelInput


@dataclass
Expand Down Expand Up @@ -160,7 +160,7 @@ def cache_implementation(self) -> str:
return "dynamic"

@classmethod
def model_config(self) -> LLMModelArgs:
def model_config(self) -> LLMModelConfig:
pass

@staticmethod
Expand Down
10 changes: 1 addition & 9 deletions mlora/common/modelargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,7 @@ class DataClass:


@dataclass
class TokenizerArgs:
vocab_size_: int = -1
bos_id_: int = -1
eos_id_: int = -1
pad_id_: int = -1


@dataclass
class LLMModelArgs:
class LLMModelConfig:
name_or_path_: str = ""
device_: str = ""
dim_: int = 4096
Expand Down
6 changes: 3 additions & 3 deletions mlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Cache,
LLMDecoder,
LLMForCausalLM,
LLMModelArgs,
LLMModelConfig,
LLMModelInput,
LLMModelOutput,
LLMOutput,
Expand Down Expand Up @@ -155,7 +155,7 @@ def forward(

def init_lora_layer_weight(
layer: LLMDecoder,
args: LLMModelArgs,
args: LLMModelConfig,
config: LoraConfig,
weight: Optional[Dict[str, torch.Tensor]],
):
Expand Down Expand Up @@ -228,7 +228,7 @@ def init_lora_layer_weight(
class LLMModel(torch.nn.Module):
def __init__(self, model: LLMForCausalLM):
super().__init__()
args: LLMModelArgs = model.config_
args: LLMModelConfig = model.config_
if args.vocab_size_ >= torch.finfo(args.dtype_).max:
logging.warn(
f"vocab_size >= max({args.dtype_}), consider load model with higher precision."
Expand Down
4 changes: 2 additions & 2 deletions mlora/models/modeling_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
LLMDecoder,
LLMFeedForward,
LLMForCausalLM,
LLMModelArgs,
LLMModelConfig,
LLMModelInput,
)
from mlora.common.mix_lora import _mixtral_slice_tensor
Expand All @@ -29,7 +29,7 @@


@dataclass
class GLMConfig(LLMModelArgs):
class GLMConfig(LLMModelConfig):
post_layer_norm: bool = True
rmsnorm: bool = True
layernorm_epsilon: float = 1e-5
Expand Down
4 changes: 2 additions & 2 deletions mlora/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
LLMDecoder,
LLMFeedForward,
LLMForCausalLM,
LLMModelArgs,
LLMModelConfig,
LLMModelInput,
prepare_4d_causal_attention_mask,
scaled_dot_product_attention,
Expand All @@ -37,7 +37,7 @@


@dataclass
class LlamaConfig(LLMModelArgs):
class LlamaConfig(LLMModelConfig):
rms_norm_eps_: float = 1e-6


Expand Down
4 changes: 2 additions & 2 deletions mlora/models/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
LLMDecoder,
LLMFeedForward,
LLMForCausalLM,
LLMModelArgs,
LLMModelConfig,
LLMModelInput,
prepare_4d_causal_attention_mask,
scaled_dot_product_attention,
Expand All @@ -37,7 +37,7 @@


@dataclass
class PhiConfig(LLMModelArgs):
class PhiConfig(LLMModelConfig):
partial_rotary_factor_: float = 0.5
layer_norm_eps_: float = 1e-05
resid_pdrop_: float = 0.0
Expand Down

0 comments on commit e5b5d63

Please sign in to comment.