Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prep for Smoke Tests #1296

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def test_chain():
File \".*?/itest_chain\.py\", line \d+, in _accumulate_parts
value \+= self\._text_to_num\.run_remote\(part\)
ValueError: \(showing chained remote errors, root error at the bottom\)
├─ Error in dependency Chainlet `TextToNum`:
├─ Error in dependency Chainlet `TextToNum` \(HTTP status 500\):
│ Chainlet-Traceback \(most recent call last\):
│ File \".*?/itest_chain\.py\", line \d+, in run_remote
│ generated_text = self\._replicator\.run_remote\(data\)
│ ValueError: \(showing chained remote errors, root error at the bottom\)
│ ├─ Error in dependency Chainlet `TextReplicator`:
│ ├─ Error in dependency Chainlet `TextReplicator` \(HTTP status 500\):
│ │ Chainlet-Traceback \(most recent call last\):
│ │ File \".*?/itest_chain\.py\", line \d+, in run_remote
│ │ validate_data\(data\)
Expand Down
2 changes: 2 additions & 0 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DeployedServiceDescriptor,
DeploymentContext,
DockerImage,
GenericRemoteException,
RemoteConfig,
RemoteErrorDetail,
RPCOptions,
Expand Down Expand Up @@ -55,6 +56,7 @@
"DeploymentContext",
"DockerImage",
"RPCOptions",
"GenericRemoteException",
"RemoteConfig",
"RemoteErrorDetail",
"DeployedServiceDescriptor",
Expand Down
31 changes: 25 additions & 6 deletions truss-chains/truss_chains/deployment/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,10 @@ class ChainService(abc.ABC):
"""

_name: str
_entrypoint_service: b10_service.TrussService
_entrypoint_fake_json_data: Any

def __init__(self, name: str, entrypoint_service: b10_service.TrussService):
def __init__(self, name: str):
self._name = name
self._entrypoint_service = entrypoint_service
self._entrypoint_fake_json_data = None

@property
Expand All @@ -151,12 +149,12 @@ def status_page_url(self) -> str:
def run_remote_url(self) -> str:
"""URL to invoke the entrypoint."""

@abc.abstractmethod
def run_remote(self, json: Dict) -> Any:
"""Invokes the entrypoint with JSON data.

Returns:
The JSON response."""
return self._entrypoint_service.predict(json)

@abc.abstractmethod
def get_info(self) -> list[b10_types.DeployedChainlet]:
Expand All @@ -183,7 +181,10 @@ def entrypoint_fake_json_data(self, fake_data: Any) -> None:


class BasetenChainService(ChainService):
# TODO: entrypoint service is for truss model - make chains-specific.
# E.g. chain/chainlet will not have model URLs anymore.
_chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic
_entrypoint_service: b10_service.BasetenService
_remote: b10_remote.BasetenRemote

def __init__(
Expand All @@ -193,8 +194,9 @@ def __init__(
chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic,
remote: b10_remote.BasetenRemote,
) -> None:
super().__init__(name, entrypoint_service)
super().__init__(name)
self._chain_deployment_handle = chain_deployment_handle
self._entrypoint_service = entrypoint_service
self._remote = remote

@property
Expand All @@ -208,6 +210,13 @@ def run_remote_url(self) -> str:
self._chain_deployment_handle.is_draft,
)

def run_remote(self, json: Dict) -> Any:
"""Invokes the entrypoint with JSON data.

Returns:
The JSON response."""
return self._entrypoint_service.predict(json)

@property
def status_page_url(self) -> str:
"""Link to status page on Baseten."""
Expand All @@ -231,14 +240,24 @@ def get_info(self) -> list[b10_types.DeployedChainlet]:


class DockerChainService(ChainService):
_entrypoint_service: DockerTrussService

def __init__(self, name: str, entrypoint_service: DockerTrussService) -> None:
super().__init__(name, entrypoint_service)
super().__init__(name)
self._entrypoint_service = entrypoint_service

@property
def run_remote_url(self) -> str:
"""URL to invoke the entrypoint."""
return self._entrypoint_service.predict_url

def run_remote(self, json: Dict) -> Any:
"""Invokes the entrypoint with JSON data.

