Skip to content

Commit

Permalink
hotfix import torch (#15849)
Browse files Browse the repository at this point in the history
* fix import torch

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* plugin

* fix

* skip

* patch require

* seed

* warn

* .

* ..

* skip True

* 0.0.3

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Nov 28, 2022
1 parent 95d5ccb commit ad4bd66
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 23 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci-pkg-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- name: DocTests actions
working-directory: .actions/
run: |
pip install pytest -q
pip install -q pytest
python -m pytest setup_tools.py
- run: python -c "print('NB_DIRS=' + str(2 if '${{ matrix.pkg-name }}' == 'pytorch' else 1))" >> $GITHUB_ENV
Expand All @@ -67,7 +67,10 @@ jobs:

- name: DocTest package
env:
LIGHTING_TESTING: 1 # path for require wrapper
PY_IGNORE_IMPORTMISMATCH: 1
run: |
pip install -q "pytest-doctestplus>=0.9.0"
pip list
PKG_NAME=$(python -c "print({'app': 'lightning_app', 'lite': 'lightning_lite', 'pytorch': 'pytorch_lightning', 'lightning': 'lightning'}['${{matrix.pkg-name}}'])")
python -m pytest src/${PKG_NAME} --ignore-glob="**/cli/*-template/**"
2 changes: 1 addition & 1 deletion requirements/app/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ beautifulsoup4>=4.8.0, <4.11.2
inquirer>=2.10.0
psutil<5.9.4
click<=8.1.3
lightning_api_access>=0.0.1
lightning_api_access>=0.0.3
1 change: 1 addition & 0 deletions requirements/app/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ codecov==2.1.12
pytest==7.2.0
pytest-timeout==2.1.0
pytest-cov==4.0.0
pytest-doctestplus>=0.9.0
playwright==1.27.1
httpx
trio<0.22.0
Expand Down
18 changes: 13 additions & 5 deletions src/lightning_app/components/serve/python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from typing import Any, Dict, Optional

import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
Expand All @@ -13,16 +12,21 @@
from lightning_app.core.queues import MultiProcessQueue
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.imports import _is_torch_available, requires
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver

logger = Logger(__name__)

# Skip doctests if requirements aren't available
if not _is_torch_available():
__doctest_skip__ = ["PythonServer", "PythonServer.*"]


class _PyTorchSpawnRunExecutor(WorkRunExecutor):

"""This Executor enables to move PyTorch tensors on GPU.
Without this executor, it woud raise the following expection:
Without this executor, it would raise the following exception:
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
To use CUDA with multiprocessing, you must use the 'spawn' start method
"""
Expand Down Expand Up @@ -86,6 +90,7 @@ def _get_sample_data() -> Dict[Any, Any]:


class PythonServer(LightningWork, abc.ABC):
@requires("torch")
def __init__( # type: ignore
self,
host: str = "127.0.0.1",
Expand Down Expand Up @@ -127,15 +132,16 @@ def predict(self, request):
and this can be accessed as `response.json()["prediction"]` in the client if
you are using requests library
.. doctest::
Example:
>>> from lightning_app.components.serve.python_server import PythonServer
>>> from lightning_app import LightningApp
>>>
...
>>> class SimpleServer(PythonServer):
...
... def setup(self):
... self._model = lambda x: x + " " + x
...
... def predict(self, request):
... return {"prediction": self._model(request.image)}
...
Expand Down Expand Up @@ -199,11 +205,13 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
return out

def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
from torch import inference_mode

input_type: type = self.configure_input_type()
output_type: type = self.configure_output_type()

def predict_fn(request: input_type): # type: ignore
with torch.inference_mode():
with inference_mode():
return self.predict(request)

fastapi_app.post("/predict", response_model=output_type)(predict_fn)
Expand Down
13 changes: 9 additions & 4 deletions src/lightning_app/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities."""

import functools
import os
import warnings
from typing import List, Union

from lightning_utilities.core.imports import module_available
Expand Down Expand Up @@ -52,10 +54,13 @@ def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
unavailable_modules = [f"'{module}'" for module in module_paths if not module_available(module)]
if any(unavailable_modules) and not bool(int(os.getenv("LIGHTING_TESTING", "0"))):
raise ModuleNotFoundError(
f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}"
)
if any(unavailable_modules):
is_lit_testing = bool(int(os.getenv("LIGHTING_TESTING", "0")))
msg = f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}"
if is_lit_testing:
warnings.warn(msg)
else:
raise ModuleNotFoundError(msg)
return func(*args, **kwargs)

return wrapper
Expand Down
13 changes: 7 additions & 6 deletions src/lightning_app/utilities/name_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,12 +1332,13 @@ def get_unique_name():
Original source:
https://raw.githubusercontent.com/moby/moby/master/pkg/namesgenerator/names-generator.go
Examples
--------
>>> get_unique_name() # doctest: +SKIP
'focused-turing-23'
>>> get_unique_name() # doctest: +SKIP
'thirsty-allen-9200'
Examples:
>>> import random ; random.seed(42)
>>> get_unique_name()
'meek-ardinghelli-4506'
>>> get_unique_name()
'truthful-dijkstra-2286'
"""
adjective, surname, i = choice(_adjectives), choice(_surnames), randint(0, 9999)
return f"{adjective}-{surname}-{i}"
4 changes: 2 additions & 2 deletions tests/tests_app/core/test_lightning_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def run(self):


# TODO: Find why this test is flaky.
@pytest.mark.skipif(True, reason="flaky test.")
@pytest.mark.skip(reason="flaky test.")
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime])
def test_app_state_api_with_flows(runtime_cls, tmpdir):
"""This test validates the AppState can properly broadcast changes from flows."""
Expand Down Expand Up @@ -180,7 +180,7 @@ def maybe_apply_changes(self):


