Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Phi3 Mini 4K Instruct Model to torchtune #876

Merged
merged 13 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ torchtune currently supports the following models.
| [Code-Llama2](https://huggingface.co/codellama) | 7B, 13B, 70B [[model](torchtune/models/code_llama2/_model_builders.py), [configs](recipes/configs/code_llama2/)] |
| [Mistral](https://huggingface.co/mistralai) | 7B [[model](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] |
| [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B [[model](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] |
| [Microsoft Phi3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) | Mini [[model](torchtune/models/phi3/)] |

We'll be adding a number of new models in the coming weeks, including support for 70B versions and MoEs.

Expand Down
74 changes: 74 additions & 0 deletions recipes/configs/phi3/mini_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Config for multi-device full finetuning in full_finetune_distributed.py
# using a Phi3 Mini 4K Instruct
#
# This config assumes that you've run the following command before launching
# this run:
# tune download microsoft/Phi-3-mini-4k-instruct --output-dir ./Phi-3-mini-4k-instruct --hf-token <HF_TOKEN> --ignore-patterns ""
#
# Run this config on 4 GPUs using the following:
# tune run --nproc_per_node 4 recipes/full_finetune_distributed.py --config recipes/configs/phi3/mini_full.yaml
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 4 recipes/full_finetune_distributed.py --config recipes/configs/phi3/mini_full.yaml checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
# Single device full finetuning requires more memory optimizations. It's
# best to use mini_low_memory.yaml for those cases

# Tokenizer
tokenizer:
_component_: torchtune.models.phi3.phi3_tokenizer
path: /tmp/cpts/Phi-3-mini-4k-instruct/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.phi3.phi3_mini

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Phi-3-mini-4k-instruct
model_type: PHI3_MINI
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 3
optimizer:
_component_: torch.optim.AdamW
lr: 5e-6
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1


# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/Phi-3-mini-4k-instruct
log_every_n_steps: null
79 changes: 79 additions & 0 deletions recipes/configs/phi3/mini_full_low_memory.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Config for single device full finetuning in full_finetune_single_device.py
# using a Phi3 Mini 4K Instruct
#
# This config assumes that you've run the following command before launching
# this run:
# tune download microsoft/Phi-3-mini-4k-instruct --output-dir ./Phi-3-mini-4k-instruct --hf-token <HF_TOKEN> --ignore-patterns ""
#
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device --config recipes/configs/phi3/mini_full_low_memory.yaml
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device --config recipes/configs/phi3/mini_full_low_memory.yaml checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

# Tokenizer
tokenizer:
_component_: torchtune.models.phi3.phi3_tokenizer
path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.phi3.phi3_mini

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
checkpoint_files: [
model-00001-of-00002.safetensors,
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Phi-3-mini-4k-instruct
model_type: PHI3_MINI
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 3
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 5e-6
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
optimizer_in_bwd: True

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Model compilation
compile: False

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/Phi-3-mini-4k-instruct
log_every_n_steps: null
43 changes: 43 additions & 0 deletions tests/torchtune/modules/test_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from tests.test_utils import assert_expected
from torch import tensor
from torchtune.models.phi3 import Phi3RotaryPositionalEmbeddings

from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings
from torchtune.utils.seed import set_seed
Expand Down Expand Up @@ -97,3 +98,45 @@ def test_rope_init_meta_device(self, input_params):
meta_rope._rope_init()
for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()):
torch.testing.assert_close(p1, p2)


class TestPhi3RotaryPositionalEmbeddings:
"""
Class for testing the Phi3 models RoPE Embeddings. The expected tensors are
computed from the reference implementation here:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
"""

@pytest.fixture
def input_params(self) -> Tuple[int, int, int, int]:
bsz = 4
num_heads = 32
embed_dim = 3072
seq_len = 60
max_seq_len = 4096
head_dim = embed_dim // num_heads
return bsz, num_heads, head_dim, seq_len, max_seq_len

@pytest.fixture
def input(self, input_params: Tuple[int, int, int, int]) -> tensor:
bsz, num_heads, head_dim, seq_len, _ = input_params
return torch.randn(bsz, seq_len, num_heads, head_dim)

@pytest.fixture
def rope_phi3(
self, input_params: Tuple[int, int, int, int]
) -> Phi3RotaryPositionalEmbeddings:
_, _, head_dim, _, max_seq_len = input_params
return Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)

def test_forward(
self, input: tensor, rope_phi3: Phi3RotaryPositionalEmbeddings
) -> None:
x_out = rope_phi3(input)

# check the numerics of the computed tensor
assert_expected(x_out.mean(), tensor(-0.0005), atol=1e-4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include in the PR where these numbers come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Info's in the gist! Don't really want to replicate all of the info in the context again

assert_expected(x_out.sum(), tensor(-381.0620))

# check shapes
assert_expected(x_out.shape, input.shape)
6 changes: 3 additions & 3 deletions tests/torchtune/utils/test_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import randn

from torchtune.models import llama2
from torchtune.utils._checkpointing import FullModelHFCheckpointer, ModelType
from torchtune.utils._checkpointing import FullModelHFCheckpointer
from torchtune.utils._checkpointing._checkpointer_utils import safe_torch_load
from torchtune.utils.seed import set_seed

Expand Down Expand Up @@ -161,7 +161,7 @@ def single_file_checkpointer(
return FullModelHFCheckpointer(
checkpoint_dir=tmp_path,
checkpoint_files=[checkpoint_file],
model_type=ModelType.LLAMA2,
model_type="LLAMA2",
output_dir=tmp_path,
)

Expand All @@ -173,7 +173,7 @@ def multi_file_checkpointer(
return FullModelHFCheckpointer(
checkpoint_dir=tmp_path,
checkpoint_files=[checkpoint_file_1, checkpoint_file_2],
model_type=ModelType.LLAMA2,
model_type="LLAMA2",
output_dir=tmp_path,
)

Expand Down
10 changes: 5 additions & 5 deletions torchtune/models/convert_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
}


def _get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
try:
if "layers" in key:
# Replace layer number with "{}" to create key for lookup
Expand Down Expand Up @@ -82,7 +82,7 @@ def meta_to_tune(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]
converted_state_dict = {}
for key, value in state_dict.items():
if key not in ["rope.freqs"]: # Skip loading the position embeddings
new_key = _get_mapped_key(key, _FROM_META)
new_key = get_mapped_key(key, _FROM_META)
converted_state_dict[new_key] = value

return converted_state_dict
Expand All @@ -104,7 +104,7 @@ def tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]
inverted_mapping_dict = {v: k for k, v in _FROM_META.items()}

for key, value in state_dict.items():
new_key = _get_mapped_key(key, inverted_mapping_dict)
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value

return converted_state_dict
Expand Down Expand Up @@ -149,7 +149,7 @@ def _permute(t, n_heads):

for key, value in state_dict.items():
if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings
new_key = _get_mapped_key(key, _FROM_HF)
new_key = get_mapped_key(key, _FROM_HF)
if "q_proj" in key:
value = _permute(value, num_heads)
elif "k_proj" in key:
Expand Down Expand Up @@ -190,7 +190,7 @@ def _permute(t, n_heads):
)

for key, value in state_dict.items():
new_key = _get_mapped_key(key, inverted_mapping_dict)
new_key = get_mapped_key(key, inverted_mapping_dict)
if "q_proj" in key:
value = _permute(value, num_heads)
elif "k_proj" in key:
Expand Down
10 changes: 10 additions & 0 deletions torchtune/models/phi3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._component_builders import phi3 # noqa
from ._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf # noqa
from ._model_builders import phi3_mini, phi3_tokenizer # noqa
from ._position_embeddings import Phi3RotaryPositionalEmbeddings # noqa
Loading
Loading