From 1a9b11a4ed3605ec82627d003e3b444fb7e20874 Mon Sep 17 00:00:00 2001 From: ron-damon Date: Sat, 3 Jul 2021 19:57:32 -0300 Subject: [PATCH 1/2] add run_job_kwargs to glue job run --- airflow/providers/amazon/aws/hooks/glue.py | 9 +++++++-- airflow/providers/amazon/aws/operators/glue.py | 6 +++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index 74e6416b9c4d2..a5e43278dc35a 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -95,7 +95,11 @@ def get_iam_execution_role(self) -> Dict: self.log.error("Failed to create aws glue job, error: %s", general_error) raise - def initialize_job(self, script_arguments: Optional[dict] = None) -> Dict[str, str]: + def initialize_job( + self, + script_arguments: Optional[dict] = None, + run_kwargs: Optional[dict] = None, + ) -> Dict[str, str]: """ Initializes connection with AWS Glue to run job @@ -103,10 +107,11 @@ def initialize_job(self, script_arguments: Optional[dict] = None) -> Dict[str, s """ glue_client = self.get_conn() script_arguments = script_arguments or {} + run_kwargs = run_kwargs or {} try: job_name = self.get_or_create_glue_job() - job_run = glue_client.start_job_run(JobName=job_name, Arguments=script_arguments) + job_run = glue_client.start_job_run(JobName=job_name, Arguments=script_arguments, **run_kwargs) return job_run except Exception as general_error: self.log.error("Failed to run aws glue job, error: %s", general_error) diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 81d3468d592fd..5951177c49f14 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -52,6 +52,8 @@ class AwsGlueJobOperator(BaseOperator): :type iam_role_name: Optional[str] :param create_job_kwargs: Extra arguments for Glue Job Creation :type create_job_kwargs: Optional[dict] + :param run_job_kwargs: Extra arguments for Glue Job Run + :type run_job_kwargs: Optional[dict] """ template_fields = ('script_args',) @@ -77,6 +79,7 @@ def __init__( s3_bucket: Optional[str] = None, iam_role_name: Optional[str] = None, create_job_kwargs: Optional[dict] = None, + run_job_kwargs: Optional[dict] = None, **kwargs, ): super().__init__(**kwargs) @@ -94,6 +97,7 @@ def __init__( self.s3_protocol = "s3://" self.s3_artifacts_prefix = 'artifacts/glue-scripts/' self.create_job_kwargs = create_job_kwargs + self.run_job_kwargs = run_job_kwargs or {} def execute(self, context): """ @@ -124,7 +128,7 @@ def execute(self, context): create_job_kwargs=self.create_job_kwargs, ) self.log.info("Initializing AWS Glue Job: %s", self.job_name) - glue_job_run = glue_job.initialize_job(self.script_args) + glue_job_run = glue_job.initialize_job(self.script_args, self.run_job_kwargs) glue_job_run = glue_job.job_completion(self.job_name, glue_job_run['JobRunId']) self.log.info( "AWS Glue Job: %s status: %s. Run Id: %s", From c11e92eec29c5572b9843482c0f173e4254675c3 Mon Sep 17 00:00:00 2001 From: ron-damon Date: Thu, 30 Sep 2021 09:45:05 -0300 Subject: [PATCH 2/2] add run_kwargs to hook and operator tests --- tests/providers/amazon/aws/hooks/test_glue.py | 3 ++- tests/providers/amazon/aws/operators/test_glue.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py index 8d3ebf3090ccc..946af6382b3b5 100644 --- a/tests/providers/amazon/aws/hooks/test_glue.py +++ b/tests/providers/amazon/aws/hooks/test_glue.py @@ -81,6 +81,7 @@ def test_get_or_create_glue_job(self, mock_get_conn, mock_get_iam_execution_role def test_initialize_job(self, mock_get_conn, mock_get_or_create_glue_job, mock_get_job_state): some_data_path = "s3://glue-datasets/examples/medicare/SampleData.csv" some_script_arguments = {"--s3_input_data_path": some_data_path} + some_run_kwargs = {"NumberOfWorkers": 5} some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py" some_s3_bucket = "my-includes" @@ -96,7 +97,7 @@ def test_initialize_job(self, mock_get_conn, mock_get_or_create_glue_job, mock_g s3_bucket=some_s3_bucket, region_name=self.some_aws_region, ) - glue_job_run = glue_job_hook.initialize_job(some_script_arguments) + glue_job_run = glue_job_hook.initialize_job(some_script_arguments, some_run_kwargs) glue_job_run_state = glue_job_hook.get_job_state(glue_job_run['JobName'], glue_job_run['JobRunId']) assert glue_job_run_state == mock_job_run_state, 'Mocks but be equal' diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py index e0693fc9f37b2..aed0b94ae6af8 100644 --- a/tests/providers/amazon/aws/operators/test_glue.py +++ b/tests/providers/amazon/aws/operators/test_glue.py @@ -58,5 +58,5 @@ def test_execute_without_failure( mock_initialize_job.return_value = {'JobRunState': 'RUNNING', 'JobRunId': '11111'} mock_get_job_state.return_value = 'SUCCEEDED' glue.execute(None) - mock_initialize_job.assert_called_once_with({}) + mock_initialize_job.assert_called_once_with({}, {}) assert glue.job_name == 'my_test_job'