Skip to content

Commit

Permalink
Rename DatabricksSqlOperator's fields' names to comply with templat…
Browse files Browse the repository at this point in the history
…ed fields validation (#38052)
  • Loading branch information
shahar1 authored Mar 15, 2024
1 parent cbb0cad commit 4742fc0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 11 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ repos:
^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$|
^airflow\/providers\/apache\/spark\/operators\/spark_submit.py\.py$|
^airflow\/providers\/apache\/spark\/operators\/spark_submit\.py$|
^airflow\/providers\/databricks\/operators\/databricks_sql\.py$|
)$
- id: ruff
name: Run 'ruff' for extremely fast Python linting
Expand Down
20 changes: 10 additions & 10 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ class DatabricksCopyIntoOperator(BaseOperator):
"""

template_fields: Sequence[str] = (
"_file_location",
"_files",
"_table_name",
"file_location",
"files",
"table_name",
"databricks_conn_id",
)

Expand Down Expand Up @@ -249,17 +249,17 @@ def __init__(
raise AirflowException("file_location shouldn't be empty")
if file_format not in COPY_INTO_APPROVED_FORMATS:
raise AirflowException(f"file_format '{file_format}' isn't supported")
self._files = files
self.files = files
self._pattern = pattern
self._file_format = file_format
self.databricks_conn_id = databricks_conn_id
self._http_path = http_path
self._sql_endpoint_name = sql_endpoint_name
self.session_config = session_configuration
self._table_name = table_name
self.table_name = table_name
self._catalog = catalog
self._schema = schema
self._file_location = file_location
self.file_location = file_location
self._expression_list = expression_list
self._credential = credential
self._storage_credential = storage_credential
Expand Down Expand Up @@ -313,14 +313,14 @@ def _create_sql_query(self) -> str:
if self._credential is not None:
maybe_credential = self._generate_options("CREDENTIAL", escaper, self._credential, False)
maybe_with = f" WITH ({maybe_credential} {maybe_encryption})"
location = escaper.escape_item(self._file_location) + maybe_with
location = escaper.escape_item(self.file_location) + maybe_with
if self._expression_list is not None:
location = f"(SELECT {self._expression_list} FROM {location})"
files_or_pattern = ""
if self._pattern is not None:
files_or_pattern = f"PATTERN = {escaper.escape_item(self._pattern)}\n"
elif self._files is not None:
files_or_pattern = f"FILES = {escaper.escape_item(self._files)}\n"
elif self.files is not None:
files_or_pattern = f"FILES = {escaper.escape_item(self.files)}\n"
format_options = self._generate_options("FORMAT_OPTIONS", escaper, self._format_options) + "\n"
copy_options = self._generate_options("COPY_OPTIONS", escaper, self._copy_options) + "\n"
storage_cred = ""
Expand All @@ -340,7 +340,7 @@ def _create_sql_query(self) -> str:
else:
raise AirflowException(f"Incorrect data type for validate parameter: {type(self._validate)}")
# TODO: think on how to make sure that table_name and expression_list aren't used for SQL injection
sql = f"""COPY INTO {self._table_name}{storage_cred}
sql = f"""COPY INTO {self.table_name}{storage_cred}
FROM {location}
FILEFORMAT = {self._file_format}
{validation}{files_or_pattern}{format_options}{copy_options}
Expand Down
24 changes: 24 additions & 0 deletions tests/providers/databricks/operators/test_databricks_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.databricks.operators.databricks_sql import DatabricksCopyIntoOperator
from airflow.utils import timezone

DATE = "2017-04-20"
TASK_ID = "databricks-sql-operator"
Expand Down Expand Up @@ -228,3 +229,26 @@ def test_incorrect_params_wrong_format():
file_format=file_format,
table_name="abc",
)


@pytest.mark.db_test
def test_templating(create_task_instance_of_operator):
ti = create_task_instance_of_operator(
DatabricksCopyIntoOperator,
# Templated fields
file_location="{{ 'file-location' }}",
files="{{ 'files' }}",
table_name="{{ 'table-name' }}",
databricks_conn_id="{{ 'databricks-conn-id' }}",
# Other parameters
file_format="JSON",
dag_id="test_template_body_templating_dag",
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
ti.render_templates()
task: DatabricksCopyIntoOperator = ti.task
assert task.file_location == "file-location"
assert task.files == "files"
assert task.table_name == "table-name"
assert task.databricks_conn_id == "databricks-conn-id"

0 comments on commit 4742fc0

Please sign in to comment.