Skip to content

Commit

Permalink
Merge branch 'main' into fsdp-lora
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Jul 2, 2024
2 parents aa965cc + 2da3dd2 commit 0bd44ac
Show file tree
Hide file tree
Showing 24 changed files with 712 additions and 237 deletions.
393 changes: 234 additions & 159 deletions README.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from litgpt.chat.base import generate as stream_generate_fn
from litgpt.prompts import load_prompt_style, has_prompt_style, PromptStyle
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -171,6 +172,7 @@ def load(

if checkpoint_dir is not None:
checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
load_checkpoint(fabric, model, checkpoint_path)
return cls(
model=model, tokenizer=tokenizer, devices=devices,
Expand Down
2 changes: 2 additions & 0 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.scripts.merge_lora import merge_lora
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -221,6 +222,7 @@ def main(
fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)

checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)

# Merge if this is a raw LoRA checkpoint
if (checkpoint_dir / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
Expand Down
16 changes: 16 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,22 @@ def norm_class(self) -> Type:
lm_head_bias=True,
gelu_approximate="tanh",
),
# https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
dict(
name="Phi-3-mini-4k-instruct",
hf_config=dict(org="microsoft", name="Phi-3-mini-4k-instruct"),
vocab_size=32000,
padded_vocab_size=32064,
block_size=4096,
n_embd=3072,
n_layer=32,
rotary_percentage=1.0,
bias=False,
norm_class_name="RMSNorm",
intermediate_size=8192,
mlp_class_name="LLaMAMLP",
parallel_residual=False,
),
]
configs.extend(phi)

Expand Down
2 changes: 1 addition & 1 deletion litgpt/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def prepare_results(results, save_filepath, print_results=True):
print(make_table(results, "groups"))

json_result = json.dumps(
results, indent=2, ensure_ascii=False
results, indent=2, ensure_ascii=False, default=str
)
save_filepath.open("w", encoding="utf-8").write(json_result)

Expand Down
2 changes: 2 additions & 0 deletions litgpt/generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -96,6 +97,7 @@ def main(
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)

tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
Expand Down
2 changes: 2 additions & 0 deletions litgpt/generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -96,6 +97,7 @@ def main(
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)

tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
Expand Down
2 changes: 2 additions & 0 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from litgpt.tokenizer import Tokenizer
from litgpt.prompts import has_prompt_style, load_prompt_style, PromptStyle
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -217,6 +218,7 @@ def main(
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"
check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)

tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
Expand Down
3 changes: 2 additions & 1 deletion litgpt/generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from litgpt.generate.base import generate
from litgpt.prompts import has_prompt_style, load_prompt_style
from litgpt.utils import (
check_file_size_on_cpu_and_warn,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
Expand Down Expand Up @@ -95,7 +96,7 @@ def main(
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = finetuned_path

check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)
tokenizer = Tokenizer(checkpoint_dir)
prompt_style = (
load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)
Expand Down
11 changes: 10 additions & 1 deletion litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ def apply(self, prompt: str, **kwargs: str) -> str:
return f"Instruct: {prompt}\nOutput:"


class Phi3(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f'<|system|>\nYou are a helpful assistant.<|end|>\n<|user|>\n{prompt}<|end|>\n<|assistant|>\n'



class TinyLlama(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return (
Expand Down Expand Up @@ -352,6 +358,7 @@ def apply(self, prompt: str, **kwargs: str) -> str:
"codellama": CodeLlama,
"phi-1": Phi1,
"phi-2": Phi2,
"phi-3": Phi3,
"tinyllama": TinyLlama,
"gemma": Gemma,
"h2oai": H2Oai,
Expand Down Expand Up @@ -386,12 +393,14 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Platypus()
if re.search("Nous-Hermes", model_name):
return NousResearch()
if re.search("CodeLlama|Mistral.*Instruct", model_name):
if re.search("CodeLlama|Mi[sx]tral.*Instruct", model_name):
return CodeLlama()
if re.search("phi-1", model_name):
return Phi1()
if re.search("phi-2", model_name):
return Phi2()
if re.search("Phi-3", model_name):
return Phi3()
if re.search(r"tiny-llama.*chat", model_name):
return TinyLlama()
if re.search(r"(Code)?Gemma.*-it", model_name):
Expand Down
56 changes: 46 additions & 10 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor

from litgpt import Config
from litgpt.utils import (
extend_checkpoint_dir,
lazy_load,
incremental_save,
save_config
)
from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load, save_config


def copy_weights_gpt_neox(
Expand Down Expand Up @@ -235,13 +230,36 @@ def copy_weights_phi(
"lm_head.bias": "lm_head.bias",
}

if config.name.startswith("Phi-3"):
weight_map.update(
{
"model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.attn.weight",
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
'model.layers.{}.post_attention_layernorm.weight': "transformer.h.{}.norm_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
"model.norm.weight": "transformer.ln_f.weight",
}
)

for name, param in hf_weights.items():
if name.startswith("model.layers."):
from_name, l = layer_template(name, 2)
qkv = qkv_weights.setdefault(l, defaultdict(dict))
if "qkv_proj" in from_name:
weight = load_param(param, f"layer {l} qkv", dtype)
weight = qkv_reassemble(weight, config)
to_name = weight_map[from_name].format(l)
state_dict[to_name] = weight
continue
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
elif from_name.endswith("gate_up_proj.weight"):
weight = load_param(param, f"layer {l} gate_up_proj", dtype)
fc_1, fc_2 = weight.chunk(2, dim=0)
state_dict[f"transformer.h.{l}.mlp.fc_1.weight"] = fc_1
state_dict[f"transformer.h.{l}.mlp.fc_2.weight"] = fc_2
continue
to_name = weight_map[from_name]
if to_name is None:
continue
Expand Down Expand Up @@ -272,6 +290,24 @@ def copy_weights_phi(
del qkv_weights[i][weight_type]


def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor:
"""Reassemble from a normal to an interleaved placement in a QKV matrix.
[Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]
"""
q, k, v = param.split(
(
config.n_head * config.head_size,
config.n_query_groups * config.head_size,
config.n_query_groups * config.head_size,
)
)
qs = q.split(config.n_head // config.n_query_groups * config.head_size)
ks = k.split(config.head_size)
vs = v.split(config.head_size)
interleaved = [t for group in zip(qs, ks, vs) for t in group]
return torch.cat(interleaved)


def layer_template(layer_name: str, idx: int) -> Tuple[str, int]:
split = layer_name.split(".")
number = int(split[idx])
Expand Down Expand Up @@ -321,14 +357,14 @@ def convert_hf_checkpoint(

if "falcon" in model_name:
copy_fn = partial(copy_weights_falcon, model_name)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
elif model_name.lower().startswith("phi"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
elif "phi" in model_name:
copy_fn = partial(copy_weights_phi, config, qkv_weights)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_phi, config, qkv_weights)
copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)
else:
copy_fn = copy_weights_gpt_neox

Expand Down
65 changes: 48 additions & 17 deletions litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import gc
from collections import defaultdict
from functools import partial
from pathlib import Path
from pprint import pprint
Expand All @@ -11,11 +12,7 @@

from litgpt import Config
from litgpt.scripts.convert_hf_checkpoint import layer_template, load_param
from litgpt.utils import (
extend_checkpoint_dir,
incremental_save,
lazy_load
)
from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load


def copy_weights_falcon(
Expand Down Expand Up @@ -192,31 +189,65 @@ def copy_weights_phi(
"lm_head.bias": "lm_head.bias",
}

if config.name.startswith("Phi-3"):
weight_map.update(
{
"transformer.h.{}.attn.attn.weight": "model.layers.{}.self_attn.qkv_proj.weight",
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
"transformer.h.{}.norm_2.weight": 'model.layers.{}.post_attention_layernorm.weight',
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
"transformer.ln_f.weight": "model.norm.weight",
}
)
gate_up_proj_weights = defaultdict(dict)


for name, param in lit_weights.items():
if name.endswith((".attn.attn.weight", ".attn.attn.bias")):
from_name, l = layer_template(name, 2)
weight_type = name.split(".")[-1] # weight or bias
q = f"model.layers.{l}.self_attn.q_proj.{weight_type}"
k = f"model.layers.{l}.self_attn.k_proj.{weight_type}"
v = f"model.layers.{l}.self_attn.v_proj.{weight_type}"
from_name, l_idx = layer_template(name, 2)
qkv = load_param(param, name, None)
qp, kp, vp = qkv_split(qkv, config)
for to_name, param in zip((q, k, v), (qp, kp, vp)):
if config.name.startswith("Phi-3"):
qkv_reassembled = torch.concat([qp, kp, vp], dim=0)
to_name = weight_map[from_name].format(l_idx)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
qkv_reassembled = saver.store_early(qkv_reassembled)
state_dict[to_name] = qkv_reassembled
else:
weight_type = name.split(".")[-1] # weight or bias
q = f"model.layers.{l_idx}.self_attn.q_proj.{weight_type}"
k = f"model.layers.{l_idx}.self_attn.k_proj.{weight_type}"
v = f"model.layers.{l_idx}.self_attn.v_proj.{weight_type}"
for to_name, param in zip((q, k, v), (qp, kp, vp)):
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param
elif name.endswith((".fc_1.weight", ".fc_2.weight")):
from_name, l_idx = layer_template(name, 2)
weight = load_param(param, name, None)
weight_name = name.split(".")[-2]
gate_up_proj_weights[l_idx][weight_name] = weight
else:
if "transformer.h" in name:
from_name, l = layer_template(name, 2)
from_name, l_idx = layer_template(name, 2)
to_name = weight_map[from_name]
to_name = to_name.format(l)
to_name = to_name.format(l_idx)
else:
to_name = weight_map[name]
param = load_param(param, name, None)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param

if config.name.startswith("Phi-3"):
for i in list(gate_up_proj_weights):
fc_1_weight = gate_up_proj_weights[i]["fc_1"]
fc_2_weight = gate_up_proj_weights[i]["fc_2"]
weight = torch.concat([fc_1_weight, fc_2_weight], dim=0)
layer_name = f"model.layers.{i}.mlp.gate_up_proj.weight"
state_dict[layer_name] = weight
del gate_up_proj_weights[i]


def qkv_split(
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
Expand Down Expand Up @@ -256,11 +287,11 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:

if "falcon" in config.name:
copy_fn = partial(copy_weights_falcon, config.name)
elif config.name.lower().startswith("phi"):
copy_fn = partial(copy_weights_phi, config)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
untie_weights = "Gemma" in config.name
copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights)
elif "phi" in config.name:
copy_fn = partial(copy_weights_phi, config)
else:
copy_fn = copy_weights_gpt_neox

Expand Down
7 changes: 7 additions & 0 deletions litgpt/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
if not checkpoint_dir.exists():
raise NotADirectoryError(f"The checkpoint directory does not exist: {str(checkpoint_dir)}")

self.model_name = checkpoint_dir.stem
self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
self.bos_id = None
self.eos_id = None
Expand Down Expand Up @@ -114,4 +115,10 @@ def encode(

def decode(self, tensor: torch.Tensor) -> str:
tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
# Phi-3 tokenizer strips any spaces if to decode a single token at a time.
# https://github.com/huggingface/transformers/issues/31643
if self.model_name.startswith("Phi-3") and len(tokens) == 1:
dummy_token_id = 33 # \x1e
dummy_token = self.processor.decode([dummy_token_id])
return self.processor.decode([dummy_token_id] + tokens).replace(dummy_token, "")
return self.processor.decode(tokens)
Loading

0 comments on commit 0bd44ac

Please sign in to comment.