Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix TPU parsing and TPU tests #2094

Merged
merged 37 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ee00683
added tpu params test
Jun 6, 2020
836aaf9
added tests
Jun 6, 2020
33a6e85
removed xla imports
Jun 6, 2020
a9e55e1
added test cases for TPU
Jun 6, 2020
243237e
fix pep 8 issues
Jun 6, 2020
cb18a87
refactorings and comments
Jun 6, 2020
45c8042
add message to MisconfigurationException
lezwon Jun 7, 2020
2239f9f
test if device is set correctly
Jun 7, 2020
4417897
added TPU device check
Jun 7, 2020
c31a90d
removed device selection
Jun 7, 2020
c7ebb46
remove xla_device call
Jun 7, 2020
56fe55e
readded spawn due to test failures
Jun 7, 2020
e1a5270
add TODO for tpu check
Jun 7, 2020
5b9b9b6
Apply suggestions from code review
Borda Jun 8, 2020
1e83661
Apply suggestions from code review
Borda Jun 8, 2020
6543961
flake8
Borda Jun 9, 2020
3fda595
added tpu args to cli tests
Jun 9, 2020
4ca6e17
added support for tpu_core selection via cli
Jun 14, 2020
d9ccdbb
fixed flake formatting
Jun 14, 2020
acacf59
replaced default_save_path with default_root_dir
Jun 14, 2020
e4c11b1
added check for data type for tpu_cores
Jun 14, 2020
2ae3862
fixed flake indent
Jun 14, 2020
6a7784c
protected
Borda Jun 14, 2020
ec44801
protected
Borda Jun 14, 2020
5377ed2
chlog
Borda Jun 18, 2020
b20a287
added tpu params test
Jun 6, 2020
4bceb92
added tests
Jun 6, 2020
16cd53c
removed xla imports
Jun 6, 2020
5290434
test if device is set correctly
Jun 7, 2020
857f052
added support for tpu_core selection via cli
Jun 14, 2020
3561c4f
replaced default_save_path with default_root_dir
Jun 14, 2020
d6e72c7
added check for data type for tpu_cores
Jun 14, 2020
0a2c2e3
fixed tpu cores error
Jun 19, 2020
a3269e9
rebased with latest changes
Jun 20, 2020
03fe3b5
flake fix
Jun 20, 2020
7f5d95a
Update pytorch_lightning/trainer/distrib_parts.py
lezwon Jun 21, 2020
da06cc4
Merge branch 'master' into 1246_tpu_tests
Borda Jun 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed parsing TPU arguments and TPU tests ([#2094](https://github.com/PyTorchLightning/pytorch-lightning/pull/2094))

- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))


Expand Down
67 changes: 56 additions & 11 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def filter_named_parameters(model, optimizer):
hvd.join()


def normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
if isinstance(s, str):
if s == '-1':
return -1
Expand All @@ -348,19 +348,19 @@ def get_all_available_gpus() -> List[int]:
return list(range(torch.cuda.device_count()))


def check_gpus_data_type(gpus: Any) -> None:
def _check_data_type(device_ids: Any) -> None:
"""
Checks that the gpus argument is one of: None, Int, String or List.
Checks that the device_ids argument is one of: None, Int, String or List.
Raises a MisconfigurationException otherwise.

Args:
gpus: parameter as passed to the Trainer
device_ids: gpus/tpu_cores parameter as passed to the Trainer
"""
if gpus is not None and (not isinstance(gpus, (int, str, MutableSequence)) or isinstance(gpus, bool)):
raise MisconfigurationException("GPUs must be int, string or sequence of ints or None.")
if device_ids is not None and (not isinstance(device_ids, (int, str, MutableSequence)) or isinstance(device_ids, bool)):
raise MisconfigurationException("Device ID's (GPU/TPU) must be int, string or sequence of ints or None.")


def normalize_parse_gpu_input_to_list(gpus: Union[int, List[int]]) -> Optional[List[int]]:
def _normalize_parse_gpu_input_to_list(gpus: Union[int, List[int]]) -> Optional[List[int]]:
assert gpus is not None
if isinstance(gpus, MutableSequence):
return list(gpus)
Expand Down Expand Up @@ -405,7 +405,7 @@ def sanitize_gpu_ids(gpus: List[int]) -> List[int]:
return gpus


def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[int]]:
def _parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[int]]:
"""
Parses the GPU ids given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.
Expand All @@ -429,7 +429,7 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
return None

# Check that gpus param is None, Int, String or List
check_gpus_data_type(gpus)
_check_data_type(gpus)

# Handle the case when no gpus are requested
if gpus is None or isinstance(gpus, int) and gpus == 0:
Expand All @@ -438,8 +438,8 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
# We know user requested GPUs therefore if some of the
# requested GPUs are not available an exception is thrown.

gpus = normalize_parse_gpu_string_input(gpus)
gpus = normalize_parse_gpu_input_to_list(gpus)
gpus = _normalize_parse_gpu_string_input(gpus)
gpus = _normalize_parse_gpu_input_to_list(gpus)
if not gpus:
raise MisconfigurationException("GPUs requested but none are available.")
gpus = sanitize_gpu_ids(gpus)
Expand Down Expand Up @@ -493,6 +493,51 @@ def retry_jittered_backoff(func: Callable, num_retries: int = 5, cap_delay: floa
sleep_delay = min(cap_delay, random.uniform(base_delay, sleep_delay * 3))


def _parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int], int]]:
"""
Parses the tpu_cores given in the format as accepted by the
:class:`~pytorch_lightning.trainer.Trainer`.

