Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: bug: memory leak when using bentoml>=1.2 #4775

Merged
merged 7 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/_bentoml_sdk/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def encode(self, obj: Path) -> bytes:
return obj.read_bytes()

def decode(self, obj: bytes | t.BinaryIO | UploadFile | PurePath | str) -> t.Any:
from bentoml._internal.context import request_directory
from bentoml._internal.context import request_temp_dir

media_type: str | None = None

Expand Down Expand Up @@ -156,7 +156,7 @@ def decode(self, obj: bytes | t.BinaryIO | UploadFile | PurePath | str) -> t.Any
f"Invalid content type {media_type}, expected {self.content_type}"
)
with tempfile.NamedTemporaryFile(
suffix=filename, dir=request_directory.get(), delete=False
suffix=filename, dir=request_temp_dir(), delete=False
) as f:
f.write(body)
return Path(f.name)
Expand Down
51 changes: 28 additions & 23 deletions src/bentoml/_internal/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
import contextvars
import os
import tempfile
import typing as t
from abc import ABC
from abc import abstractmethod
Expand All @@ -13,15 +12,28 @@
import starlette.datastructures

from .utils.http import Cookie
from .utils.temp import TempfilePool

if TYPE_CHECKING:
import starlette.requests
import starlette.responses

# A request-unique directory for storing temporary files
request_directory: contextvars.ContextVar[str] = contextvars.ContextVar(
"request_directory"
_request_var: contextvars.ContextVar[starlette.requests.Request] = (
contextvars.ContextVar("request")
)
_response_var: contextvars.ContextVar[ServiceContext.ResponseContext] = (
contextvars.ContextVar("response")
)

request_tempdir_pool = TempfilePool(prefix="bentoml-request-")


def request_temp_dir() -> str:
"""A request-unique directory for storing temporary files"""
request = _request_var.get()
if not hasattr(request.state, "temp_dir"):
request.state.temp_dir = request_tempdir_pool.acquire()
return request.state.temp_dir
frostming marked this conversation as resolved.
Show resolved Hide resolved


class Metadata(t.Mapping[str, str], ABC):
Expand Down Expand Up @@ -81,12 +93,6 @@ def mutablecopy(self) -> Metadata:

class ServiceContext:
def __init__(self) -> None:
self._request_var: contextvars.ContextVar[starlette.requests.Request] = (
contextvars.ContextVar("request")
)
self._response_var: contextvars.ContextVar[ServiceContext.ResponseContext] = (
contextvars.ContextVar("response")
)
# A dictionary for storing global state shared by the process
self.state: dict[str, t.Any] = {}

Expand All @@ -95,28 +101,27 @@ def in_request(
self, request: starlette.requests.Request
) -> t.Generator[ServiceContext, None, None]:
request.metadata = request.headers # type: ignore[attr-defined]
request_token = self._request_var.set(request)
response_token = self._response_var.set(ServiceContext.ResponseContext())
with tempfile.TemporaryDirectory(prefix="bentoml-request-") as temp_dir:
dir_token = request_directory.set(temp_dir)
try:
yield self
finally:
self._request_var.reset(request_token)
self._response_var.reset(response_token)
request_directory.reset(dir_token)
request_token = _request_var.set(request)
response_token = _response_var.set(ServiceContext.ResponseContext())
try:
yield self
finally:
if hasattr(request.state, "temp_dir"):
request_tempdir_pool.release(request.state.temp_dir)
_request_var.reset(request_token)
_response_var.reset(response_token)

@property
def request(self) -> starlette.requests.Request:
return self._request_var.get()
return _request_var.get()

@property
def response(self) -> ResponseContext:
return self._response_var.get()
return _response_var.get()

@property
def temp_dir(self) -> str:
return request_directory.get()
return request_temp_dir()

@attr.define
class ResponseContext:
Expand Down
4 changes: 3 additions & 1 deletion src/bentoml/_internal/server/base_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def on_startup(self) -> list[LifecycleHook]:

@property
def on_shutdown(self) -> list[LifecycleHook]:
return []
from ..context import request_tempdir_pool

return [request_tempdir_pool.cleanup]

def mark_as_ready(self) -> None:
self._is_ready = True
Expand Down
45 changes: 45 additions & 0 deletions src/bentoml/_internal/utils/temp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

import shutil
import tempfile
from collections import deque
from functools import partial
from pathlib import Path
from threading import Lock


class TempfilePool:
"""A simple pool to get temp directories,
so they are reused as much as possible.
"""

def __init__(
self,
suffix: str | None = None,
prefix: str | None = None,
dir: str | None = None,
) -> None:
self._pool: deque[str] = deque([])
self._lock = Lock()
self._new = partial(tempfile.mkdtemp, suffix=suffix, prefix=prefix, dir=dir)

def cleanup(self) -> None:
while len(self._pool):
dir = self._pool.popleft()
shutil.rmtree(dir, ignore_errors=True)

def acquire(self) -> str:
with self._lock:
if not len(self._pool):
return self._new()
else:
return self._pool.popleft()

def release(self, dir: str) -> None:
for child in Path(dir).iterdir():
if child.is_dir():
shutil.rmtree(child)
else:
child.unlink()
with self._lock:
self._pool.append(dir)
Loading