Skip to content

Commit

Permalink
Test tune
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Mar 28, 2024
1 parent 04c9d7d commit 2d8569d
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 67 deletions.
73 changes: 73 additions & 0 deletions tests/torchtune/_cli/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,76 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import runpy
import sys

import pytest

from tests.common import TUNE_PATH


class TestTuneRunCommand:
def test_run_calls_distributed_run_for_distributed_recipe(
self, capsys, monkeypatch, mocker
):
testargs = "tune run --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full".split()

monkeypatch.setattr(sys, "argv", testargs)
distributed_run = mocker.patch("torch.distributed.run.run")
runpy.run_path(TUNE_PATH, run_name="__main__")
distributed_run.assert_called_once()

output = capsys.readouterr()
assert "Running with torchrun..." in output.out

def test_run_calls_single_device_run_for_single_device_recipe(
self, capsys, monkeypatch, mocker
):
testargs = "tune run full_finetune_single_device --config llama2/7B_full_single_device".split()

monkeypatch.setattr(sys, "argv", testargs)
single_device_run = mocker.patch("torchtune._cli.tune.Run._run_single_device")
runpy.run_path(TUNE_PATH, run_name="__main__")
single_device_run.assert_called_once()

def test_run_fails_when_called_with_distributed_args_for_single_device_recipe(
self, capsys, monkeypatch
):
testargs = "tune run --nproc_per_node 4 full_finetune_single_device --config llama2/7B_full_single_device".split()

monkeypatch.setattr(sys, "argv", testargs)
with pytest.raises(SystemExit, match="2"):
runpy.run_path(TUNE_PATH, run_name="__main__")

output = capsys.readouterr()
assert "does not support distributed training" in output.err

def test_run_fails_when_config_not_passed_in(self, capsys, monkeypatch):
testargs = "tune run full_finetune_single_device batch_size=3".split()

monkeypatch.setattr(sys, "argv", testargs)
with pytest.raises(SystemExit, match="2"):
runpy.run_path(TUNE_PATH, run_name="__main__")

output = capsys.readouterr()
assert "The '--config' argument is required" in output.err

def test_run_succeeds_with_local_recipe_file_and_default_config(
self, capsys, monkeypatch, mocker
):
testargs = "tune run my_custom_recipe.py --config llama2/7B_full".split()
monkeypatch.setattr(sys, "argv", testargs)
local_file_run = mocker.patch("torchtune._cli.tune.Run._run_single_device")
runpy.run_path(TUNE_PATH, run_name="__main__")
local_file_run.assert_called_once()

def test_run_calls_local_file_run_for_local_file_recipe(
self, capsys, monkeypatch, mocker
):
testargs = "tune run my_custom_recipe.py --config custom_config.yaml".split()

monkeypatch.setattr(sys, "argv", testargs)
local_file_run = mocker.patch("torchtune._cli.tune.Run._run_single_device")
runpy.run_path(TUNE_PATH, run_name="__main__")
local_file_run.assert_called_once()
65 changes: 6 additions & 59 deletions tests/torchtune/_cli/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,71 +7,18 @@

import runpy
import sys
from pathlib import Path

import pytest

from tests.common import TUNE_PATH

from torchtune import get_all_recipes


class TestTuneRunCommand:
def test_run_calls_distributed_run_for_distributed_recipe(
self, capsys, monkeypatch, mocker
):
testargs = "tune run --num-gpu 4 full_finetune_distributed --config llama2/7B_full".split()
class TestTuneCLI:
def test_tune_without_args_returns_help(self, capsys, monkeypatch):
testargs = ["tune"]

monkeypatch.setattr(sys, "argv", testargs)
distributed_run = mocker.patch("torch.distributed.run.run")
runpy.run_path(TUNE_PATH, run_name="__main__")
distributed_run.assert_called_once()

output = capsys.readouterr()
assert "Running with torchrun..." in output.out

def test_run_calls_single_device_run_for_single_device_recipe(
self, capsys, monkeypatch, mocker
):
testargs = "tune run full_finetune_single_device --config llama2/7B_full_single_device".split()

monkeypatch.setattr(sys, "argv", testargs)
single_device_run = mocker.patch.object(
torchtune._cli.tune.Run, "_run_single_device", autospec=True
)
runpy.run_path(TUNE_PATH, run_name="__main__")
single_device_run.assert_called_once()