Returns:
The JSON response."""
return self._entrypoint_service.predict(json)

@property
def status_page_url(self) -> str:
"""Not Implemented.."""
Expand Down
6 changes: 4 additions & 2 deletions truss-chains/truss_chains/remote_chainlet/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,17 @@ def _make_request_params(
if isinstance(inputs, pydantic.BaseModel):
if self._service_descriptor.options.use_binary:
data_dict = inputs.model_dump(mode="python")
kwargs["data"] = serialization.truss_msgpack_serialize(data_dict)
data_key = "content" if for_httpx else "data"
kwargs[data_key] = serialization.truss_msgpack_serialize(data_dict)
headers["Content-Type"] = "application/octet-stream"
else:
data_key = "content" if for_httpx else "data"
kwargs[data_key] = inputs.model_dump_json()
headers["Content-Type"] = "application/json"
else: # inputs is JSON dict.
if self._service_descriptor.options.use_binary:
kwargs["data"] = serialization.truss_msgpack_serialize(inputs)
data_key = "content" if for_httpx else "data"
kwargs[data_key] = serialization.truss_msgpack_serialize(inputs)
headers["Content-Type"] = "application/octet-stream"
else:
kwargs["json"] = inputs
Expand Down
66 changes: 28 additions & 38 deletions truss-chains/truss_chains/remote_chainlet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,20 +210,24 @@ def _resolve_exception_class(
return exception_cls


def _handle_response_error(response_json: dict, remote_name: str):
def _handle_response_error(response_json: dict, remote_name: str, status: int):
try:
error_json = response_json["error"]
except KeyError as e:
logging.error(f"response_json: {response_json}")
raise ValueError(
"Could not get `error` field from JSON from chainlet error response"
"Could not get `error` field from JSON from chainlet "
f"error response. HTTP status: {status}."
) from e

try:
error = definitions.RemoteErrorDetail.model_validate(error_json)
except pydantic.ValidationError as e:
if isinstance(error_json, str):
msg = f"Remote error occurred in `{remote_name}`: '{error_json}'"
msg = (
f"Remote error occurred in `{remote_name}` "
f"(HTTP status {status}): '{error_json}'"
)
raise definitions.GenericRemoteException(msg) from None
raise ValueError(
"Could not parse chainlet error. Error details are expected to be either a "
Expand All @@ -238,7 +242,7 @@ def _handle_response_error(response_json: dict, remote_name: str):
error_format = "\n".join(lines + [last_line])
msg = (
f"(showing chained remote errors, root error at the bottom)\n"
f"├─ Error in dependency Chainlet `{remote_name}`:\n"
f"├─ Error in dependency Chainlet `{remote_name}` (HTTP status {status}):\n"
f"{error_format}"
)
raise exception_cls(msg)
Expand All @@ -255,38 +259,24 @@ def response_raise_errors(response: httpx.Response, remote_name: str) -> None:
Chainlet that raised an exception. E.g. the message might look like this:

```
RemoteChainletError in "Chain"
Traceback (most recent call last):
File "/app/model/Chainlet.py", line 112, in predict
result = await self._chainlet.run(
File "/app/model/Chainlet.py", line 79, in run
value += self._text_to_num.run(part)
File "/packages/remote_stubs.py", line 21, in run
json_result = self.predict_sync(json_args)
File "/packages/truss_chains/stub.py", line 37, in predict_sync
return utils.handle_response(
ValueError: (showing remote errors, root message at the bottom)
--> Preceding Remote Cause:
RemoteChainletError in "TextToNum"
Traceback (most recent call last):
File "/app/model/Chainlet.py", line 113, in predict
result = self._chainlet.run(data=payload["data"])
File "/app/model/Chainlet.py", line 54, in run
generated_text = self._replicator.run(data)
File "/packages/remote_stubs.py", line 7, in run
json_result = self.predict_sync(json_args)
File "/packages/truss_chains/stub.py", line 37, in predict_sync
return utils.handle_response(
ValueError: (showing remote errors, root message at the bottom)
--> Preceding Remote Cause:
RemoteChainletError in "TextReplicator"
Traceback (most recent call last):
File "/app/model/Chainlet.py", line 112, in predict
result = self._chainlet.run(data=payload["data"])
File "/app/model/Chainlet.py", line 36, in run
raise ValueError(f"This input is too long: {len(data)}.")
ValueError: This input is too long: 100.

Chainlet-Traceback (most recent call last):
File "/packages/itest_chain.py", line 132, in run_remote
value = self._accumulate_parts(text_parts.parts)
File "/packages/itest_chain.py", line 144, in _accumulate_parts
value += self._text_to_num.run_remote(part)
ValueError: (showing chained remote errors, root error at the bottom)
├─ Error in dependency Chainlet `TextToNum` (HTTP status 500):
│ Chainlet-Traceback (most recent call last):
│ File "/packages/itest_chain.py", line 87, in run_remote
│ generated_text = self._replicator.run_remote(data)
│ ValueError: (showing chained remote errors, root error at the bottom)
│ ├─ Error in dependency Chainlet `TextReplicator` (HTTP status 500):
│ │ Chainlet-Traceback (most recent call last):
│ │ File "/packages/itest_chain.py", line 52, in run_remote
│ │ validate_data(data)
│ │ File "/packages/itest_chain.py", line 36, in validate_data
│ │ raise ValueError(f"This input is too long: {len(data)}.")
╰ ╰ ValueError: This input is too long: 100.
```
"""
if response.is_error:
Expand All @@ -297,7 +287,7 @@ def response_raise_errors(response: httpx.Response, remote_name: str) -> None:
"Could not get JSON from error response. Status: "
f"`{response.status_code}`."
) from e
_handle_response_error(response_json=response_json, remote_name=remote_name)
_handle_response_error(response_json, remote_name, response.status_code)


