diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py index 5cc476d742a36..3fcaf0bf178bf 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py @@ -123,7 +123,7 @@ def start_notebook_execution(self): if self.compute: start_execution_params["compute"] = self.compute else: - start_execution_params["compute"] = {"instance_type": "ml.m4.xlarge"} + start_execution_params["compute"] = {"instance_type": "ml.m6i.xlarge"} print(start_execution_params) return self._sagemaker_studio.execution_client.start_execution(**start_execution_params) diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py index bc9d3ad1cd0ce..be81da9b2824c 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py @@ -199,3 +199,49 @@ def test_set_xcom_s3_path(self, mock_set_xcom_s3_path): def test_set_xcom_s3_path_negative_missing_context(self): with pytest.raises(AirflowException, match="context is required"): self.hook._set_xcom_s3_path(self.s3Path, {}) + + def test_start_notebook_execution_default_compute(self): + """Test that default compute uses ml.m6i.xlarge instance type.""" + hook_without_compute = SageMakerNotebookHook( + input_config={ + "input_path": "test-data/notebook/test_notebook.ipynb", + "input_params": {"key": "value"}, + }, + output_config={"output_formats": ["NOTEBOOK"]}, + execution_name="test-execution", + waiter_delay=10, + ) + hook_without_compute._sagemaker_studio = MagicMock() + hook_without_compute._sagemaker_studio.execution_client = MagicMock(spec=ExecutionClient) + hook_without_compute._sagemaker_studio.execution_client.start_execution.return_value = { + "executionId": "123456" + } + + hook_without_compute.start_notebook_execution() + + call_kwargs = hook_without_compute._sagemaker_studio.execution_client.start_execution.call_args[1] + assert call_kwargs["compute"] == {"instance_type": "ml.m6i.xlarge"} + + def test_start_notebook_execution_custom_compute(self): + """Test that custom compute config is used when provided.""" + custom_compute = {"instance_type": "ml.c5.xlarge", "volume_size_in_gb": 50} + hook_with_compute = SageMakerNotebookHook( + input_config={ + "input_path": "test-data/notebook/test_notebook.ipynb", + "input_params": {"key": "value"}, + }, + output_config={"output_formats": ["NOTEBOOK"]}, + execution_name="test-execution", + waiter_delay=10, + compute=custom_compute, + ) + hook_with_compute._sagemaker_studio = MagicMock() + hook_with_compute._sagemaker_studio.execution_client = MagicMock(spec=ExecutionClient) + hook_with_compute._sagemaker_studio.execution_client.start_execution.return_value = { + "executionId": "123456" + } + + hook_with_compute.start_notebook_execution() + + call_kwargs = hook_with_compute._sagemaker_studio.execution_client.start_execution.call_args[1] + assert call_kwargs["compute"] == custom_compute