# FIXME: This test doesn't assert anything
@pytest.mark.skipif(True, reason="TODO: Resolve flaky test.")
@pytest.mark.skip(reason="TODO: Resolve flaky test.")
@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime])
def test_app_stage_from_frontend(runtime_cls):
"""This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/core/test_lightning_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def run(self):


# TODO (tchaton) Resolve this test.
@pytest.mark.skipif(True, reason="flaky test which never terminates")
@pytest.mark.skip(reason="flaky test which never terminates")
@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime])
@pytest.mark.parametrize("use_same_args", [False, True])
def test_state_wait_for_all_all_works(tmpdir, runtime_cls, use_same_args):
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/structures/test_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def run(self):
self.counter += 1


@pytest.mark.skipif(True, reason="tchaton: Resolve this test.")
@pytest.mark.skip(reason="tchaton: Resolve this test.")
@pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime, SingleProcessRuntime])
@pytest.mark.parametrize("run_once_iterable", [False, True])
@pytest.mark.parametrize("cache_calls", [False, True])
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_app/utilities/packaging/test_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from lightning_app.utilities.redis import check_if_redis_running


@pytest.mark.skipif(True, reason="FIXME (tchaton)")
@pytest.mark.skip(reason="FIXME (tchaton)")
@pytest.mark.skipif(not _is_docker_available(), reason="docker is required for this test.")
@pytest.mark.skipif(not check_if_redis_running(), reason="redis is required for this test.")
@_RunIf(skip_windows=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_lite/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def run(rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdi
_atomic_save(model.state_dict(), os.path.join(tmpdir, "model_spawn.pt"))


@pytest.mark.skipif(True, reason="Skipping as it takes 80 seconds.")
@pytest.mark.skip(reason="Skipping as it takes 80 seconds.")
@RunIf(min_cuda_gpus=2)
@pytest.mark.parametrize(
"precision, strategy, devices, accelerator",
Expand Down

0 comments on commit ad4bd66

Please sign in to comment.