diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 2cc18f6989..2ff561d784 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -7509,7 +7509,7 @@ def get_model_package_args( if source_uri is not None: model_package_args["source_uri"] = source_uri if model_life_cycle is not None: - model_package_args["model_life_cycle"] = model_life_cycle + model_package_args["model_life_cycle"] = model_life_cycle._to_request_dict() if model_card is not None: original_req = model_card._create_request_args() if original_req.get("ModelCardName") is not None: diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 8f98cd076d..e84c1920f4 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -48,6 +48,7 @@ from sagemaker.s3 import S3Uploader from sagemaker.sklearn import SKLearnModel, SKLearnProcessor from sagemaker.mxnet.model import MXNetModel +from sagemaker.model_life_cycle import ModelLifeCycle from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline @@ -1005,11 +1006,11 @@ def test_model_registration_with_model_life_cycle_object( py_version="py3", role=role, ) - create_model_life_cycle = { - "Stage": "Development", - "StageStatus": "In-Progress", - "StageDescription": "Development In Progress", - } + create_model_life_cycle = ModelLifeCycle( + stage="Development", + stage_status="In-Progress", + stage_description="Development In Progress", + ) step_register = RegisterModel( name="MyRegisterModelStep", diff --git a/tests/integ/test_model_package.py b/tests/integ/test_model_package.py index bc8120bd07..1ac8e33fd8 100644 --- a/tests/integ/test_model_package.py +++ b/tests/integ/test_model_package.py @@ -103,7 +103,7 @@ def test_update_model_life_cycle_model_package(sagemaker_session): inference_instances=["ml.m5.large"], transform_instances=["ml.m5.large"], model_package_group_name=model_group_name, - model_life_cycle=create_model_life_cycle._to_request_dict(), + model_life_cycle=create_model_life_cycle, ) desc_model_package = sagemaker_session.sagemaker_client.describe_model_package( diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 8294eb0039..11cc83a463 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -4369,7 +4369,6 @@ def test_register_default_image(sagemaker_session): stage_status="In-Progress", stage_description="Sending for Staging Verification", ) - update_model_life_cycle_req = update_model_life_cycle._to_request_dict() estimator.register( content_types=content_types, @@ -4384,7 +4383,7 @@ def test_register_default_image(sagemaker_session): nearest_model_name=nearest_model_name, data_input_configuration=data_input_config, model_card=model_card, - model_life_cycle=update_model_life_cycle_req, + model_life_cycle=update_model_life_cycle, ) sagemaker_session.create_model.assert_not_called() exp_model_card = {