Skip to content

Commit

Permalink
[serve] Add deployment handle dependency injection to build_app (#4…
Browse files Browse the repository at this point in the history
…8447)

Adds dependency injection to create a custom `DeploymentHandle` type in
`build_app`. This will be used to create a special type of handle for
local testing mode.

This also allowed me to remove the mocking/patching happening in the
unit tests.

---------

Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
edoakes authored Oct 30, 2024
1 parent 7d912d6 commit 41be27c
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 47 deletions.
75 changes: 54 additions & 21 deletions python/ray/serve/_private/build_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Generic, List, TypeVar
from typing import Callable, Dict, Generic, List, Optional, TypeVar

from ray.dag.py_obj_scanner import _PyObjScanner
from ray.serve._private.constants import SERVE_LOGGER_NAME
Expand All @@ -21,16 +21,24 @@ class IDDict(dict, Generic[K, V]):
"""

def __getitem__(self, key: K) -> V:
return super().__getitem__(id(key))
if not isinstance(key, int):
key = id(key)
return super().__getitem__(key)

def __setitem__(self, key: K, value: V):
return super().__setitem__(id(key), value)
if not isinstance(key, int):
key = id(key)
return super().__setitem__(key, value)

def __delitem__(self, key: K):
return super().__delitem__(id(key))
if not isinstance(key, int):
key = id(key)
return super().__delitem__(key)

def __contains__(self, key: object):
return super().__contains__(id(key))
if not isinstance(key, int):
key = id(key)
return super().__contains__(key)


@dataclass(frozen=True)
Expand All @@ -42,12 +50,27 @@ class BuiltApplication:
ingress_deployment_name: str
# List of unique deployments comprising the app.
deployments: List[Deployment]
# Dict[name, DeploymentHandle] mapping deployment names to the handles that replaced
# them in other deployments' init args/kwargs.
deployment_handles: Dict[str, DeploymentHandle]


def _make_deployment_handle_default(
deployment: Deployment, app_name: str
) -> DeploymentHandle:
return DeploymentHandle(
deployment.name,
app_name=app_name,
)


def build_app(
app: Application,
*,
name: str,
make_deployment_handle: Optional[
Callable[[Deployment, str], DeploymentHandle]
] = None,
) -> BuiltApplication:
"""Builds the application into a list of finalized deployments.
Expand All @@ -59,16 +82,25 @@ def build_app(
Returns: BuiltApplication
"""
if make_deployment_handle is None:
make_deployment_handle = _make_deployment_handle_default

handles = IDDict()
deployment_names = IDDict()
deployments = _build_app_recursive(
app,
app_name=name,
handles=IDDict(),
deployment_names=IDDict(),
handles=handles,
deployment_names=deployment_names,
make_deployment_handle=make_deployment_handle,
)
return BuiltApplication(
name=name,
ingress_deployment_name=app._bound_deployment.name,
deployments=deployments,
deployment_handles={
deployment_names[app]: handle for app, handle in handles.items()
},
)


Expand All @@ -78,6 +110,7 @@ def _build_app_recursive(
app_name: str,
deployment_names: IDDict[Application, str],
handles: IDDict[Application, DeploymentHandle],
make_deployment_handle: Callable[[Deployment, str], DeploymentHandle],
) -> List[Deployment]:
"""Recursively traverses the graph of Application objects.
Expand All @@ -93,13 +126,6 @@ def _build_app_recursive(
if app in handles:
return []

# Create the DeploymentHandle that will be used to replace this application
# in the arguments of its parent(s).
handles[app] = DeploymentHandle(
_get_unique_deployment_name_memoized(app, deployment_names),
app_name=app_name,
)

deployments = []
scanner = _PyObjScanner(source_type=Application)
try:
Expand All @@ -114,19 +140,26 @@ def _build_app_recursive(
app_name=app_name,
handles=handles,
deployment_names=deployment_names,
make_deployment_handle=make_deployment_handle,
)
)

# Replace Application objects with their corresponding DeploymentHandles.
new_init_args, new_init_kwargs = scanner.replace_nodes(handles)
deployments.append(
app._bound_deployment.options(
name=_get_unique_deployment_name_memoized(app, deployment_names),
_init_args=new_init_args,
_init_kwargs=new_init_kwargs,
)
final_deployment = app._bound_deployment.options(
name=_get_unique_deployment_name_memoized(app, deployment_names),
_init_args=new_init_args,
_init_kwargs=new_init_kwargs,
)
return deployments

# Create the DeploymentHandle that will be used to replace this application
# in the arguments of its parent(s).
handles[app] = make_deployment_handle(
final_deployment,
app_name,
)

return deployments + [final_deployment]
finally:
scanner.clear()

Expand Down
81 changes: 55 additions & 26 deletions python/ray/serve/tests/unit/test_build_app.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,28 @@
import sys
from typing import List
from unittest import mock
from typing import Any, List

import pytest

from ray import serve
from ray.serve._private.build_app import BuiltApplication, build_app
from ray.serve._private.common import DeploymentID
from ray.serve.deployment import Application, Deployment
from ray.serve.handle import DeploymentHandle


@pytest.fixture(autouse=True)
def patch_handle_eq():
"""Patch DeploymentHandle.__eq__ to compare options we care about."""
class FakeDeploymentHandle:
def __init__(self, deployment_name: str, app_name: str):
self.deployment_id = DeploymentID(deployment_name, app_name)

def _patched_handle_eq(self, other):
return all(
[
isinstance(other, type(self)),
self.deployment_id == other.deployment_id,
self.handle_options == other.handle_options,
]
)
@classmethod
def from_deployment(cls, deployment, app_name: str) -> "FakeDeploymentHandle":
return cls(deployment.name, app_name)

with mock.patch(
"ray.serve.handle._DeploymentHandleBase.__eq__", _patched_handle_eq
):
yield
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, FakeDeploymentHandle)
and self.deployment_id == other.deployment_id
)


