Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
support llama 3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jul 25, 2024
1 parent 349dac6 commit ff68a74
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 47 deletions.
4 changes: 2 additions & 2 deletions mlora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

assert is_package_available("torch", "2.3.0"), "m-LoRA requires torch>=2.3.0"
assert is_package_available(
"transformers", "4.42.0"
), "m-LoRA requires transformers>=4.42.0"
"transformers", "4.43.0"
), "m-LoRA requires transformers>=4.43.0"


setup_logging()
Expand Down
3 changes: 2 additions & 1 deletion mlora/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Attention and Feed Forward
from .attention import (
eager_attention_forward,
flash_attention_forward,
Expand Down Expand Up @@ -59,6 +58,7 @@
Tokens,
lora_config_factory,
)
from .rope import ROPE_INIT_FUNCTIONS

__all__ = [
"prepare_4d_causal_attention_mask",
Expand Down Expand Up @@ -103,4 +103,5 @@
"LoraConfig",
"MixConfig",
"lora_config_factory",
"ROPE_INIT_FUNCTIONS",
]
8 changes: 8 additions & 0 deletions mlora/common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def forward(
self,
hidden_states: torch.Tensor,
input_args: LLMModelInput,
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
Expand Down Expand Up @@ -106,6 +107,7 @@ def forward(
self,
hidden_states: torch.Tensor,
input_args: LLMModelInput,
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
Expand Down Expand Up @@ -137,6 +139,12 @@ class LLMForCausalLM(metaclass=ABCMeta):
def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
pass

@classmethod
def rotary_embed(
self, input_tensor: torch.Tensor, position_ids: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
pass

@classmethod
def decoder_stack(self) -> List[LLMDecoder]:
pass
Expand Down
29 changes: 15 additions & 14 deletions mlora/common/modelargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,21 @@ class DataClass:

@dataclass
class LLMModelConfig:
name_or_path_: str = ""
device_: str = ""
dim_: int = 4096
head_dim_: int = 256
intermediate_: int = 11008
n_heads_: int = 32
n_kv_heads_: int = 32
n_layers_: int = 32
hidden_act_: str = "silu"
hidden_dropout_: float = 0.0
vocab_size_: int = -1
pad_token_id_: int = -1
rope_theta_: float = 10000.0
max_seq_len_: int = 2048
name_or_path_: str = None
device_: str = None
dim_: int = None
head_dim_: int = None
intermediate_: int = None
n_heads_: int = None
n_kv_heads_: int = None
n_layers_: int = None
hidden_act_: str = None
hidden_dropout_: float = None
vocab_size_: int = None
pad_token_id_: int = None
rope_theta_: float = None
partial_rotary_factor_: float = None
max_seq_len_: int = None
# eager or flash_attn
attn_implementation_: str = "eager"
# data type
Expand Down
80 changes: 80 additions & 0 deletions mlora/common/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import math
from typing import Optional, Tuple

import torch

from .modelargs import LLMModelConfig


def _compute_default_rope_parameters(
config: Optional[LLMModelConfig] = None,
device: Optional[torch.device] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple[torch.Tensor, float]:
if len(rope_kwargs) > 0:
base = rope_kwargs["base"]
dim = rope_kwargs["dim"]
elif config is not None:
base = config.rope_theta_
partial_rotary_factor = (
config.partial_rotary_factor_
if config.partial_rotary_factor_ is not None
else 1.0
)
dim = int((config.dim_ // config.n_heads_) * partial_rotary_factor)

attention_factor = 1.0 # Unused in this type of RoPE

# Compute the inverse frequencies
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)
)
return inv_freq, attention_factor


def _compute_llama3_parameters(
config: LLMModelConfig,
device: torch.device,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple[torch.Tensor, float]:
# Gets the default RoPE parameters
inv_freq, attention_factor = _compute_default_rope_parameters(
config, device, seq_len, **rope_kwargs
)

factor = config.rope_scaling_["factor"] # `8` in the original implementation
low_freq_factor = config.rope_scaling_[
"low_freq_factor"
] # `1` in the original implementation
high_freq_factor = config.rope_scaling_[
"high_freq_factor"
] # `4` in the original implementation
old_context_len = config.rope_scaling_[
"original_max_position_embeddings"
] # `8192` in the original implementation

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
return inv_freq, attention_factor


ROPE_INIT_FUNCTIONS = {
"default": _compute_default_rope_parameters,
"llama3": _compute_llama3_parameters,
}
60 changes: 47 additions & 13 deletions mlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,9 @@ def __init__(self, model: LLMForCausalLM):
# adapter configs
self.adapter_configs_: Dict[str, LoraConfig] = {}

# compute the model: output probs
def forward(
def _prepare_inputs(
self, input_args: LLMModelInput, past_key_values: Optional[Cache] = None
) -> List[LLMModelOutput]:
):
assert input_args.batch_tokens_ is not None, "Model have no input."
assert (
input_args.gradient_checkpoint_ == "none" or past_key_values is None
Expand Down Expand Up @@ -306,15 +305,17 @@ def forward(
else:
causal_mask = attention_mask

labels = input_args.batch_labels_

input_args.batch_labels_ = None
input_args.batch_tokens_ = None
input_args.batch_masks_ = None

# embed positions
hidden_states = inputs_embeds
return input_ids, inputs_embeds, attention_mask, causal_mask, cache_position

def _call_decoder_stack(
self,
hidden_states: torch.Tensor,
input_args: LLMModelInput,
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
):
# decoder layers
num_adapters = len(input_args.batch_configs_)
all_router_logits = [[] for _ in range(num_adapters)]
Expand All @@ -325,9 +326,10 @@ def forward(
decoder_layer.forward,
hidden_states,
input_args,
causal_mask,
rotary_emb,
attention_mask,
cache_position,
past_key_values,
past_key_value,
)
if len(router_logits) == 0:
continue
Expand All @@ -339,6 +341,38 @@ def forward(

hidden_states = self.model_.norm(hidden_states)

return hidden_states, all_router_logits

# compute the model: output probs
def forward(
self, input_args: LLMModelInput, past_key_values: Optional[Cache] = None
) -> List[LLMModelOutput]:
input_ids, inputs_embeds, attention_mask, causal_mask, cache_position = (
self._prepare_inputs(input_args, past_key_values)
)

labels = input_args.batch_labels_

input_args.batch_labels_ = None
input_args.batch_tokens_ = None
input_args.batch_masks_ = None

# embed positions
hidden_states = inputs_embeds

rotary_emb = self.model_.rotary_embed(
hidden_states, cache_position.unsqueeze(0)
)

hidden_states, all_router_logits = self._call_decoder_stack(
hidden_states,
input_args,
rotary_emb,
causal_mask,
cache_position,
past_key_values,
)

# calculate loss
output = self.output_(hidden_states, input_args)
assert isinstance(output, List)
Expand Down
Loading

0 comments on commit ff68a74

Please sign in to comment.