diff --git a/tests/torchtune/utils/__init__.py b/tests/torchtune/utils/__init__.py deleted file mode 100644 index 2e41cd717f..0000000000 --- a/tests/torchtune/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/tests/torchtune/utils/test_device.py b/tests/torchtune/utils/test_device.py deleted file mode 100644 index 302b6d10bb..0000000000 --- a/tests/torchtune/utils/test_device.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -from unittest import mock -from unittest.mock import patch - -import pytest - -import torch -from torchtune.utils._device import ( - _get_device_type_from_env, - _setup_cuda_device, - get_device, -) - - -class TestDevice: - - cuda_available: bool = torch.cuda.is_available() - - @patch("torch.cuda.is_available", return_value=False) - def test_get_cpu_device(self, mock_cuda): - devices = [None, "cpu", "meta"] - expected_devices = [ - torch.device("cpu"), - torch.device("cpu"), - torch.device("meta"), - ] - for device, expected_device in zip(devices, expected_devices): - device = get_device(device) - assert device == expected_device - assert device.index is None - - @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.") - def test_get_gpu_device(self) -> None: - device_idx = torch.cuda.device_count() - 1 - assert device_idx >= 0 - with mock.patch.dict(os.environ, {"LOCAL_RANK": str(device_idx)}, clear=True): - device = get_device() - assert device.type == "cuda" - assert device.index == device_idx - assert device.index == torch.cuda.current_device() - - # Test that we raise an error if the device index is specified on distributed runs - if device_idx > 0: - with pytest.raises( - RuntimeError, - match=f"Device specified is cuda:0 but was assigned cuda:{device_idx}", - ): - device = get_device("cuda:0") - - invalid_device_idx = device_idx + 10 - with mock.patch.dict(os.environ, {"LOCAL_RANK": str(invalid_device_idx)}): - with pytest.raises( - RuntimeError, - match="The local rank is larger than the number of available GPUs", - ): - device = get_device("cuda") - - # Test that we fall back to 0 if LOCAL_RANK is not specified - device = torch.device(_get_device_type_from_env()) - device = _setup_cuda_device(device) - assert device.type == "cuda" - assert device.index == 0 - assert device.index == torch.cuda.current_device() diff --git a/tests/torchtune/utils/test_logging.py b/tests/torchtune/utils/test_logging.py deleted file mode 100644 index d1dee6bcfe..0000000000 --- a/tests/torchtune/utils/test_logging.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from io import StringIO -from unittest import mock - -import pytest -from torchtune.utils.logging import deprecated, log_rank_zero - - -def test_deprecated(): - @deprecated(msg="Please use `TotallyAwesomeClass` instead.") - class DummyClass: - pass - - with pytest.warns( - FutureWarning, - match="DummyClass is deprecated and will be removed in future versions. Please use `TotallyAwesomeClass` instead.", - ): - DummyClass() - - with pytest.warns(None) as record: - DummyClass() - - assert len(record) == 0, "Warning raised twice when it should only be raised once." - - @deprecated(msg="Please use `totally_awesome_func` instead.") - def dummy_func(): - pass - - with pytest.warns( - FutureWarning, - match="dummy_func is deprecated and will be removed in future versions. Please use `totally_awesome_func` instead.", - ): - dummy_func() - - -def test_log_rank_zero(capsys): - # Create a logger and add a StreamHandler to it so we can - # assert on logged strings - logger = logging.getLogger(__name__) - logger.setLevel("DEBUG") - stream = StringIO() - handler = logging.StreamHandler(stream) - logger.addHandler(handler) - - with mock.patch( - "torchtune.utils.logging.dist.is_available", return_value=True - ), mock.patch("torchtune.utils.logging.dist.is_initialized", return_value=True): - # Make sure rank 0 logs as expected - with mock.patch( - "torchtune.utils.logging.dist.get_rank", - return_value=0, - ): - log_rank_zero(logger, "this is a test", level=logging.DEBUG) - output = stream.getvalue().strip() - assert "this is a test" in output - - # Clear the stream - stream.truncate(0) - stream.seek(0) - - # Make sure all other ranks do not log anything - with mock.patch( - "torchtune.utils.logging.dist.get_rank", - return_value=1, - ): - log_rank_zero(logger, "this is a test", level=logging.DEBUG) - output = stream.getvalue().strip() - assert not output