Skip to content

Commit

Permalink
Fix telemetry cleanup bug (from #800)
Browse files Browse the repository at this point in the history
* Fix bug where telemetry does not clean up objects when telemetry=false
* Add a memory profiler endpoint /memory, disabled by default and enabled through an environment variable
  • Loading branch information
farshidz committed Apr 15, 2024
1 parent 3bd949c commit a65ed08
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 79 deletions.
8 changes: 7 additions & 1 deletion requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# tensor search:
# tensor search, idential to requirements.txt:
requests==2.28.1
anyio==3.7.1
fastapi==0.86.0
uvicorn[standard]
fastapi-utils==0.2.1
jsonschema==4.17.1
typing-extensions==4.5.0
urllib3==1.26.0
pydantic==1.10.11
httpx==0.25.0
semver==3.0.2
memory-profiler==0.61.0

# s2_inference:
more_itertools
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ typing-extensions==4.5.0
urllib3==1.26.0
pydantic==1.10.11
httpx==0.25.0
semver==3.0.2
semver==3.0.2
memory-profiler==0.61.0
8 changes: 8 additions & 0 deletions src/marqo/core/models/memory_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import List

from marqo.base_model import ImmutableStrictBaseModel


class MemoryProfile(ImmutableStrictBaseModel):
memory_used: float
stats: List[str]
20 changes: 20 additions & 0 deletions src/marqo/core/monitoring/memory_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import tracemalloc

from memory_profiler import memory_usage

from marqo.core.models.memory_profile import MemoryProfile


def get_memory_profile() -> MemoryProfile:
tracemalloc.start()

snapshot = tracemalloc.take_snapshot()
stats = snapshot.statistics('lineno')

# Get mem used
mem_used = memory_usage(-1, interval=0.1, timeout=1)

return MemoryProfile(
memory_used=mem_used[0],
stats=[str(s) for s in stats]
)
7 changes: 7 additions & 0 deletions src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from marqo.api.route import MarqoCustomRoute
from marqo.core import exceptions as core_exceptions
from marqo.core.index_management.index_management import IndexManagement
from marqo.core.monitoring import memory_profiler
from marqo.logging import get_logger
from marqo.tensor_search import tensor_search, utils
from marqo.tensor_search.enums import RequestType, EnvVars
Expand Down Expand Up @@ -154,6 +155,12 @@ def root():
"version": version.get_version()}


@app.get('/memory')
@utils.enable_debug_apis()
def memory():
return memory_profiler.get_memory_profile()


@app.post("/indexes/{index_name}")
def create_index(index_name: str, settings: IndexSettings, marqo_config: config.Config = Depends(get_config)):
marqo_config.index_management.create_index(settings.to_marqo_index_request(index_name))
Expand Down
80 changes: 43 additions & 37 deletions src/marqo/tensor_search/telemetry.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from typing import Any, Callable, Dict, List, Optional, Union
import json
import time
from collections import defaultdict
from contextlib import contextmanager
import time
from contextvars import ContextVar
from marqo.tensor_search.models.add_docs_objects import AddDocsParams
from typing import Any, Callable, Dict, List, Optional, Union

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response

from marqo.tensor_search.tensor_search_logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -68,7 +69,7 @@ def __init__(self):
self.timers: Dict[str, Timer] = defaultdict(Timer)

def increment_counter(self, k: str):
self.counter[k]+=1
self.counter[k] += 1

@contextmanager
def time(self, k: str, callback: Optional[Callable[[float], None]] = None):
Expand Down Expand Up @@ -106,7 +107,7 @@ def stop(self, k: str) -> float:
logger.warn(f"timer {k} stopped incorrectly. Time not recorded.")

def increment_counter(self, k: str, v: int = 1):
self.counter[k]+=v
self.counter[k] += v

