Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 56 additions & 48 deletions providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@
import os
import urllib.parse
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

from botocore.exceptions import ClientError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink
Expand All @@ -38,12 +36,13 @@
GlueJobCompleteTrigger,
)
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class GlueJobOperator(BaseOperator):
class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
"""
Create an AWS Glue Job.

Expand Down Expand Up @@ -82,7 +81,8 @@ class GlueJobOperator(BaseOperator):
For more information see: https://repost.aws/questions/QUaKgpLBMPSGWO0iq2Fob_bw/glue-run-concurrent-jobs#ANFpCL2fRnQRqgDFuIU_rpvA
"""

template_fields: Sequence[str] = (
aws_hook_class = GlueJobHook
template_fields: Sequence[str] = aws_template_fields(
"job_name",
"script_location",
"script_args",
Expand Down Expand Up @@ -112,8 +112,6 @@ def __init__(
script_args: dict | None = None,
retry_limit: int = 0,
num_of_dpus: int | float | None = None,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
s3_bucket: str | None = None,
iam_role_name: str | None = None,
iam_role_arn: str | None = None,
Expand All @@ -137,8 +135,6 @@ def __init__(
self.script_args = script_args or {}
self.retry_limit = retry_limit
self.num_of_dpus = num_of_dpus
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.s3_bucket = s3_bucket
self.iam_role_name = iam_role_name
self.iam_role_arn = iam_role_arn
Expand All @@ -155,39 +151,49 @@ def __init__(
self.stop_job_run_on_kill = stop_job_run_on_kill
self._job_run_id: str | None = None
self.sleep_before_return: int = sleep_before_return
self.s3_script_location: str | None = None

@cached_property
def glue_job_hook(self) -> GlueJobHook:
@property
def _hook_parameters(self):
# Upload script to S3 before creating the hook.
if self.script_location is None:
s3_script_location = None
elif not self.script_location.startswith(self.s3_protocol):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
script_name = os.path.basename(self.script_location)
s3_hook.load_file(
self.script_location,
self.s3_artifacts_prefix + script_name,
bucket_name=self.s3_bucket,
replace=self.replace_script_file,
)
s3_script_location = f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
else:
s3_script_location = self.script_location
return GlueJobHook(
job_name=self.job_name,
desc=self.job_desc,
concurrent_run_limit=self.concurrent_run_limit,
script_location=s3_script_location,
retry_limit=self.retry_limit,
num_of_dpus=self.num_of_dpus,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
s3_bucket=self.s3_bucket,
iam_role_name=self.iam_role_name,
iam_role_arn=self.iam_role_arn,
create_job_kwargs=self.create_job_kwargs,
update_config=self.update_config,
job_poll_interval=self.job_poll_interval,
self.s3_script_location = None
# location provided, but it's not in S3 yet.
elif self.script_location and self.s3_script_location is None:
if not self.script_location.startswith(self.s3_protocol):
self.upload_etl_script_to_s3()
else:
self.s3_script_location = self.script_location

return {
**super()._hook_parameters,
"job_name": self.job_name,
"desc": self.job_desc,
"concurrent_run_limit": self.concurrent_run_limit,
"script_location": self.s3_script_location,
"retry_limit": self.retry_limit,
"num_of_dpus": self.num_of_dpus,
"aws_conn_id": self.aws_conn_id,
"region_name": self.region_name,
"s3_bucket": self.s3_bucket,
"iam_role_name": self.iam_role_name,
"iam_role_arn": self.iam_role_arn,
"create_job_kwargs": self.create_job_kwargs,
"update_config": self.update_config,
"job_poll_interval": self.job_poll_interval,
}

def upload_etl_script_to_s3(self):
"""Upload the ETL script to S3."""
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
script_name = os.path.basename(self.script_location)
s3_hook.load_file(
self.script_location,
self.s3_artifacts_prefix + script_name,
bucket_name=self.s3_bucket,
replace=self.replace_script_file,
)
self.s3_script_location = f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"

def execute(self, context: Context):
"""
Expand All @@ -200,19 +206,19 @@ def execute(self, context: Context):
self.job_name,
self.wait_for_completion,
)
glue_job_run = self.glue_job_hook.initialize_job(self.script_args, self.run_job_kwargs)
glue_job_run = self.hook.initialize_job(self.script_args, self.run_job_kwargs)
self._job_run_id = glue_job_run["JobRunId"]
glue_job_run_url = GlueJobRunDetailsLink.format_str.format(
aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.glue_job_hook.conn_partition),
region_name=self.glue_job_hook.conn_region_name,
aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.hook.conn_partition),
region_name=self.hook.conn_region_name,
job_name=urllib.parse.quote(self.job_name, safe=""),
job_run_id=self._job_run_id,
)
GlueJobRunDetailsLink.persist(
context=context,
operator=self,
region_name=self.glue_job_hook.conn_region_name,
aws_partition=self.glue_job_hook.conn_partition,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
job_name=urllib.parse.quote(self.job_name, safe=""),
job_run_id=self._job_run_id,
)
Expand All @@ -230,7 +236,7 @@ def execute(self, context: Context):
method_name="execute_complete",
)
elif self.wait_for_completion:
glue_job_run = self.glue_job_hook.job_completion(
glue_job_run = self.hook.job_completion(
self.job_name, self._job_run_id, self.verbose, self.sleep_before_return
)
self.log.info(
Expand All @@ -254,7 +260,7 @@ def on_kill(self):
"""Cancel the running AWS Glue Job."""
if self.stop_job_run_on_kill:
self.log.info("Stopping AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
response = self.glue_job_hook.conn.batch_stop_job_run(
response = self.hook.conn.batch_stop_job_run(
JobName=self.job_name,
JobRunIds=[self._job_run_id],
)
Expand Down Expand Up @@ -290,7 +296,9 @@ class GlueDataQualityOperator(AwsBaseOperator[GlueDataQualityHook]):
"""

aws_hook_class = GlueDataQualityHook
template_fields: Sequence[str] = ("name", "ruleset", "description", "data_quality_ruleset_kwargs")
template_fields: Sequence[str] = aws_template_fields(
"name", "ruleset", "description", "data_quality_ruleset_kwargs"
)

template_fields_renderers = {
"data_quality_ruleset_kwargs": "json",
Expand Down Expand Up @@ -387,7 +395,7 @@ class GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualit

aws_hook_class = GlueDataQualityHook

template_fields: Sequence[str] = (
template_fields: Sequence[str] = aws_template_fields(
"datasource",
"role",
"rule_set_names",
Expand Down Expand Up @@ -553,7 +561,7 @@ class GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQuali
"""

aws_hook_class = GlueDataQualityHook
template_fields: Sequence[str] = (
template_fields: Sequence[str] = aws_template_fields(
"datasource",
"role",
"recommendation_run_kwargs",
Expand Down
62 changes: 62 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,68 @@ def test_replace_script_file(
"folder/file", "artifacts/glue-scripts/file", bucket_name="bucket_name", replace=True
)

assert glue.s3_script_location == "s3://bucket_name/artifacts/glue-scripts/file"

@mock.patch.object(GlueJobHook, "get_job_state")
@mock.patch.object(GlueJobHook, "initialize_job")
@mock.patch.object(GlueJobHook, "get_conn")
@mock.patch.object(GlueJobHook, "conn")
@mock.patch.object(S3Hook, "load_file")
@mock.patch.object(GlueJobOperator, "upload_etl_script_to_s3")
def test_upload_script_to_s3_no_upload(
self,
mock_upload,
mock_load_file,
mock_conn,
mock_get_connection,
mock_initialize_job,
mock_get_job_state,
):
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="s3://my_bucket/folder/file",
s3_bucket="bucket_name",
iam_role_name="role_arn",
replace_script_file=True,
)
mock_initialize_job.return_value = {"JobRunState": "RUNNING", "JobRunId": JOB_RUN_ID}
mock_get_job_state.return_value = "SUCCEEDED"
glue.execute(mock.MagicMock())

assert glue.s3_script_location == "s3://my_bucket/folder/file"
mock_load_file.assert_not_called()
mock_upload.assert_not_called()

@mock.patch.object(GlueJobHook, "get_job_state")
@mock.patch.object(GlueJobHook, "initialize_job")
@mock.patch.object(GlueJobHook, "get_conn")
@mock.patch.object(GlueJobHook, "conn")
@mock.patch.object(S3Hook, "load_file")
@mock.patch.object(GlueJobOperator, "upload_etl_script_to_s3")
def test_no_script_file(
self,
mock_upload,
mock_load_file,
mock_conn,
mock_get_connection,
mock_initialize_job,
mock_get_job_state,
):
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
iam_role_name="role_arn",
replace_script_file=True,
)

mock_initialize_job.return_value = {"JobRunState": "RUNNING", "JobRunId": JOB_RUN_ID}
mock_get_job_state.return_value = "SUCCEEDED"
glue.execute(mock.MagicMock())

assert glue.s3_script_location is None
mock_upload.assert_not_called()

def test_template_fields(self):
operator = GlueJobOperator(
task_id=TASK_ID,
Expand Down