Args:
tpu_cores: An int 1 or string '1' indicate that 1 core with multi-processing should be used
An int 8 or string '8' indicate that all 8 cores with multi-processing should be used
A list of int or a string containing list of comma separated integer
indicates specific TPU core to use.

Returns:
a list of tpu_cores to be used or ``None`` if no TPU cores were requested
"""

if callable(tpu_cores):
return None

_check_data_type(tpu_cores)

if isinstance(tpu_cores, str):
tpu_cores = _parse_tpu_cores_str(tpu_cores.strip())

if not _tpu_cores_valid(tpu_cores):
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")

return tpu_cores


def _tpu_cores_valid(tpu_cores):
return tpu_cores in (1, 8, None) or (
isinstance(tpu_cores, (list, tuple, set)) and
len(tpu_cores) == 1 and
tpu_cores[0] in range(1, 9)
)


def _parse_tpu_cores_str(tpu_cores):
if tpu_cores == '1' or tpu_cores == '8':
tpu_cores = int(tpu_cores)
else:
tpu_cores = [int(x.strip()) for x in tpu_cores.split(',') if len(x) > 0]
return tpu_cores


def pick_single_gpu(exclude_gpus: list):
for i in range(torch.cuda.device_count()):
if i in exclude_gpus:
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
TrainerDeprecatedAPITillVer0_9, TrainerDeprecatedAPITillVer0_10)
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import (
TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device, pick_multiple_gpus)
TrainerDPMixin, _parse_gpu_ids, determine_root_gpu_device, pick_multiple_gpus, _parse_tpu_cores)
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
num_processes: int = 1,
gpus: Optional[Union[List[int], str, int]] = None,
auto_select_gpus: bool = False,
tpu_cores: Optional[Union[List[int], int]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None,
log_gpu_memory: Optional[str] = None,
progress_bar_refresh_rate: int = 1,
overfit_batches: Union[int, float] = 0.0,
Expand Down Expand Up @@ -360,13 +360,10 @@ def __init__(

if tpu_cores is None:
tpu_cores = num_tpu_cores
self.on_tpu = tpu_cores is not None
self.tpu_cores = tpu_cores
assert self.tpu_cores in (1, 8, None) or (
isinstance(self.tpu_cores, (list, tuple, set)) and len(self.tpu_cores) == 1
), '`tpu_cores` can only be 1, 8 or [<1-8>]'
self.tpu_cores = _parse_tpu_cores(tpu_cores)
self.on_tpu = self.tpu_cores is not None

self.tpu_id = tpu_cores[0] if isinstance(tpu_cores, list) else None
self.tpu_id = self.tpu_cores[0] if isinstance(self.tpu_cores, list) else None

if num_processes != 1 and distributed_backend != "ddp_cpu":
rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
Expand Down Expand Up @@ -460,7 +457,7 @@ def __init__(
else:
self.gpus = gpus

self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
self.data_parallel_device_ids = _parse_gpu_ids(self.gpus)
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
self.root_device = torch.device("cpu")

Expand Down Expand Up @@ -703,7 +700,7 @@ def use_type(x):
else:
use_type = arg_types[0]

if arg == 'gpus':
if arg == 'gpus' or arg == 'tpu_cores':
use_type = Trainer._allowed_type
arg_default = Trainer._arg_default

Expand Down Expand Up @@ -920,6 +917,9 @@ def fit(
elif self.use_tpu: # pragma: no-cover
rank_zero_info(f'training on {self.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
raise MisconfigurationException('No TPU devices found.')

# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if self.on_colab_kaggle else 'spawn'

Expand Down
12 changes: 6 additions & 6 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.core import memory
from pytorch_lightning.trainer.distrib_parts import parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.trainer.distrib_parts import _parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate

Expand Down Expand Up @@ -202,7 +202,7 @@ def test_determine_root_gpu_device(gpus, expected_root_gpu):
pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"),
])
def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
assert parse_gpu_ids(gpus) == expected_gpu_ids
assert _parse_gpu_ids(gpus) == expected_gpu_ids


@pytest.mark.gpus_param_tests
Expand All @@ -218,27 +218,27 @@ def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids):
])
def test_parse_gpu_fail_on_unsupported_inputs(mocked_device_count, gpus):
with pytest.raises(MisconfigurationException):
parse_gpu_ids(gpus)
_parse_gpu_ids(gpus)


@pytest.mark.gpus_param_tests
@pytest.mark.parametrize("gpus", [[1, 2, 19], -1, '-1'])
def test_parse_gpu_fail_on_non_existent_id(mocked_device_count_0, gpus):
with pytest.raises(MisconfigurationException):
parse_gpu_ids(gpus)
_parse_gpu_ids(gpus)


@pytest.mark.gpus_param_tests
def test_parse_gpu_fail_on_non_existent_id_2(mocked_device_count):
with pytest.raises(MisconfigurationException):
parse_gpu_ids([1, 2, 19])
_parse_gpu_ids([1, 2, 19])


@pytest.mark.gpus_param_tests
@pytest.mark.parametrize("gpus", [-1, '-1'])
def test_parse_gpu_returns_None_when_no_devices_are_available(mocked_device_count_0, gpus):
with pytest.raises(MisconfigurationException):
parse_gpu_ids(gpus)
_parse_gpu_ids(gpus)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
Expand Down
120 changes: 120 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
from unittest.mock import patch

import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate

try:
import torch_xla
# TODO: The tests are aborted if the following lines are uncommented. Must be resolved with XLA team
# device = torch_xla.core.xla_model.xla_device()
# device_type = torch_xla.core.xla_model.xla_device_hw(device)
# TPU_AVAILABLE = device_type == 'TPU'
except ImportError:
TPU_AVAILABLE = False
else:
TPU_AVAILABLE = True


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.parametrize(['tpu_cores', 'expected_device'], [
pytest.param([1], 'xla:1'),
pytest.param([8], 'xla:8'),
])
def test_single_tpu_core_model(tmpdir, tpu_cores, expected_device):
"""Test if single TPU core training works"""
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
tpu_cores=tpu_cores
)
trainer.fit(model)
assert torch_xla._XLAC._xla_get_default_device() == expected_device


@pytest.mark.spawn
@pytest.mark.parametrize("tpu_cores", [1, 8])
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
def test_multi_core_tpu_model(tmpdir, tpu_cores):
"""Test if distributed TPU core training works"""
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
tpu_cores=tpu_cores,
)
trainer.fit(model)
assert trainer.tpu_id is None


@pytest.mark.spawn
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
def test_dataloaders_passed_to_fit(tmpdir):
"""Test if dataloaders passed to trainer works on TPU"""

model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
tpu_cores=8,
)
result = trainer.fit(
model,
train_dataloader=model.train_dataloader(),
val_dataloaders=model.val_dataloader(),
)
assert result, "TPU doesn't work with dataloaders passed to fit()."


@pytest.mark.spawn
@pytest.mark.parametrize("tpu_cores", [1, 8, [1]])
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
def test_mixed_precision_with_tpu(tmpdir, tpu_cores):
"""Test if FP16 TPU core training works"""
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
tpu_cores=tpu_cores,
precision=16
)
trainer.fit(model)
assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables"


@pytest.mark.parametrize(['tpu_cores', 'expected_tpu_id'], [
pytest.param(1, None),
pytest.param(8, None),
pytest.param([1], 1),
pytest.param([8], 8),
])
def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id):
"""Test if trainer.tpu_id is set as expected"""
assert Trainer(tpu_cores=tpu_cores).tpu_id == expected_tpu_id


@patch('pytorch_lightning.trainer.trainer.XLA_AVAILABLE', False)
def test_exception_when_no_tpu_found(tmpdir):
"""Test if exception is thrown when xla devices are not available"""
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
tpu_cores=8,
)

with pytest.raises(MisconfigurationException, match='No TPU devices found.'):
trainer.fit(model)
22 changes: 22 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,28 @@ def test_gpu_choice(tmpdir):
Trainer(**trainer_options, gpus=num_gpus + 1, auto_select_gpus=True)


@pytest.mark.parametrize(['tpu_cores', 'expected_tpu_id', 'error_expected'], [
pytest.param(1, None, False),
pytest.param(8, None, False),
pytest.param([1], 1, False),
pytest.param([8], 8, False),
pytest.param('1,', 1, False),
pytest.param('1', None, False),
pytest.param('9, ', 9, True),
pytest.param([9], 9, True),
pytest.param([0], 0, True),
pytest.param(2, None, True),
pytest.param(10, None, True),
])
def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
if error_expected:
with pytest.raises(MisconfigurationException, match=r'.*tpu_cores` can only be 1, 8 or [<1-8>]*'):
Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores, auto_select_gpus=True)
else:
trainer = Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores, auto_select_gpus=True)
assert trainer.tpu_id == expected_tpu_id


@pytest.mark.parametrize("trainer_kwargs,expected", [
pytest.param(
dict(distributed_backend=None, gpus=None),
Expand Down
4 changes: 4 additions & 0 deletions tests/trainer/test_trainer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def _raise():
{'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}),
pytest.param('--early_stop_callback',
{'auto_lr_find': False, 'early_stop_callback': True, 'auto_scale_batch_size': False}),
pytest.param('--tpu_cores=8',
{'tpu_cores': 8}),
pytest.param("--tpu_cores=1,",
{'tpu_cores': '1,'})
])
def test_argparse_args_parsing(cli_args, expected):
"""Test multi type argument with bool."""
Expand Down