Skip to content

Commit

Permalink
Improve e2e test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
nnarayen committed Feb 4, 2025
1 parent b3d5a40 commit 21a88eb
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 17 deletions.
5 changes: 0 additions & 5 deletions truss-chains/tests/openai/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ def __init__(self):
def load(self):
self._predict_count = 0
self._completions_count = 0
self._chat_completions_count = 0

async def predict(self, inputs: Dict) -> int:
self._predict_count += inputs["increment"]
Expand All @@ -17,7 +16,3 @@ async def predict(self, inputs: Dict) -> int:
async def completions(self, inputs: Dict) -> int:
self._completions_count += inputs["increment"]
return self._completions_count

async def chat_completions(self, inputs: Dict) -> int:
self._chat_completions_count += inputs["increment"]
return self._chat_completions_count
4 changes: 2 additions & 2 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ def test_custom_openai_endpoints():
assert response.status_code == 200
assert response.json() == 2

# Written model intentionally does not support chat completions
response = requests.post(
f"{base_url}/v1/chat/completions", json={"increment": 3}
)
assert response.status_code == 200
assert response.json() == 3
assert response.status_code == 404
17 changes: 8 additions & 9 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,16 +250,16 @@ def _safe_extract_descriptor(
def from_model(cls, model) -> "ModelDescriptor":
preprocess = cls._safe_extract_descriptor(model, ModelMethod.PREPROCESS.value)
predict = cls._safe_extract_descriptor(model, ModelMethod.PREDICT.value)
if predict and predict.arg_config == ArgConfig.REQUEST_ONLY:
if predict is None:
raise errors.ModelDefinitionError(
f"Truss model must have a `{ModelMethod.PREDICT.value}` method."
)
elif preprocess and predict.arg_config == ArgConfig.REQUEST_ONLY:
raise errors.ModelDefinitionError(
f"When using `{ModelMethod.PREPROCESS.value}`, the {ModelMethod.PREDICT.value} method "
f"cannot only have the request argument (because the result of `{ModelMethod.PREPROCESS.value}` "
"would be discarded)."
)
elif predict is None:
raise errors.ModelDefinitionError(
f"Truss model must have a `{ModelMethod.PREDICT.value}` method."
)

postprocess = cls._safe_extract_descriptor(model, ModelMethod.POSTPROCESS.value)
if postprocess and postprocess.arg_config == ArgConfig.REQUEST_ONLY:
Expand All @@ -271,7 +271,7 @@ def from_model(cls, model) -> "ModelDescriptor":
completions = cls._safe_extract_descriptor(model, ModelMethod.COMPLETIONS.value)
chats = cls._safe_extract_descriptor(model, ModelMethod.CHAT_COMPLETIONS.value)
is_healthy = cls._safe_extract_descriptor(model, ModelMethod.IS_HEALTHY.value)
if is_healthy and is_healthy.arg_config is not ArgConfig.NONE:
if is_healthy and is_healthy.arg_config != ArgConfig.NONE:
raise errors.ModelDefinitionError(
f"`{ModelMethod.IS_HEALTHY.value}` must have only one argument: `self`."
)
Expand Down Expand Up @@ -378,8 +378,7 @@ def load(self):
self._logger.info(
f"Completed model.load() execution in {_elapsed_ms(start_time)} ms"
)
except Exception as e:
raise e
except Exception:
self._logger.exception("Exception while loading model")
self._status = ModelWrapper.Status.FAILED

Expand Down Expand Up @@ -712,7 +711,7 @@ async def completions(
) -> OutputType:
descriptor = self.model_descriptor.completions
assert descriptor, (
f"`{ModelMethod.COMPLETIONS}` must only be called if model has it."
f"`{ModelMethod.COMPLETIONS.value}` must only be called if model has it."
)

async def exec_fn(
Expand Down
1 change: 0 additions & 1 deletion truss/tests/templates/server/test_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ async def mock_predict(return_value, request: Request):
@pytest.mark.anyio
async def test_open_ai_completion_endpoints(open_ai_container_fs, helpers):
app_path = open_ai_container_fs / "app"
print(app_path)
with _clear_model_load_modules(), helpers.sys_paths(app_path), _change_directory(
app_path
):
Expand Down

0 comments on commit 21a88eb

Please sign in to comment.