def json(self):
return {
Expand All @@ -117,7 +118,7 @@ def json(self):

class RequestMetricsStore():
current_request: ContextVar[Request] = ContextVar('current_request')

METRIC_STORES: Dict[Request, RequestMetrics] = {}

@classmethod
Expand All @@ -132,7 +133,7 @@ def _get_request(cls) -> Request:
def for_request(cls, r: Optional[Request] = None) -> RequestMetrics:
if r is None:
r = cls._get_request()

return cls.METRIC_STORES[r]

@classmethod
Expand All @@ -147,6 +148,7 @@ def set_in_request(cls, r: Optional[Request] = None, metrics: Optional[RequestMe
@classmethod
def clear_metrics_for(cls, r: Request) -> None:
cls.METRIC_STORES.pop(r, None)
cls.current_request.set(None)


class TelemetryMiddleware(BaseHTTPMiddleware):
Expand All @@ -158,8 +160,9 @@ class TelemetryMiddleware(BaseHTTPMiddleware):

DEFAULT_TELEMETRY_QUERY_PARAM = "telemetry"

def __init__(self, app, **options):
self.telemetry_flag: Optional[str] = options.pop("telemetery_flag", TelemetryMiddleware.DEFAULT_TELEMETRY_QUERY_PARAM)
def __init__(self, app, **options):
self.telemetry_flag: Optional[str] = options.pop("telemetery_flag",
TelemetryMiddleware.DEFAULT_TELEMETRY_QUERY_PARAM)
super().__init__(app, **options)

def telemetry_enabled_for_request(self, request: Request) -> bool:
Expand All @@ -182,37 +185,40 @@ async def dispatch(self, request: Request, call_next: Callable[[], Any]):
"""
RequestMetricsStore.set_in_request(request)
try:
response = await call_next(request)

# Early exit if opentelemetry is not to be injected into response.
if not self.telemetry_enabled_for_request(request):
return response

data = await self.get_response_json(response)

# Inject telemetry and fix content-length header
if isinstance(data, dict):
telemetry = RequestMetricsStore.for_request(request).json()
if len(telemetry["timesMs"]) == 0:
telemetry.pop("timesMs")
if len(telemetry["counter"]) == 0:
telemetry.pop("counter")
data["telemetry"] = telemetry
else:
get_logger(__name__).warning(
f"{self.telemetry_flag} set but response payload is not Dict. telemetry not returned"
)
get_logger(__name__).info(
f"Telemetry data={json.dumps(RequestMetricsStore.for_request(request).json(), indent=2)}")

response = await call_next(request)

# Early exit if opentelemetry is not to be injected into response.
if not self.telemetry_enabled_for_request(request):
return response

data = await self.get_response_json(response)

# Inject telemetry and fix content-length header
if isinstance(data, dict):
telemetry = RequestMetricsStore.for_request(request).json()
if len(telemetry["timesMs"]) == 0:
telemetry.pop("timesMs")
if len(telemetry["counter"]) == 0:
telemetry.pop("counter")
data["telemetry"] = telemetry
else:
get_logger(__name__).warning(
f"{self.telemetry_flag} set but response payload is not Dict. telemetry not returned"
)
get_logger(__name__).info(f"Telemetry data={json.dumps(RequestMetricsStore.for_request(request).json(), indent=2)}")

RequestMetricsStore.clear_metrics_for(request)
finally:
logger.debug('Clearing metrics for request')
RequestMetricsStore.clear_metrics_for(request)

body = json.dumps(data).encode()
response.headers["content-length"] = str(len(body))

return Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
)
14 changes: 14 additions & 0 deletions src/marqo/tensor_search/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,17 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator_function


def enable_debug_apis():
def decorator_function(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if read_env_vars_and_defaults(EnvVars.MARQO_ENABLE_DEBUG_API).lower() != 'true':
raise HTTPException(status_code=403,
detail="This API endpoint is disabled. Please set MARQO_ENABLE_DEBUG_API to true to enable it.")
return func(*args, **kwargs)

return wrapper

return decorator_function
25 changes: 25 additions & 0 deletions tests/tensor_search/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from marqo.core import exceptions as core_exceptions
from marqo.core.models.marqo_index import FieldType
from marqo.core.models.marqo_index_request import FieldRequest
from marqo.tensor_search.enums import EnvVars
from marqo.vespa import exceptions as vespa_exceptions
from tests.marqo_test import MarqoTestCase

Expand All @@ -34,6 +35,30 @@ def test_add_or_replace_documents_tensor_fields(self):
self.assertEqual(response.status_code, 200)
mock_add_documents.assert_called_once()

def test_memory(self):
"""
Test that the memory endpoint returns the expected keys when debug API is enabled.
"""
with patch.dict('os.environ', {EnvVars.MARQO_ENABLE_DEBUG_API: 'TRUE'}):
response = self.client.get("/memory")
data = response.json()
assert set(data.keys()) == {"memory_used", "stats"}

def test_memory_defaultDisabled(self):
"""
Test that the memory endpoint returns 403 by default.
"""
response = self.client.get("/memory")
self.assertEqual(response.status_code, 403)

def test_memory_disabled_403(self):
"""
Test that the memory endpoint returns 403 when debug API is disabled explicitly.
"""
with patch.dict('os.environ', {EnvVars.MARQO_ENABLE_DEBUG_API: 'FALSE'}):
response = self.client.get("/memory")
self.assertEqual(response.status_code, 403)


class TestApiErrors(MarqoTestCase):
"""
Expand Down
3 changes: 2 additions & 1 deletion tests/tensor_search/test_api_exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,5 @@ def test_base_exception_handler_unhandled_error(self, mock_api_exception_handler

self.assertIsInstance(converted_error, api_exceptions.MarqoWebError)
self.assertNotIn("This should not be propagated.", converted_error.message)
self.assertIn("unexpected internal error", converted_error.message)
self.assertIn("unexpected internal error", converted_error.message)

Loading

0 comments on commit a65ed08

Please sign in to comment.