Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add support for apps created by factory functions #37

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.0
rev: v0.4.9
hooks:
- id: ruff
args:
Expand Down
3 changes: 2 additions & 1 deletion src/fastapi_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _run(
proxy_headers: bool = False,
) -> None:
try:
use_uvicorn_app = get_import_string(path=path, app_name=app)
use_uvicorn_app, is_factory = get_import_string(path=path, app_name=app)
except FastAPICLIException as e:
logger.error(str(e))
raise typer.Exit(code=1) from None
Expand Down Expand Up @@ -97,6 +97,7 @@ def _run(
workers=workers,
root_path=root_path,
proxy_headers=proxy_headers,
factory=is_factory,
)


Expand Down
38 changes: 26 additions & 12 deletions src/fastapi_cli/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from logging import getLogger
from pathlib import Path
from typing import Union
from typing import Any, Callable, Tuple, Union, get_type_hints

from rich import print
from rich.padding import Padding
Expand Down Expand Up @@ -98,7 +98,9 @@ def get_module_data_from_path(path: Path) -> ModuleData:
)


def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) -> str:
def get_app_name(
*, mod_data: ModuleData, app_name: Union[str, None] = None
) -> Tuple[str, bool]:
try:
mod = importlib.import_module(mod_data.module_import_str)
except (ImportError, ValueError) as e:
Expand All @@ -119,26 +121,38 @@ def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) ->
f"Could not find app name {app_name} in {mod_data.module_import_str}"
)
app = getattr(mod, app_name)
is_factory = False
if not isinstance(app, FastAPI):
raise FastAPICLIException(
f"The app name {app_name} in {mod_data.module_import_str} doesn't seem to be a FastAPI app"
)
return app_name
for preferred_name in ["app", "api"]:
is_factory = check_factory(app)
if not is_factory:
raise FastAPICLIException(
f"The app name {app_name} in {mod_data.module_import_str} doesn't seem to be a FastAPI app"
)
return app_name, is_factory
for preferred_name in ["app", "api", "create_app", "create_api"]:
if preferred_name in object_names_set:
obj = getattr(mod, preferred_name)
if isinstance(obj, FastAPI):
return preferred_name
return preferred_name, False
if check_factory(obj):
return preferred_name, True
for name in object_names:
obj = getattr(mod, name)
if isinstance(obj, FastAPI):
return name
return name, False
raise FastAPICLIException("Could not find FastAPI app in module, try using --app")


def check_factory(fn: Callable[[], Any]) -> bool:
"""Checks whether the return-type of a factory function is FastAPI"""
type_hints = get_type_hints(fn)
return_type = type_hints.get("return")
return return_type is not None and issubclass(return_type, FastAPI)


