Skip to content

Commit

Permalink
Merge pull request #1183 from basetenlabs/bump-version-0.9.44
Browse files Browse the repository at this point in the history
Release 0.9.44
  • Loading branch information
rcano-baseten authored Oct 10, 2024
2 parents aa92e33 + c57f21e commit 5a4626e
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.43"
version = "0.9.44"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
3 changes: 2 additions & 1 deletion truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import IO, List, Optional, Tuple

import truss
from truss.constants import PRODUCTION_ENVIRONMENT_NAME
from truss.remote.baseten import custom_types as b10_types
from truss.remote.baseten.api import BasetenApi
from truss.remote.baseten.error import ApiError
Expand Down Expand Up @@ -271,7 +272,7 @@ def create_truss_service(
return model_version_json["id"], model_version_json["version_id"]

if model_id is None:
if environment:
if environment and environment != PRODUCTION_ENVIRONMENT_NAME:
raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING)
model_version_json = api.create_model_from_truss(
model_name=model_name,
Expand Down
111 changes: 111 additions & 0 deletions truss/tests/remote/baseten/test_core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from tempfile import NamedTemporaryFile
from unittest.mock import MagicMock

import pytest
from truss.constants import PRODUCTION_ENVIRONMENT_NAME
from truss.remote.baseten import core
from truss.remote.baseten.api import BasetenApi
from truss.remote.baseten.core import create_truss_service
from truss.remote.baseten.error import ApiError


Expand Down Expand Up @@ -84,3 +87,111 @@ def test_get_prod_version_from_versions_error():
]
prod_version = core.get_prod_version_from_versions(versions)
assert prod_version is None


@pytest.mark.parametrize(
"environment",
[
None,
PRODUCTION_ENVIRONMENT_NAME,
],
)
def test_create_truss_service_handles_eligible_environment_values(environment):
api = MagicMock()
return_value = {
"id": "id",
"version_id": "model_version_id",
}
api.create_model_from_truss.return_value = return_value
model_id, model_version_id = create_truss_service(
api,
"model_name",
"s3_key",
"config",
is_trusted=False,
preserve_previous_prod_deployment=False,
is_draft=False,
model_id=None,
deployment_name="deployment_name",
environment=environment,
)
assert model_id == return_value["id"]
assert model_version_id == return_value["version_id"]
api.create_model_from_truss.assert_called_once()


@pytest.mark.parametrize(
"model_id",
[
"some_model_id",
None,
],
)
def test_create_truss_services_handles_is_draft(model_id):
api = MagicMock()
return_value = {
"id": "id",
"version_id": "model_version_id",
}
api.create_development_model_from_truss.return_value = return_value
model_id, model_version_id = create_truss_service(
api,
"model_name",
"s3_key",
"config",
is_trusted=False,
preserve_previous_prod_deployment=False,
is_draft=True,
model_id=model_id,
deployment_name="deployment_name",
)
assert model_id == return_value["id"]
assert model_version_id == return_value["version_id"]
api.create_development_model_from_truss.assert_called_once()


@pytest.mark.parametrize(
"inputs",
[
{
"environment": None,
"deployment_name": "some deployment",
"is_trusted": True,
"preserve_previous_prod_deployment": False,
},
{
"environment": PRODUCTION_ENVIRONMENT_NAME,
"deployment_name": None,
"is_trusted": True,
"preserve_previous_prod_deployment": False,
},
{
"environment": "staging",
"deployment_name": "some_deployment_name",
"is_trusted": False,
"preserve_previous_prod_deployment": True,
},
],
)
def test_create_truss_service_handles_existing_model(inputs):
api = MagicMock()
return_value = {
"id": "model_version_id",
}
api.create_model_version_from_truss.return_value = return_value
model_id, model_version_id = create_truss_service(
api,
"model_name",
"s3_key",
"config",
is_draft=False,
model_id="model_id",
**inputs,
)

assert model_id == "model_id"
assert model_version_id == return_value["id"]
api.create_model_version_from_truss.assert_called_once()
_, kwargs = api.create_model_version_from_truss.call_args
for k, v in inputs.items():
assert kwargs[k] == v

0 comments on commit 5a4626e

Please sign in to comment.