def _build_and_check(
Expand All @@ -36,11 +32,25 @@ def _build_and_check(
expected_deployments: List[Deployment],
app_name: str = "default",
):
built_app: BuiltApplication = build_app(app, name=app_name)
built_app: BuiltApplication = build_app(
app,
name=app_name,
# Each real DeploymentHandle has a unique ID (intentionally), so the below
# equality checks don't work. Use a fake implementation instead.
make_deployment_handle=FakeDeploymentHandle.from_deployment,
)
assert built_app.name == app_name
assert built_app.ingress_deployment_name == expected_ingress_name
assert len(built_app.deployments) == len(expected_deployments)

# Check that the returned deployment_handles are populated properly.
assert len(built_app.deployment_handles) == len(expected_deployments)
for d in expected_deployments:
h = built_app.deployment_handles.get(d.name, None)
assert h is not None, f"No handle returned for deployment {d.name}."
assert isinstance(h, FakeDeploymentHandle)
assert h.deployment_id == DeploymentID(d.name, app_name=app_name)

for expected_deployment in expected_deployments:
generated_deployment = None
for d in built_app.deployments:
Expand All @@ -56,6 +66,25 @@ def _build_and_check(
assert expected_deployment == generated_deployment


def test_real_deployment_handle_default():
"""Other tests inject a FakeDeploymentHandle, so check the default behavior."""

@serve.deployment
class D:
pass

built_app: BuiltApplication = build_app(
D.bind(D.options(name="Inner").bind()),
name="app-name",
)
assert len(built_app.deployments) == 2
assert len(built_app.deployments[1].init_args) == 1
assert isinstance(built_app.deployments[1].init_args[0], DeploymentHandle)
assert built_app.deployments[1].init_args[0].deployment_id == DeploymentID(
"Inner", app_name="app-name"
)


def test_single_deployment_basic():
@serve.deployment(
num_replicas=123,
Expand Down Expand Up @@ -119,13 +148,13 @@ class Outer:
Outer.options(
name="Outer",
_init_args=(
DeploymentHandle(
FakeDeploymentHandle(
"Inner",
app_name="default",
),
),
_init_kwargs={
"other": DeploymentHandle(
"other": FakeDeploymentHandle(
"Other",
app_name="default",
),
Expand Down Expand Up @@ -154,7 +183,7 @@ class Outer:
name="Outer",
_init_args=(
[
DeploymentHandle(
FakeDeploymentHandle(
"Inner",
app_name="default",
),
Expand Down Expand Up @@ -185,7 +214,7 @@ class Outer:
Outer.options(
name="Outer",
_init_args=(
DeploymentHandle(
FakeDeploymentHandle(
"Inner",
app_name="custom",
),
Expand Down Expand Up @@ -218,11 +247,11 @@ class Outer:
Outer.options(
name="Outer",
_init_args=(
DeploymentHandle(
FakeDeploymentHandle(
"Inner",
app_name="default",
),
DeploymentHandle(
FakeDeploymentHandle(
"Inner_1",
app_name="default",
),
Expand All @@ -248,7 +277,7 @@ class Outer:

shared = Shared.bind()
app = Outer.bind(Inner.bind(shared), shared)
shared_handle = DeploymentHandle(
shared_handle = FakeDeploymentHandle(
"Shared",
app_name="default",
)
Expand All @@ -265,7 +294,7 @@ class Outer:
Outer.options(
name="Outer",
_init_args=(
DeploymentHandle(
FakeDeploymentHandle(
"Inner",
app_name="default",
),
Expand Down

0 comments on commit 41be27c

Please sign in to comment.