def get_import_string(
*, path: Union[Path, None] = None, app_name: Union[str, None] = None
) -> str:
) -> Tuple[str, bool]:
if not path:
path = get_default_path()
logger.info(f"Using path [blue]{path}[/blue]")
Expand All @@ -147,7 +161,7 @@ def get_import_string(
raise FastAPICLIException(f"Path does not exist {path}")
mod_data = get_module_data_from_path(path)
sys.path.insert(0, str(mod_data.extra_sys_path))
use_app_name = get_app_name(mod_data=mod_data, app_name=app_name)
use_app_name, is_factory = get_app_name(mod_data=mod_data, app_name=app_name)
import_example = Syntax(
f"from {mod_data.module_import_str} import {use_app_name}", "python"
)
Expand All @@ -164,4 +178,4 @@ def get_import_string(
print(import_panel)
import_string = f"{mod_data.module_import_str}:{use_app_name}"
logger.info(f"Using import string [b green]{import_string}[/b green]")
return import_string
return import_string, is_factory
11 changes: 11 additions & 0 deletions tests/assets/factory_create_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fastapi import FastAPI


def create_api() -> FastAPI:
app = FastAPI()

@app.get("/")
def app_root():
return {"message": "single file factory app"}

return app
24 changes: 24 additions & 0 deletions tests/assets/factory_create_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from fastapi import FastAPI


class App(FastAPI): ...


def create_app_other() -> App:
app = App()

@app.get("/")
def app_root():
return {"message": "single file factory app inherited"}

return app


def create_app() -> FastAPI:
app = FastAPI()

@app.get("/")
def app_root():
return {"message": "single file factory app"}

return app
11 changes: 11 additions & 0 deletions tests/assets/package/mod/factory_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fastapi import FastAPI


def create_api() -> FastAPI:
app = FastAPI()

@app.get("/")
def root():
return {"message": "package create_api"}

return app
11 changes: 11 additions & 0 deletions tests/assets/package/mod/factory_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fastapi import FastAPI


def create_app() -> FastAPI:
app = FastAPI()

@app.get("/")
def root():
return {"message": "package create_app"}

return app
14 changes: 14 additions & 0 deletions tests/assets/package/mod/factory_inherit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from fastapi import FastAPI


class App(FastAPI): ...


def create_app() -> App:
app = App()

@app.get("/")
def root():
return {"message": "package build_app"}

return app
11 changes: 11 additions & 0 deletions tests/assets/package/mod/factory_other.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fastapi import FastAPI


def build_app() -> FastAPI:
app = FastAPI()

@app.get("/")
def root():
return {"message": "package build_app"}

return app
58 changes: 58 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_dev() -> None:
"workers": None,
"root_path": "",
"proxy_headers": True,
"factory": False,
}
assert "Using import string single_file_app:app" in result.output
assert (
Expand All @@ -40,6 +41,33 @@ def test_dev() -> None:
assert "│ fastapi run" in result.output


def test_dev_factory() -> None:
with changing_dir(assets_path):
with patch.object(uvicorn, "run") as mock_run:
result = runner.invoke(app, ["dev", "factory_create_app.py"])
assert result.exit_code == 0, result.output
assert mock_run.called
assert mock_run.call_args
assert mock_run.call_args.kwargs == {
"app": "factory_create_app:create_app",
"host": "127.0.0.1",
"port": 8000,
"reload": True,
"workers": None,
"root_path": "",
"proxy_headers": True,
"factory": True,
}
assert "Using import string factory_create_app:create_app" in result.output
assert (
"╭────────── FastAPI CLI - Development mode ───────────╮" in result.output
)
assert "│ Serving at: http://127.0.0.1:8000" in result.output
assert "│ API docs: http://127.0.0.1:8000/docs" in result.output
assert "│ Running in development mode, for production use:" in result.output
assert "│ fastapi run" in result.output


def test_dev_args() -> None:
with changing_dir(assets_path):
with patch.object(uvicorn, "run") as mock_run:
Expand Down Expand Up @@ -71,6 +99,7 @@ def test_dev_args() -> None:
"workers": None,
"root_path": "/api",
"proxy_headers": False,
"factory": False,
}
assert "Using import string single_file_app:api" in result.output
assert (
Expand All @@ -97,6 +126,7 @@ def test_run() -> None:
"workers": None,
"root_path": "",
"proxy_headers": True,
"factory": False,
}
assert "Using import string single_file_app:app" in result.output
assert (
Expand All @@ -108,6 +138,33 @@ def test_run() -> None:
assert "│ fastapi dev" in result.output


def test_run_factory() -> None:
with changing_dir(assets_path):
with patch.object(uvicorn, "run") as mock_run:
result = runner.invoke(app, ["run", "factory_create_app.py"])
assert result.exit_code == 0, result.output
assert mock_run.called
assert mock_run.call_args
assert mock_run.call_args.kwargs == {
"app": "factory_create_app:create_app",
"host": "0.0.0.0",
"port": 8000,
"reload": False,
"workers": None,
"root_path": "",
"proxy_headers": True,
"factory": True,
}
assert "Using import string factory_create_app:create_app" in result.output
assert (
"╭─────────── FastAPI CLI - Production mode ───────────╮" in result.output
)
assert "│ Serving at: http://0.0.0.0:8000" in result.output
assert "│ API docs: http://0.0.0.0:8000/docs" in result.output
assert "│ Running in production mode, for development use:" in result.output
assert "│ fastapi dev" in result.output


def test_run_args() -> None:
with changing_dir(assets_path):
with patch.object(uvicorn, "run") as mock_run:
Expand Down Expand Up @@ -141,6 +198,7 @@ def test_run_args() -> None:
"workers": 2,
"root_path": "/api",
"proxy_headers": False,
"factory": False,
}
assert "Using import string single_file_app:api" in result.output
assert (
Expand Down
32 changes: 32 additions & 0 deletions tests/test_utils_check_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from fastapi import FastAPI
from fastapi_cli.discover import check_factory


def test_check_untyped_factory() -> None:
def create_app(): # type: ignore[no-untyped-def]
return FastAPI() # pragma: no cover

assert check_factory(create_app) is False


def test_check_typed_factory() -> None:
def create_app() -> FastAPI:
return FastAPI() # pragma: no cover

assert check_factory(create_app) is True


def test_check_typed_factory_inherited() -> None:
class MyApp(FastAPI): ...

def create_app() -> MyApp:
return MyApp() # pragma: no cover

assert check_factory(create_app) is True


def test_create_app_with_different_type() -> None:
def create_app() -> int:
return 1 # pragma: no cover

assert check_factory(create_app) is False
9 changes: 6 additions & 3 deletions tests/test_utils_default_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

def test_app_dir_main(capsys: CaptureFixture[str]) -> None:
with changing_dir(assets_path / "default_files" / "default_app_dir_main"):
import_string = get_import_string()
import_string, is_factory = get_import_string()
assert import_string == "app.main:app"
assert is_factory is False

captured = capsys.readouterr()
assert "Using path app/main.py" in captured.out
Expand All @@ -36,8 +37,9 @@ def test_app_dir_main(capsys: CaptureFixture[str]) -> None:

def test_app_dir_app(capsys: CaptureFixture[str]) -> None:
with changing_dir(assets_path / "default_files" / "default_app_dir_app"):
import_string = get_import_string()
import_string, is_factory = get_import_string()
assert import_string == "app.app:app"
assert is_factory is False

captured = capsys.readouterr()
assert "Using path app/app.py" in captured.out
Expand All @@ -58,8 +60,9 @@ def test_app_dir_app(capsys: CaptureFixture[str]) -> None:

def test_app_dir_api(capsys: CaptureFixture[str]) -> None:
with changing_dir(assets_path / "default_files" / "default_app_dir_api"):
import_string = get_import_string()
import_string, is_factory = get_import_string()
assert import_string == "app.api:app"
assert is_factory is False

captured = capsys.readouterr()
assert "Using path app/api.py" in captured.out
Expand Down
9 changes: 6 additions & 3 deletions tests/test_utils_default_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def test_single_file_main(capsys: CaptureFixture[str]) -> None:
mod = importlib.import_module("main")

importlib.reload(mod)
import_string = get_import_string()
import_string, is_factory = get_import_string()
assert import_string == "main:app"
assert is_factory is False

captured = capsys.readouterr()
assert "Using path main.py" in captured.out
Expand All @@ -47,8 +48,9 @@ def test_single_file_app(capsys: CaptureFixture[str]) -> None:
mod = importlib.import_module("app")

importlib.reload(mod)
import_string = get_import_string()
import_string, is_factory = get_import_string()
assert import_string == "app:app"
assert is_factory is False

captured = capsys.readouterr()
assert "Using path app.py" in captured.out
Expand All @@ -74,8 +76,9 @@ def test_single_file_api(capsys: CaptureFixture[str]) -> None:
mod = importlib.import_module("api")

importlib.reload(mod)
import_string = get_import_string()
import_string, is_factory = get_import_string()
assert import_string == "api:app"
assert is_factory is False

captured = capsys.readouterr()
assert "Using path api.py" in captured.out
Expand Down
Loading