async def async_response_raise_errors(
Expand All @@ -312,4 +302,4 @@ async def async_response_raise_errors(
"Could not get JSON from error response. Status: "
f"`{response.status}`."
) from e
_handle_response_error(response_json=response_json, remote_name=remote_name)
_handle_response_error(response_json, remote_name, response.status)
7 changes: 4 additions & 3 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class Config:


class BasetenRemote(TrussRemote):
def __init__(self, remote_url: str, api_key: str, **kwargs):
super().__init__(remote_url, **kwargs)
def __init__(self, remote_url: str, api_key: str):
super().__init__(remote_url)
self._auth_service = AuthService(api_key=api_key)
self._api = BasetenApi(remote_url, self._auth_service)

Expand Down Expand Up @@ -310,7 +310,8 @@ def push_chain_atomic(

model_id = chain_deployment_handle.entrypoint_model_id
model_version_id = chain_deployment_handle.entrypoint_model_version_id

# TODO: entrypoint service is for truss model - make chains-specific.
# E.g. chain/chainlet will not have model URLs anymore.
entrypoint_service = BasetenService(
model_id=model_id,
model_version_id=model_version_id,
Expand Down
3 changes: 2 additions & 1 deletion truss/remote/remote_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import os

try:
from configparser import DEFAULTSECT, ConfigParser # type: ignore
Expand All @@ -16,7 +17,7 @@
from truss.remote.baseten import BasetenRemote
from truss.remote.truss_remote import RemoteConfig, TrussRemote

USER_TRUSSRC_PATH = Path("~/.trussrc").expanduser()
USER_TRUSSRC_PATH = Path(os.environ.get("USER_TRUSSRC_PATH", "~/.trussrc")).expanduser()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice



def load_config() -> ConfigParser:
Expand Down
2 changes: 1 addition & 1 deletion truss/remote/truss_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class TrussRemote(ABC):

"""

def __init__(self, remote_url: str, **kwargs) -> None:
def __init__(self, remote_url: str) -> None:
self._remote_url = remote_url

@property
Expand Down