def test_run_fails_when_called_with_distributed_args_for_single_device_recipe(
self, capsys, monkeypatch
):
testargs = "tune run --num-gpu 4 full_finetune_single_device --config llama2/7B_full_single_device".split()

monkeypatch.setattr(sys, "argv", testargs)
with pytest.raises(SystemExit):
runpy.run_path(TUNE_PATH, run_name="__main__")

output = capsys.readouterr()
assert "does not support distributed training" in output.err

def test_run_calls_local_file_run_for_local_file_recipe(
self, capsys, monkeypatch, mocker
):
testargs = "tune run my_custom_recipe.py --config custom_config.yaml".split()

monkeypatch.setattr(sys, "argv", testargs)
local_file_run = mocker.patch("torchtune._cli.tune.Run._run_single_device")
runpy.run_path(TUNE_PATH, run_name="__main__")
local_file_run.assert_called_once()

def test_run_fails_when_using_custom_recipe_and_default_config(
self, capsys, monkeypatch
):
testargs = "tune run my_custom_recipe.py --config llama2/7B_full".split()

monkeypatch.setattr(sys, "argv", testargs)
with pytest.raises(SystemExit):
runpy.run_path(TUNE_PATH, run_name="__main__")
captured = capsys.readouterr()
output = captured.out.rstrip("\n")

output = capsys.readouterr()
assert "please copy the config file to your local dir first" in output.err
assert "Welcome to the TorchTune CLI!" in output
12 changes: 7 additions & 5 deletions torchtune/_cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
# LICENSE file in the root directory of this source tree.

import argparse
import sys
import textwrap

from pathlib import Path
from typing import Optional

import torchtune

from torch.distributed.run import get_args_parser as get_torchrun_args_parser, run
from torchtune import Config, get_all_recipes, Recipe
from torchtune._cli.subcommand import Subcommand

ROOT = Path(torchtune.__file__).parent.parent
Expand All @@ -34,7 +39,7 @@ def __init__(self, subparsers):
$ tune run lora_finetune_single_device --config llama2/7B_lora_single_device
# Run a finetuning recipe in a distributed fashion using torchrun w/ default values
$ tune run --num-gpu=4 full_finetune_distributed --config llama2/7B_full_finetune_distributed
$ tune run --nproc_per_node 4 full_finetune_distributed --config llama2/7B_full_finetune_distributed
# Override a parameter in the config file and specify a number of GPUs for torchrun
$ tune run lora_finetune_single_device \
Expand Down Expand Up @@ -72,17 +77,14 @@ def _run_distributed(self, args: argparse.Namespace):
# we don't do this since we test on CPUs for distributed. Will update once multi GPU CI is supported.
print("Running with torchrun...")
# Have to reset the argv so that the recipe can be run with the correct arguments
# args = copy.deepcopy(args)
args.__dict__["training_script"] = args.__dict__.pop("recipe")
args.__dict__["training_script_args"] = args.__dict__.pop("recipe_args")
run(args)
# return 0

def _run_single_device(self, args: argparse.Namespace):
"""Run a recipe on a single device."""
sys.argv = [str(args.recipe)] + args.recipe_args
runpy.run_path(str(args.recipe), run_name="__main__")
# return 0

def _is_distributed_args(self, args: argparse.Namespace):
"""Check if the user is trying to run a distributed recipe."""
Expand All @@ -106,7 +108,7 @@ def _get_config(
if config.name == config_str:
return config

# Search through all recipes
# If not, search through all recipes
for recipe in get_all_recipes():
for config in recipe.get_configs():
if config.name == config_str:
Expand Down
5 changes: 2 additions & 3 deletions torchtune/_cli/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from torchtune._cli.cp import Copy
from torchtune._cli.download import Download
from torchtune._cli.ls import List

# from torchtune._cli.run import Run
from torchtune._cli.run import Run
from torchtune._cli.validate import Validate


Expand All @@ -33,7 +32,7 @@ def __init__(self):
Download.create(subparsers)
List.create(subparsers)
Copy.create(subparsers)
# Run.create(subparsers)
Run.create(subparsers)
Validate.create(subparsers)

def parse_args(self) -> argparse.Namespace:
Expand Down

0 comments on commit 2d8569d

Please sign in to comment.