Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/support existing environments #1891

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion docs/user-guide/custom.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ In these cases, to load your custom runtime, MLServer will need access to these
dependencies.

It is possible to load this custom set of dependencies by providing them
through an [environment tarball](../examples/conda/README), whose path can be
through an [environment tarball](../examples/conda/README) or by giving a
path to an already exisiting python environment. Both paths can be
specified within your `model-settings.json` file.

```{warning}
Expand Down Expand Up @@ -277,6 +278,21 @@ Note that, in the folder layout above, we are assuming that:
}
```

If you want to use an already exisiting python environment, you can use the parameter `environment_path` of your `model-settings.json`:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add in the docs that we can have envs etc. in the enviornment_path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain that further? I'm not quite sure what I'm supposed to do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ramonpzg minor changes to the docs FYI


```
---
emphasize-lines: 5
---
{
"model": "sum-model",
"implementation": "models.MyCustomRuntime",
"parameters": {
"environment_path": "~/micromambda/envs/my-conda-environment"
}
}
```

## Building a custom MLServer image

```{note}
Expand Down
28 changes: 23 additions & 5 deletions mlserver/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import multiprocessing
import os
import shutil
import sys
import tarfile
import glob
Expand All @@ -18,7 +19,7 @@ def _extract_env(tarball_path: str, env_path: str) -> None:
tarball.extractall(path=env_path)


def _compute_hash(tarball_path: str) -> str:
def _compute_hash_of_file(tarball_path: str) -> str:
"""
From Python 3.11's implementation of `hashlib.file_digest()`:
https://github.com/python/cpython/blob/3.11/Lib/hashlib.py#L257
Expand All @@ -39,9 +40,20 @@ def _compute_hash(tarball_path: str) -> str:
return h.hexdigest()


async def compute_hash(tarball_path: str) -> str:
def _compute_hash_of_string(string: str) -> str:
h = hashlib.sha256()
h.update(string.encode())
return h.hexdigest()


async def compute_hash_of_file(tarball_path: str) -> str:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, _compute_hash, tarball_path)
return await loop.run_in_executor(None, _compute_hash_of_file, tarball_path)


async def compute_hash_of_string(string: str) -> str:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, _compute_hash_of_string, string)


class Environment:
Expand All @@ -51,7 +63,8 @@ class Environment:
environment.
"""

def __init__(self, env_path: str, env_hash: str):
def __init__(self, env_path: str, env_hash: str, delete_env: bool = True):
self._delete_env = delete_env
self._env_path = env_path
self.env_hash = env_hash

Expand All @@ -67,7 +80,7 @@ async def from_tarball(
await loop.run_in_executor(None, _extract_env, tarball_path, env_path)

if not env_hash:
env_hash = await compute_hash(tarball_path)
env_hash = await compute_hash_of_file(tarball_path)

return cls(env_path, env_hash)

Expand Down Expand Up @@ -136,3 +149,8 @@ def __exit__(self, *exc_details) -> None:
multiprocessing.set_executable(sys.executable)
sys.path = self._prev_sys_path
os.environ["PATH"] = self._prev_bin_path

def __del__(self) -> None:
logger.info("Cleaning up environment")
if self._delete_env:
shutil.rmtree(self._env_path)
48 changes: 43 additions & 5 deletions mlserver/parallel/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import os
import shutil
import signal

from typing import Optional, Dict, List
Expand All @@ -9,7 +8,7 @@
from ..utils import to_absolute_path
from ..model import MLModel
from ..settings import Settings
from ..env import Environment, compute_hash
from ..env import Environment, compute_hash_of_file, compute_hash_of_string
from ..registry import model_initialiser

from .errors import EnvironmentNotFound
Expand Down Expand Up @@ -76,11 +75,52 @@ async def _handle_worker_stop(self, signum, frame):
)

async def _get_or_create(self, model: MLModel) -> InferencePool:
if (
model.settings.parameters is not None
and model.settings.parameters.environment_path
):
pool = await self._get_or_create_with_existing_env(
model.settings.parameters.environment_path
)
else:
pool = await self._get_or_create_with_tarball(model)
return pool

async def _get_or_create_with_existing_env(
self, environment_path: str
) -> InferencePool:
"""
Creates or returns the InferencePool for a model that uses an existing
python environment.
"""
expanded_environment_path = os.path.abspath(
os.path.expanduser(os.path.expandvars(environment_path))
)
logger.info(f"Using environment {expanded_environment_path}")
env_hash = await compute_hash_of_string(expanded_environment_path)
if env_hash in self._pools:
return self._pools[env_hash]
env = Environment(
env_path=expanded_environment_path,
env_hash=env_hash,
delete_env=False,
)
pool = InferencePool(
self._settings, env=env, on_worker_stop=self._on_worker_stop
)
self._pools[env_hash] = pool
return pool

async def _get_or_create_with_tarball(self, model: MLModel) -> InferencePool:
"""
Creates or returns the InferencePool for a model that uses a
tarball as python environment.
"""
env_tarball = _get_env_tarball(model)
if not env_tarball:
return self._default_pool

env_hash = await compute_hash(env_tarball)
env_hash = await compute_hash_of_file(env_tarball)
if env_hash in self._pools:
return self._pools[env_hash]

Expand Down Expand Up @@ -223,5 +263,3 @@ async def _close_pool(self, env_hash: Optional[str] = None):

if env_hash:
del self._pools[env_hash]
env_path = self._get_env_path(env_hash)
shutil.rmtree(env_path)
4 changes: 4 additions & 0 deletions mlserver/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ class ModelParameters(BaseSettings):
version: Optional[str] = None
"""Version of the model."""

environment_path: Optional[str] = None
"""Path to a directory that contains the python environment to be used
to load this model."""

environment_tarball: Optional[str] = None
"""Path to the environment tarball which should be used to load this
model."""
Expand Down
4 changes: 0 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,6 @@ async def env(env_tarball: str, tmp_path: str) -> Environment:
env = await Environment.from_tarball(env_tarball, str(tmp_path))
yield env

# Envs can be quite heavy, so let's make sure we're clearing them up once
# the test finishes
shutil.rmtree(tmp_path)


@pytest.fixture(autouse=True)
def logger(settings: Settings):
Expand Down
15 changes: 15 additions & 0 deletions tests/parallel/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,21 @@ def env_model_settings(env_tarball: str) -> ModelSettings:
)


@pytest.fixture
def existing_env_model_settings(env_tarball: str, tmp_path) -> ModelSettings:
from mlserver.env import _extract_env

env_path = str(tmp_path)

_extract_env(env_tarball, env_path)
model_settings = ModelSettings(
name="exising_env_model",
implementation=EnvModel,
parameters=ModelParameters(environment_path=env_path),
)
yield model_settings


@pytest.fixture
async def worker_with_env(
settings: Settings,
Expand Down
33 changes: 31 additions & 2 deletions tests/parallel/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import asyncio

from mlserver.env import Environment, compute_hash
from mlserver.env import Environment, compute_hash_of_file
from mlserver.model import MLModel
from mlserver.settings import Settings, ModelSettings
from mlserver.types import InferenceRequest
Expand Down Expand Up @@ -30,6 +30,19 @@ async def env_model(
await inference_pool_registry.unload_model(model)


@pytest.fixture
async def existing_env_model(
inference_pool_registry: InferencePoolRegistry,
existing_env_model_settings: ModelSettings,
) -> MLModel:
env_model = EnvModel(existing_env_model_settings)
model = await inference_pool_registry.load_model(env_model)

yield model

await inference_pool_registry.unload_model(model)


def test_set_environment_hash(sum_model: MLModel):
env_hash = "0e46fce1decb7a89a8b91c71d8b6975630a17224d4f00094e02e1a732f8e95f3"
_set_environment_hash(sum_model, env_hash)
Expand Down Expand Up @@ -90,6 +103,22 @@ async def test_load_model_with_env(
assert sklearn_version == "1.0.2"


async def test_load_model_with_existing_env(
inference_pool_registry: InferencePoolRegistry,
existing_env_model: MLModel,
inference_request: InferenceRequest,
):
response = await existing_env_model.predict(inference_request)

assert len(response.outputs) == 1

# Note: These versions come from the `environment.yml` found in
# `./tests/testdata/environment.yaml`
assert response.outputs[0].name == "sklearn_version"
[sklearn_version] = StringCodec.decode_output(response.outputs[0])
assert sklearn_version == "1.0.2"


async def test_load_creates_pool(
inference_pool_registry: InferencePoolRegistry,
env_model_settings: MLModel,
Expand Down Expand Up @@ -124,7 +153,7 @@ async def test_load_reuses_env_folder(
new_model = EnvModel(env_model_settings)

# Make sure there's already existing env
env_hash = await compute_hash(env_tarball)
env_hash = await compute_hash_of_file(env_tarball)
env_path = inference_pool_registry._get_env_path(env_hash)
await Environment.from_tarball(env_tarball, env_path, env_hash)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Tuple

from mlserver.env import Environment, compute_hash
from mlserver.env import Environment, compute_hash_of_file


@pytest.fixture
Expand All @@ -15,7 +15,7 @@ def expected_python_folder(env_python_version: Tuple[int, int]) -> str:


async def test_compute_hash(env_tarball: str):
env_hash = await compute_hash(env_tarball)
env_hash = await compute_hash_of_file(env_tarball)
assert len(env_hash) == 64


Expand Down