forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "delete all tests except recipe"
This reverts commit 47bbc30.
- Loading branch information
Felipe Mello
committed
Sep 12, 2024
1 parent
47bbc30
commit 79f93ab
Showing
130 changed files
with
18,222 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import re | ||
import runpy | ||
import sys | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torchtune | ||
from tests.common import TUNE_PATH | ||
from tests.test_utils import CKPT_MODEL_PATHS, gpu_test | ||
|
||
|
||
CKPT = "llama2_7b" | ||
|
||
# TODO: remove this once we have eval configs exposed properly | ||
pkg_path = Path(torchtune.__file__).parent.absolute() | ||
EVAL_CONFIG_PATH = Path.joinpath( | ||
pkg_path, "_cli", "eval_configs", "default_eval_config.yaml" | ||
) | ||
|
||
|
||
@gpu_test(gpu_count=2) | ||
class TestLoRA7BDistributedFinetuneEval: | ||
@pytest.mark.slow_integration_test | ||
def test_finetune_and_eval(self, tmpdir, capsys, monkeypatch): | ||
|
||
ckpt_path = Path(CKPT_MODEL_PATHS[CKPT]) | ||
ckpt_dir = ckpt_path.parent | ||
|
||
# Run on prod LoRA FT config but with only 10 steps for now | ||
ft_cmd = f""" | ||
tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed | ||
--config llama2/7B_lora \ | ||
output_dir={tmpdir} \ | ||
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer | ||
checkpointer.checkpoint_dir='{ckpt_dir}' \ | ||
checkpointer.checkpoint_files=[{ckpt_path}]\ | ||
checkpointer.output_dir={tmpdir} \ | ||
checkpointer.model_type=LLAMA2 \ | ||
tokenizer.path=/tmp/test-artifacts/tokenizer.model \ | ||
max_steps_per_epoch=10 \ | ||
""".split() | ||
|
||
monkeypatch.setattr(sys, "argv", ft_cmd) | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
eval_cmd = f""" | ||
tune run eleuther_eval \ | ||
--config eleuther_evaluation \ | ||
output_dir={tmpdir} \ | ||
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ | ||
checkpointer.checkpoint_dir='{tmpdir}' \ | ||
checkpointer.checkpoint_files=[torchtune_model_0.pt] \ | ||
checkpointer.output_dir={tmpdir} \ | ||
tokenizer.path=/tmp/test-artifacts/tokenizer.model \ | ||
tasks=['truthfulqa_mc2'] | ||
limit=10 \ | ||
device=cuda \ | ||
""".split() | ||
monkeypatch.setattr(sys, "argv", eval_cmd) | ||
with pytest.raises(SystemExit): | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
out = capsys.readouterr().out | ||
search_results = re.search( | ||
r"acc(?:_norm)?\s*\|?\s*(?:\↑\s*\|?)?([\d.]+)", out.strip() | ||
) | ||
assert search_results is not None | ||
acc_result = float(search_results.group(1)) | ||
assert acc_result >= 0.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import runpy | ||
import sys | ||
from pathlib import Path | ||
|
||
import pytest | ||
from tests.common import TUNE_PATH | ||
|
||
|
||
class TestTuneCLIWithCopyScript: | ||
"""This class tests the `tune cp` command.""" | ||
|
||
@pytest.mark.parametrize("already_exists", (True, False)) | ||
def test_copy_successful(self, capsys, monkeypatch, tmpdir, already_exists): | ||
tmpdir_path = Path(tmpdir) | ||
dest = tmpdir_path / "my_custom_finetune.yaml" | ||
|
||
if already_exists: | ||
dest.touch() | ||
|
||
args = f"tune cp llama2/7B_full {dest}".split() | ||
|
||
monkeypatch.setattr(sys, "argv", args) | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
captured = capsys.readouterr() | ||
out = captured.out.rstrip("\n") | ||
|
||
assert dest.exists(), f"Expected {dest} to exist" | ||
assert f"Copied file to {dest}" in out | ||
|
||
def test_copy_successful_with_cwd_as_path(self, capsys, monkeypatch, tmpdir): | ||
tmpdir_path = Path(tmpdir) | ||
|
||
# Needed so we can run test from tmpdir | ||
tune_path_as_absolute = Path(TUNE_PATH).absolute() | ||
|
||
# Change cwd to tmpdir | ||
monkeypatch.chdir(tmpdir_path) | ||
|
||
args = "tune cp llama2/7B_full .".split() | ||
monkeypatch.setattr(sys, "argv", args) | ||
runpy.run_path(str(tune_path_as_absolute), run_name="__main__") | ||
|
||
captured = capsys.readouterr() | ||
out = captured.out.rstrip("\n") | ||
|
||
dest = tmpdir_path / "7B_full.yaml" | ||
|
||
assert dest.exists() | ||
assert "Copied file to ./7B_full.yaml" in out | ||
|
||
def test_copy_skips_when_dest_already_exists_and_no_clobber_is_true( | ||
self, capsys, monkeypatch, tmpdir | ||
): | ||
tmpdir_path = Path(tmpdir) | ||
existing_file = tmpdir_path / "existing_file.yaml" | ||
existing_file.touch() | ||
|
||
args = f"tune cp llama2/7B_full_low_memory {existing_file} -n".split() | ||
|
||
monkeypatch.setattr(sys, "argv", args) | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
captured = capsys.readouterr() | ||
out = captured.out.rstrip("\n") | ||
err = captured.err.rstrip("\n") | ||
|
||
assert err == "" | ||
assert "not overwriting" in out | ||
|
||
def test_adds_correct_suffix_to_dest_when_no_suffix_is_provided( | ||
self, capsys, monkeypatch, tmpdir | ||
): | ||
tmpdir_path = Path(tmpdir) | ||
dest = tmpdir_path / "my_custom_finetune" | ||
|
||
args = f"tune cp llama2/7B_full_low_memory {dest}".split() | ||
|
||
monkeypatch.setattr(sys, "argv", args) | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
captured = capsys.readouterr() | ||
out = captured.out.rstrip("\n") | ||
|
||
assert dest.with_suffix(".yaml").exists(), f"Expected {dest} to exist" | ||
assert f"Copied file to {dest}.yaml" in out | ||
|
||
@pytest.mark.parametrize( | ||
"tune_command,expected_error_message", | ||
[ | ||
( | ||
"tune cp non_existent_recipe .", | ||
"error: Invalid file name: non_existent_recipe. Try `tune ls` to see all available files to copy.", | ||
), | ||
( | ||
"tune cp non_existent_config .", | ||
"error: Invalid file name: non_existent_config. Try `tune ls` to see all available files to copy.", | ||
), | ||
( | ||
"tune cp full_finetune_single_device /home/mr_bean/full_finetune_single_device.py", | ||
"error: Cannot create regular file: '/home/mr_bean/full_finetune_single_device.py'. No such file or directory.", | ||
), | ||
( | ||
"tune cp", | ||
"error: the following arguments are required: file, destination", | ||
), | ||
], | ||
) | ||
def test_copy_fails_when_given_invalid_recipe( | ||
self, capsys, monkeypatch, tune_command, expected_error_message | ||
): | ||
args = tune_command.split() | ||
|
||
monkeypatch.setattr(sys, "argv", args) | ||
with pytest.raises(SystemExit): | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
captured = capsys.readouterr() | ||
err = captured.err.rstrip("\n") | ||
|
||
assert expected_error_message in err |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import runpy | ||
import sys | ||
|
||
import pytest | ||
from tests.common import TUNE_PATH | ||
|
||
|
||
class TestTuneDownloadCommand: | ||
"""This class tests the `tune download` command.""" | ||
|
||
@pytest.fixture | ||
def snapshot_download(self, mocker, tmpdir): | ||
|
||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError | ||
|
||
yield mocker.patch( | ||
"torchtune._cli.download.snapshot_download", | ||
return_value=tmpdir, | ||
# Side effects are iterated through on each call | ||
side_effect=[ | ||
GatedRepoError("test"), | ||
RepositoryNotFoundError("test"), | ||
mocker.DEFAULT, | ||
], | ||
) | ||
|
||
def test_download_calls_snapshot(self, capsys, monkeypatch, snapshot_download): | ||
model = "meta-llama/Llama-2-7b" | ||
testargs = f"tune download {model}".split() | ||
monkeypatch.setattr(sys, "argv", testargs) | ||
|
||
# Call the first time and get GatedRepoError | ||
with pytest.raises(SystemExit, match="2"): | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
out_err = capsys.readouterr() | ||
assert ( | ||
"Ignoring files matching the following patterns: *.safetensors" | ||
in out_err.out | ||
) | ||
assert ( | ||
"It looks like you are trying to access a gated repository." in out_err.err | ||
) | ||
|
||
# Call the second time and get RepositoryNotFoundError | ||
with pytest.raises(SystemExit, match="2"): | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
out_err = capsys.readouterr() | ||
assert ( | ||
"Ignoring files matching the following patterns: *.safetensors" | ||
in out_err.out | ||
) | ||
assert "not found on the Hugging Face Hub" in out_err.err | ||
|
||
# Call the third time and get the expected output | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
output = capsys.readouterr().out | ||
assert "Ignoring files matching the following patterns: *.safetensors" in output | ||
assert "Successfully downloaded model repo" in output | ||
|
||
# Make sure it was called twice | ||
assert snapshot_download.call_count == 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import runpy | ||
import sys | ||
|
||
from tests.common import TUNE_PATH | ||
|
||
from torchtune._recipe_registry import get_all_recipes | ||
|
||
|
||
class TestTuneListCommand: | ||
"""This class tests the `tune ls` command.""" | ||
|
||
def test_ls_lists_all_recipes_and_configs(self, capsys, monkeypatch): | ||
testargs = "tune ls".split() | ||
|
||
monkeypatch.setattr(sys, "argv", testargs) | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
captured = capsys.readouterr() | ||
output = captured.out.rstrip("\n") | ||
|
||
for recipe in get_all_recipes(): | ||
assert recipe.name in output | ||
for config in recipe.configs: | ||
assert config.name in output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import runpy | ||
import sys | ||
|
||
import pytest | ||
|
||
from tests.common import TUNE_PATH | ||
|
||
|
||
class TestTuneRunCommand: | ||
def test_run_calls_distributed_run_for_distributed_recipe( | ||
self, capsys, monkeypatch, mocker | ||
): | ||
testargs = "tune run --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full".split() | ||
|
||
monkeypatch.setattr(sys, "argv", testargs) | ||
distributed_run = mocker.patch("torchtune._cli.tune.Run._run_distributed") | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
distributed_run.assert_called_once() | ||
|
||
def test_run_calls_single_device_run_for_single_device_recipe( | ||
self, capsys, monkeypatch, mocker | ||
): | ||
testargs = "tune run full_finetune_single_device --config llama2/7B_full_single_device".split() | ||
|
||
monkeypatch.setattr(sys, "argv", testargs) | ||
single_device_run = mocker.patch("torchtune._cli.tune.Run._run_single_device") | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
single_device_run.assert_called_once() | ||
|
||
def test_run_fails_when_called_with_distributed_args_for_single_device_recipe( | ||
self, capsys, monkeypatch | ||
): | ||
testargs = "tune run --nproc_per_node 4 full_finetune_single_device --config llama2/7B_full_single_device".split() | ||
|
||
monkeypatch.setattr(sys, "argv", testargs) | ||
with pytest.raises(SystemExit, match="2"): | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
output = capsys.readouterr() | ||
assert "does not support distributed training" in output.err | ||
|
||
def test_run_fails_when_config_not_passed_in(self, capsys, monkeypatch): | ||
testargs = "tune run full_finetune_single_device batch_size=3".split() | ||
|
||
monkeypatch.setattr(sys, "argv", testargs) | ||
with pytest.raises(SystemExit, match="2"): | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
|
||
output = capsys.readouterr() | ||
assert "The '--config' argument is required" in output.err | ||
|
||
def test_run_succeeds_with_local_recipe_file_and_default_config( | ||
self, capsys, monkeypatch, mocker | ||
): | ||
testargs = "tune run my_custom_recipe.py --config llama2/7B_full".split() | ||
monkeypatch.setattr(sys, "argv", testargs) | ||
local_file_run = mocker.patch("torchtune._cli.tune.Run._run_single_device") | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
local_file_run.assert_called_once() | ||
|
||
def test_run_calls_local_file_run_for_local_file_recipe( | ||
self, capsys, monkeypatch, mocker | ||
): | ||
testargs = "tune run my_custom_recipe.py --config custom_config.yaml".split() | ||
|
||
monkeypatch.setattr(sys, "argv", testargs) | ||
local_file_run = mocker.patch("torchtune._cli.tune.Run._run_single_device") | ||
runpy.run_path(TUNE_PATH, run_name="__main__") | ||
local_file_run.assert_called_once() |
Oops, something went wrong.