Skip to content

Commit

Permalink
Return 429 for Vespa timeouts (#758)
Browse files Browse the repository at this point in the history
  • Loading branch information
farshidz authored Feb 7, 2024
1 parent d7bfb6b commit b2125a8
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 83 deletions.
28 changes: 19 additions & 9 deletions src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from marqo.tensor_search.throttling.redis_throttle import throttle
from marqo.tensor_search.web import api_validation, api_utils
from marqo.upgrades.upgrade import UpgradeRunner, RollbackRunner
from marqo.vespa import exceptions as vespa_exceptions
from marqo.vespa.vespa_client import VespaClient

logger = get_logger(__name__)
Expand All @@ -48,7 +49,6 @@ def generate_config() -> config.Config:
if __name__ in ["__main__", "api"]:
on_start(_config)


app = FastAPI(
title="Marqo",
version=version.get_version()
Expand All @@ -75,20 +75,30 @@ def marqo_base_exception_handler(request: Request, exc: base_exceptions.MarqoErr
# More specific errors should take precedence

# Core exceptions
(core_exceptions.InvalidFieldNameError, api_exceptions.InvalidFieldNameError),
(core_exceptions.IndexExistsError, api_exceptions.IndexAlreadyExistsError),
(core_exceptions.IndexNotFoundError, api_exceptions.IndexNotFoundError),
(core_exceptions.VespaDocumentParsingError, api_exceptions.BackendDataParsingError),
(core_exceptions.InvalidFieldNameError, api_exceptions.InvalidFieldNameError, None),
(core_exceptions.IndexExistsError, api_exceptions.IndexAlreadyExistsError, None),
(core_exceptions.IndexNotFoundError, api_exceptions.IndexNotFoundError, None),
(core_exceptions.VespaDocumentParsingError, api_exceptions.BackendDataParsingError, None),

# Vespa client exceptions
(
vespa_exceptions.VespaTimeoutError,
api_exceptions.TooManyRequestsError,
"Throttled by vector store. Try your request again later."
),

# Base exceptions
(base_exceptions.InternalError, api_exceptions.InternalError),
(base_exceptions.InvalidArgumentError, api_exceptions.InvalidArgError),
(base_exceptions.InternalError, api_exceptions.InternalError, None),
(base_exceptions.InvalidArgumentError, api_exceptions.InvalidArgError, None),
]

converted_error = None
for base_exception, api_exception in api_exception_mappings:
for base_exception, api_exception, message in api_exception_mappings:
if isinstance(exc, base_exception):
converted_error = api_exception(exc.message)
if message is None:
converted_error = api_exception(exc.message)
else:
converted_error = api_exception(message)
break

# Completely unhandled exception (500)
Expand Down
7 changes: 7 additions & 0 deletions src/marqo/vespa/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,12 @@ def __str__(self) -> str:
return super().__str__()


class VespaTimeoutError(VespaStatusError):
"""
Raised when Vespa responds with a timeout error.
"""
pass


class InvalidVespaApplicationError(VespaError):
pass
42 changes: 27 additions & 15 deletions src/marqo/vespa/models/query_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from pydantic import BaseModel, Field


# See https://docs.vespa.ai/en/reference/default-result-format.html
class RootFields(BaseModel):
total_count: int = Field(alias='totalCount')
total_count: Optional[int] = Field(alias='totalCount')


class Degraded(BaseModel):
adaptive_timeout: bool = Field(alias='adaptive-timeout')
match_phase: bool = Field(alias='match-phase')
non_ideal_state: bool = Field(alias='non-ideal-state')
timeout: bool
adaptive_timeout: Optional[bool] = Field(alias='adaptive-timeout')
match_phase: Optional[bool] = Field(alias='match-phase')
non_ideal_state: Optional[bool] = Field(alias='non-ideal-state')
timeout: Optional[bool]


class Coverage(BaseModel):
Expand All @@ -24,32 +25,43 @@ class Coverage(BaseModel):
results_full: int = Field(alias='resultsFull')


class Child(BaseModel):
class Error(BaseModel):
code: int
summary: Optional[str]
source: Optional[str]
message: Optional[str]
stack_trace: Optional[str] = Field(alias='stackTrace')
transient: Optional[bool]


class AbstractChild(BaseModel):
# label, value, and recursive children occur in aggregation results
id: str
id: Optional[str]
relevance: float
source: Optional[str]
label: Optional[str]
value: Optional[str]
fields: Optional[Dict[str, Any]]
coverage: Optional[Coverage]
errors: Optional[List[Error]]
children: Optional[List['Child']]


class Root(BaseModel):
id: str
relevance: float
fields: RootFields
coverage: Coverage
children: List[Child] = []
class Child(AbstractChild):
fields: Optional[Dict[str, Any]]


class Root(AbstractChild):
fields: Optional[RootFields]


class QueryResult(BaseModel):
root: Root
timing: Optional[Dict[str, Any]]
trace: Optional[Dict[str, Any]]

@property
def hits(self) -> List[Child]:
return self.root.children
return self.root.children or []

@property
def total_count(self) -> int:
Expand Down
38 changes: 35 additions & 3 deletions src/marqo/vespa/vespa_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import marqo.logging
import marqo.vespa.concurrency as conc
from marqo.vespa.exceptions import VespaStatusError, VespaError, InvalidVespaApplicationError
from marqo.vespa.exceptions import VespaStatusError, VespaError, InvalidVespaApplicationError, VespaTimeoutError
from marqo.vespa.models import VespaDocument, QueryResult, FeedBatchDocumentResponse, FeedBatchResponse, \
FeedDocumentResponse
from marqo.vespa.models.application_metrics import ApplicationMetrics
Expand Down Expand Up @@ -175,7 +175,7 @@ def query(self, yql: str, hits: int = 10, ranking: str = None, model_restrict: s
except httpx.HTTPError as e:
raise VespaError(e) from e

self._raise_for_status(resp)
self._query_raise_for_status(resp)

return QueryResult(**resp.json())

Expand Down Expand Up @@ -688,7 +688,39 @@ async def _delete_document_async(self,

self._raise_for_status(resp)

def _raise_for_status(self, resp) -> None:
def _query_raise_for_status(self, resp: httpx.Response) -> None:
"""
Query API specific raise for status method.
"""
# See error codes here https://github.com/vespa-engine/vespa/blob/master/container-core/src/main/java/com/yahoo/container/protect/Error.java
try:
resp.raise_for_status()
except httpx.HTTPStatusError as e:
try:
result = QueryResult(**resp.json())
if (
result.root.errors is not None
and len(result.root.errors) > 0
):
if resp.status_code == 504 and result.root.errors[0].code == 12:
raise VespaTimeoutError(message=resp.text, cause=e) from e
elif (
result.root.errors[0].code == 8
and result.root.errors[
0].message == "Search request soft doomed during query setup and initialization."
):
# The soft doom error is a bug in certain Vespa versions. Newer versions should always return
# a code 12 for timeouts
logger.warn('Detected soft doomed query')
raise VespaTimeoutError(message=resp.text, cause=e) from e

raise e
except VespaStatusError:
raise
except Exception:
raise VespaStatusError(message=resp.text, cause=e) from e

def _raise_for_status(self, resp: httpx.Response) -> None:
try:
resp.raise_for_status()
except httpx.HTTPStatusError as e:
Expand Down
30 changes: 29 additions & 1 deletion tests/marqo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from marqo import config, version
from marqo.core.index_management.index_management import IndexManagement
from marqo.core.monitoring.monitoring import Monitoring
from marqo.core.models.marqo_index import *
from marqo.core.models.marqo_index_request import (StructuredMarqoIndexRequest, UnstructuredMarqoIndexRequest,
FieldRequest, MarqoIndexRequest)
from marqo.core.monitoring.monitoring import Monitoring
from marqo.tensor_search.telemetry import RequestMetricsStore
from marqo.vespa.vespa_client import VespaClient

Expand Down Expand Up @@ -69,6 +69,9 @@ def clear_indexes(self, indexes: List[MarqoIndex]):
def clear_index_by_name(self, index_name: str):
self.pyvespa_client.delete_all_docs(self.CONTENT_CLUSTER, index_name)

def random_index_name(self) -> str:
return 'a' + str(uuid.uuid4()).replace('-', '')

@classmethod
def structured_marqo_index(
cls,
Expand Down Expand Up @@ -211,6 +214,31 @@ def unstructured_marqo_index_request(
updated_at=updated_at
)

class _AssertRaisesContext:
def __init__(self, expected_exception):
self.expected_exception = expected_exception

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, tb):
if exc_type is None:
raise AssertionError(f"No exception raised, expected: '{self.expected_exception.__name__}'")
if issubclass(exc_type, self.expected_exception) and exc_type is not self.expected_exception:
raise AssertionError(
f"Subclass of '{self.expected_exception.__name__}' "
f"raised: '{exc_type.__name__}', expected exact exception.")
if exc_type is not self.expected_exception:
raise AssertionError(
f"Wrong exception raised: '{exc_type.__name__}', expected: '{self.expected_exception.__name__}'")
return True

def assertRaisesStrict(self, expected_exception):
"""
Assert that a specific exception is raised. Will not pass for subclasses of the expected exception.
"""
return self._AssertRaisesContext(expected_exception)


class AsyncMarqoTestCase(unittest.IsolatedAsyncioTestCase, MarqoTestCase):
pass
Loading

0 comments on commit b2125a8

Please sign in to comment.