diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index f54012138bcd8..96608d549f8a0 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -215,6 +215,7 @@ botocore BounceX Bq bq +bteq bugfix bugfixes buildType diff --git a/providers/teradata/docs/operators/bteq.rst b/providers/teradata/docs/operators/bteq.rst new file mode 100644 index 0000000000000..2424aa9855e1f --- /dev/null +++ b/providers/teradata/docs/operators/bteq.rst @@ -0,0 +1,264 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:BteqOperator: + +BteqOperator +============ + +The :class:`~airflow.providers.teradata.operators.bteq.BteqOperator` enables execution of SQL statements or BTEQ (Basic Teradata Query) scripts using the Teradata BTEQ utility, which can be installed either locally or accessed remotely via SSH. + +This is useful for executing administrative operations, batch queries, or ETL tasks in Teradata environments using the Teradata BTEQ utility. + +.. note:: + + This operator requires the Teradata Tools and Utilities (TTU) including the ``bteq`` binary to be installed + and accessible via the system's ``PATH`` (either locally or on the remote SSH host). + +Use the ``BteqOperator`` when you want to: + +- Run parameterized or templated SQL/BTEQ scripts +- Connect securely to Teradata with Airflow connections +- Execute queries via SSH on remote systems with BTEQ installed + +Prerequisite +------------ + +Make sure your Teradata Airflow connection is defined with the required fields: + +- ``host`` +- ``login`` +- ``password`` +- Optional: ``database``, etc. + +You can define a remote host with a separate SSH connection using the ``ssh_conn_id``. + +.. note:: + + For improved security, it is **highly recommended** to use + **private key-based SSH authentication** (SSH key pairs) instead of username/password + for the SSH connection. + + This avoids password exposure, enables seamless automated execution, and enhances security. + + See the Airflow SSH Connection documentation for details on configuring SSH keys: + https://airflow.apache.org/docs/apache-airflow/stable/howto/connection/ssh.html + + +To execute arbitrary SQL or BTEQ commands in a Teradata database, use the +:class:`~airflow.providers.teradata.operators.bteq.BteqOperator`. + +Common Database Operations with BteqOperator when BTEQ is installed on local machine +------------------------------------------------------------------------------------- + +Creating a Teradata database table +---------------------------------- + +You can use the BteqOperator to create tables in a Teradata database. The following example demonstrates how to create a simple employee table: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_create_table] + :end-before: [END bteq_operator_howto_guide_create_table] + +The BTEQ script within this operator handles the table creation, including defining columns, data types, and constraints. + + +Inserting data into a Teradata database table +--------------------------------------------- + +The following example demonstrates how to populate the ``my_employees`` table with sample employee records: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_populate_table] + :end-before: [END bteq_operator_howto_guide_populate_table] + +This BTEQ script inserts multiple rows into the table in a single operation, making it efficient for batch data loading. + + +Exporting data from a Teradata database table to a file +------------------------------------------------------- + +The BteqOperator makes it straightforward to export query results to a file. This capability is valuable for data extraction, backups, and transferring data between systems. The following example demonstrates how to query the employee table and export the results: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_export_data_to_a_file] + :end-before: [END bteq_operator_howto_guide_export_data_to_a_file] + +The BTEQ script above handles the data export with options for formatting, file location specification, and error handling during the export process. + + +Fetching and processing records from your Teradata database +----------------------------------------------------------- + +You can use BteqOperator to query and retrieve data from your Teradata tables. The following example demonstrates +how to fetch specific records from the employee table with filtering and formatting: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_get_it_employees] + :end-before: [END bteq_operator_howto_guide_get_it_employees] + +Executing a BTEQ script with the BteqOperator +--------------------------------------------- + +You can use BteqOperator to execute a BTEQ script directly. This is useful for running complex queries or scripts that require multiple SQL statements or specific BTEQ commands. + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_bteq_file_input] + :end-before: [END bteq_operator_howto_guide_bteq_file_input] + + +Common Database Operations with BteqOperator when BTEQ is installed on remote machine +------------------------------------------------------------------------------------- + +Make sure SSH connection is defined with the required fields to connect to remote machine: + +- ``remote_host`` +- ``username`` +- ``password`` +- Optional: ``key_file``, ``private_key``, ``conn_timeout``, etc. + +Creating a Teradata database table +---------------------------------- + +You can use the BteqOperator to create tables in a Teradata database. The following example demonstrates how to create a simple employee table: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_remote_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_create_table] + :end-before: [END bteq_operator_howto_guide_create_table] + +The BTEQ script within this operator handles the table creation, including defining columns, data types, and constraints. + + +Inserting data into a Teradata database table +--------------------------------------------- + +The following example demonstrates how to populate the ``my_employees`` table with sample employee records: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_remote_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_populate_table] + :end-before: [END bteq_operator_howto_guide_populate_table] + +This BTEQ script inserts multiple rows into the table in a single operation, making it efficient for batch data loading. + + +Exporting data from a Teradata database table to a file +------------------------------------------------------- + +The BteqOperator makes it straightforward to export query results to a file. This capability is valuable for data extraction, backups, and transferring data between systems. The following example demonstrates how to query the employee table and export the results: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_remote_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_export_data_to_a_file] + :end-before: [END bteq_operator_howto_guide_export_data_to_a_file] + +The BTEQ script above handles the data export with options for formatting, file location specification, and error handling during the export process. + + +Fetching and processing records from your Teradata database +----------------------------------------------------------- + +You can use BteqOperator to query and retrieve data from your Teradata tables. The following example demonstrates +how to fetch specific records from the employee table with filtering and formatting: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_remote_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_get_it_employees] + :end-before: [END bteq_operator_howto_guide_get_it_employees] + +This example shows how to: +- Execute a SELECT query with WHERE clause filtering +- Format the output for better readability +- Process the result set within the BTEQ script +- Handle empty result sets appropriately + +Executing a BTEQ script with the BteqOperator when BTEQ script file is on remote machine +---------------------------------------------------------------------------------------- + +You can use BteqOperator to execute a BTEQ script directly when file is on remote machine. + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_remote_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_bteq_file_input] + :end-before: [END bteq_operator_howto_guide_bteq_file_input] + + +Using Conditional Logic with BteqOperator +----------------------------------------- + +The BteqOperator supports executing conditional logic within your BTEQ scripts. This powerful feature lets you create dynamic, decision-based workflows that respond to data conditions or processing results: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_conditional_logic] + :end-before: [END bteq_operator_howto_guide_conditional_logic] + +Conditional execution enables more intelligent data pipelines that can adapt to different scenarios without requiring separate DAG branches. + + +Error Handling in BTEQ Scripts +------------------------------ + +The BteqOperator allows you to implement comprehensive error handling within your BTEQ scripts: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_error_handling] + :end-before: [END bteq_operator_howto_guide_error_handling] + +This approach lets you catch and respond to errors at the BTEQ script level, providing more granular control over error conditions and enabling appropriate recovery actions. + + +Dropping a Teradata Database Table +---------------------------------- + +When your workflow completes or requires cleanup, you can use the BteqOperator to drop database objects. The following example demonstrates how to drop the ``my_employees`` table: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_bteq.py + :language: python + :dedent: 4 + :start-after: [START bteq_operator_howto_guide_drop_table] + :end-before: [END bteq_operator_howto_guide_drop_table] + + +The complete Teradata Operator DAG +---------------------------------- + +When we put everything together, our DAG should look like this: + +.. exampleinclude:: /../../teradata/tests/system/teradata/example_bteq.py + :language: python + :start-after: [START bteq_operator_howto_guide] + :end-before: [END bteq_operator_howto_guide] diff --git a/providers/teradata/provider.yaml b/providers/teradata/provider.yaml index 153e825ccd7ab..01563d9d0dee3 100644 --- a/providers/teradata/provider.yaml +++ b/providers/teradata/provider.yaml @@ -57,11 +57,20 @@ operators: python-modules: - airflow.providers.teradata.operators.teradata - airflow.providers.teradata.operators.teradata_compute_cluster + - integration-name: Bteq + python-modules: + - airflow.providers.teradata.operators.bteq hooks: - integration-name: Teradata python-modules: - airflow.providers.teradata.hooks.teradata + - integration-name: Ttu + python-modules: + - airflow.providers.teradata.hooks.ttu + - integration-name: Bteq + python-modules: + - airflow.providers.teradata.hooks.bteq transfers: - source-integration-name: Teradata diff --git a/providers/teradata/pyproject.toml b/providers/teradata/pyproject.toml index d2dbb359b3090..3866c56d0a64a 100644 --- a/providers/teradata/pyproject.toml +++ b/providers/teradata/pyproject.toml @@ -72,6 +72,9 @@ dependencies = [ "amazon" = [ "apache-airflow-providers-amazon", ] +"ssh" = [ + "apache-airflow-providers-ssh" +] [dependency-groups] dev = [ @@ -81,6 +84,7 @@ dev = [ "apache-airflow-providers-amazon", "apache-airflow-providers-common-sql", "apache-airflow-providers-microsoft-azure", + "apache-airflow-providers-ssh", # Additional devel dependencies (do not remove this line and add extra development dependencies) ] diff --git a/providers/teradata/src/airflow/providers/teradata/get_provider_info.py b/providers/teradata/src/airflow/providers/teradata/get_provider_info.py index 7780d600f0531..f9b935fa551ff 100644 --- a/providers/teradata/src/airflow/providers/teradata/get_provider_info.py +++ b/providers/teradata/src/airflow/providers/teradata/get_provider_info.py @@ -45,10 +45,13 @@ def get_provider_info(): "airflow.providers.teradata.operators.teradata", "airflow.providers.teradata.operators.teradata_compute_cluster", ], - } + }, + {"integration-name": "Bteq", "python-modules": ["airflow.providers.teradata.operators.bteq"]}, ], "hooks": [ - {"integration-name": "Teradata", "python-modules": ["airflow.providers.teradata.hooks.teradata"]} + {"integration-name": "Teradata", "python-modules": ["airflow.providers.teradata.hooks.teradata"]}, + {"integration-name": "Ttu", "python-modules": ["airflow.providers.teradata.hooks.ttu"]}, + {"integration-name": "Bteq", "python-modules": ["airflow.providers.teradata.hooks.bteq"]}, ], "transfers": [ { diff --git a/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py b/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py new file mode 100644 index 0000000000000..89aac594ddb18 --- /dev/null +++ b/providers/teradata/src/airflow/providers/teradata/hooks/bteq.py @@ -0,0 +1,325 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os +import socket +import subprocess +import tempfile +from contextlib import contextmanager + +from paramiko import SSHException + +from airflow.exceptions import AirflowException +from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.providers.teradata.hooks.ttu import TtuHook +from airflow.providers.teradata.utils.bteq_util import ( + prepare_bteq_command_for_local_execution, + prepare_bteq_command_for_remote_execution, + transfer_file_sftp, + verify_bteq_installed, + verify_bteq_installed_remote, +) +from airflow.providers.teradata.utils.encryption_utils import ( + decrypt_remote_file_to_string, + generate_encrypted_file_with_openssl, + generate_random_password, +) + + +class BteqHook(TtuHook): + """ + Hook for executing BTEQ (Basic Teradata Query) scripts. + + This hook provides functionality to execute BTEQ scripts either locally or remotely via SSH. + It extends the `TtuHook` and integrates with Airflow's SSHHook for remote execution. + + The BTEQ scripts are used to interact with Teradata databases, allowing users to perform + operations such as querying, data manipulation, and administrative tasks. + + Features: + - Supports both local and remote execution of BTEQ scripts. + - Handles connection details, script preparation, and execution. + - Provides robust error handling and logging for debugging. + - Allows configuration of session parameters like output width and encoding. + + .. seealso:: + - :ref:`hook API connection ` + + :param bteq_script: The BTEQ script to be executed. This can be a string containing the BTEQ commands. + :param remote_working_dir: Temporary directory location on the remote host (via SSH) where the BTEQ script will be transferred and executed. Defaults to `/tmp` if not specified. This is only applicable when `ssh_conn_id` is provided. + :param bteq_script_encoding: Character encoding for the BTEQ script file. Defaults to ASCII if not specified. + :param timeout: Timeout (in seconds) for executing the BTEQ command. Default is 600 seconds (10 minutes). + :param timeout_rc: Return code to use if the BTEQ execution fails due to a timeout. To allow DAG execution to continue after a timeout, include this value in `bteq_quit_rc`. If not specified, a timeout will raise an exception and stop the DAG. + :param bteq_session_encoding: Character encoding for the BTEQ session. Defaults to UTF-8 if not specified. + :param bteq_quit_rc: Accepts a single integer, list, or tuple of return codes. Specifies which BTEQ return codes should be treated as successful, allowing subsequent tasks to continue execution. + """ + + def __init__(self, teradata_conn_id: str, ssh_conn_id: str | None = None, *args, **kwargs): + super().__init__(teradata_conn_id, *args, **kwargs) + self.ssh_conn_id = ssh_conn_id + self.ssh_hook = SSHHook(ssh_conn_id=ssh_conn_id) if ssh_conn_id else None + + def execute_bteq_script( + self, + bteq_script: str, + remote_working_dir: str | None, + bteq_script_encoding: str | None, + timeout: int, + timeout_rc: int | None, + bteq_session_encoding: str | None, + bteq_quit_rc: int | list[int] | tuple[int, ...] | None, + temp_file_read_encoding: str | None, + ) -> int | None: + """Execute the BTEQ script either in local machine or on remote host based on ssh_conn_id.""" + # Remote execution + if self.ssh_hook: + # Write script to local temp file + # Encrypt the file locally + return self.execute_bteq_script_at_remote( + bteq_script, + remote_working_dir, + bteq_script_encoding, + timeout, + timeout_rc, + bteq_session_encoding, + bteq_quit_rc, + temp_file_read_encoding, + ) + return self.execute_bteq_script_at_local( + bteq_script, + bteq_script_encoding, + timeout, + timeout_rc, + bteq_quit_rc, + bteq_session_encoding, + temp_file_read_encoding, + ) + + def execute_bteq_script_at_remote( + self, + bteq_script: str, + remote_working_dir: str | None, + bteq_script_encoding: str | None, + timeout: int, + timeout_rc: int | None, + bteq_session_encoding: str | None, + bteq_quit_rc: int | list[int] | tuple[int, ...] | None, + temp_file_read_encoding: str | None, + ) -> int | None: + with ( + self.preferred_temp_directory() as tmp_dir, + ): + file_path = os.path.join(tmp_dir, "bteq_script.txt") + with open(file_path, "w", encoding=str(temp_file_read_encoding or "UTF-8")) as f: + f.write(bteq_script) + return self._transfer_to_and_execute_bteq_on_remote( + file_path, + remote_working_dir, + bteq_script_encoding, + timeout, + timeout_rc, + bteq_quit_rc, + bteq_session_encoding, + tmp_dir, + ) + + def _transfer_to_and_execute_bteq_on_remote( + self, + file_path: str, + remote_working_dir: str | None, + bteq_script_encoding: str | None, + timeout: int, + timeout_rc: int | None, + bteq_quit_rc: int | list[int] | tuple[int, ...] | None, + bteq_session_encoding: str | None, + tmp_dir: str, + ) -> int | None: + encrypted_file_path = None + remote_encrypted_path = None + try: + if self.ssh_hook and self.ssh_hook.get_conn(): + with self.ssh_hook.get_conn() as ssh_client: + if ssh_client is None: + raise AirflowException("Failed to establish SSH connection. `ssh_client` is None.") + verify_bteq_installed_remote(ssh_client) + password = generate_random_password() # Encryption/Decryption password + encrypted_file_path = os.path.join(tmp_dir, "bteq_script.enc") + generate_encrypted_file_with_openssl(file_path, password, encrypted_file_path) + remote_encrypted_path = os.path.join(remote_working_dir or "", "bteq_script.enc") + + transfer_file_sftp(ssh_client, encrypted_file_path, remote_encrypted_path) + + bteq_command_str = prepare_bteq_command_for_remote_execution( + timeout=timeout, + bteq_script_encoding=bteq_script_encoding or "", + bteq_session_encoding=bteq_session_encoding or "", + timeout_rc=timeout_rc or -1, + ) + + exit_status, stdout, stderr = decrypt_remote_file_to_string( + ssh_client, + remote_encrypted_path, + password, + bteq_command_str, + ) + + failure_message = None + password = None # Clear sensitive data + + if "Failure" in stderr or "Error" in stderr: + failure_message = stderr + # Raising an exception if there is any failure in bteq and also user wants to fail the + # task otherwise just log the error message as warning to not fail the task. + if ( + failure_message + and exit_status != 0 + and exit_status + not in ( + bteq_quit_rc + if isinstance(bteq_quit_rc, (list, tuple)) + else [bteq_quit_rc if bteq_quit_rc is not None else 0] + ) + ): + raise AirflowException(f"BTEQ task failed with error: {failure_message}") + if failure_message: + self.log.warning(failure_message) + return exit_status + else: + raise AirflowException("SSH connection is not established. `ssh_hook` is None or invalid.") + except (OSError, socket.gaierror): + raise AirflowException( + "SSH connection timed out. Please check the network or server availability." + ) + except SSHException as e: + raise AirflowException(f"An unexpected error occurred during SSH connection: {str(e)}") + except AirflowException as e: + raise e + except Exception as e: + raise AirflowException( + f"An unexpected error occurred while executing BTEQ script on remote machine: {str(e)}" + ) + finally: + # Remove the local script file + if encrypted_file_path and os.path.exists(encrypted_file_path): + os.remove(encrypted_file_path) + # Cleanup: Delete the remote temporary file + if encrypted_file_path: + cleanup_en_command = f"rm -f {remote_encrypted_path}" + if self.ssh_hook and self.ssh_hook.get_conn(): + with self.ssh_hook.get_conn() as ssh_client: + if ssh_client is None: + raise AirflowException( + "Failed to establish SSH connection. `ssh_client` is None." + ) + ssh_client.exec_command(cleanup_en_command) + + def execute_bteq_script_at_local( + self, + bteq_script: str, + bteq_script_encoding: str | None, + timeout: int, + timeout_rc: int | None, + bteq_quit_rc: int | list[int] | tuple[int, ...] | None, + bteq_session_encoding: str | None, + temp_file_read_encoding: str | None, + ) -> int | None: + verify_bteq_installed() + bteq_command_str = prepare_bteq_command_for_local_execution( + self.get_conn(), + timeout=timeout, + bteq_script_encoding=bteq_script_encoding or "", + bteq_session_encoding=bteq_session_encoding or "", + timeout_rc=timeout_rc or -1, + ) + process = subprocess.Popen( + bteq_command_str, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=True, + preexec_fn=os.setsid, + ) + encode_bteq_script = bteq_script.encode(str(temp_file_read_encoding or "UTF-8")) + stdout_data, _ = process.communicate(input=encode_bteq_script) + try: + # https://docs.python.org/3.10/library/subprocess.html#subprocess.Popen.wait timeout is in seconds + process.wait(timeout=timeout + 60) # Adding 1 minute extra for BTEQ script timeout + except subprocess.TimeoutExpired: + self.on_kill() + raise AirflowException(f"BTEQ command timed out after {timeout} seconds.") + conn = self.get_conn() + conn["sp"] = process # For `on_kill` support + failure_message = None + if stdout_data is None: + raise AirflowException("Process stdout is None. Unable to read BTEQ output.") + decoded_line = "" + for line in stdout_data.splitlines(): + try: + decoded_line = line.decode("UTF-8").strip() + except UnicodeDecodeError: + self.log.warning("Failed to decode line: %s", line) + if "Failure" in decoded_line or "Error" in decoded_line: + failure_message = decoded_line + # Raising an exception if there is any failure in bteq and also user wants to fail the + # task otherwise just log the error message as warning to not fail the task. + if ( + failure_message + and process.returncode != 0 + and process.returncode + not in ( + bteq_quit_rc + if isinstance(bteq_quit_rc, (list, tuple)) + else [bteq_quit_rc if bteq_quit_rc is not None else 0] + ) + ): + raise AirflowException(f"BTEQ task failed with error: {failure_message}") + if failure_message: + self.log.warning(failure_message) + + return process.returncode + + def on_kill(self): + """Terminate the subprocess if running.""" + conn = self.get_conn() + process = conn.get("sp") + if process: + try: + process.terminate() + process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.log.warning("Subprocess did not terminate in time. Forcing kill...") + process.kill() + except Exception as e: + self.log.error("Failed to terminate subprocess: %s", str(e)) + + def get_airflow_home_dir(self) -> str: + """Get the AIRFLOW_HOME directory.""" + return os.environ.get("AIRFLOW_HOME", "~/airflow") + + @contextmanager + def preferred_temp_directory(self, prefix="bteq_"): + try: + temp_dir = tempfile.gettempdir() + if not os.path.isdir(temp_dir) or not os.access(temp_dir, os.W_OK): + raise OSError("OS temp dir not usable") + except Exception: + temp_dir = self.get_airflow_home_dir() + + with tempfile.TemporaryDirectory(dir=temp_dir, prefix=prefix) as tmp: + yield tmp diff --git a/providers/teradata/src/airflow/providers/teradata/hooks/ttu.py b/providers/teradata/src/airflow/providers/teradata/hooks/ttu.py new file mode 100644 index 0000000000000..958bd76796e73 --- /dev/null +++ b/providers/teradata/src/airflow/providers/teradata/hooks/ttu.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import subprocess +from abc import ABC +from typing import Any + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class TtuHook(BaseHook, ABC): + """ + Abstract base hook for integrating Teradata Tools and Utilities (TTU) in Airflow. + + This hook provides common connection handling, resource management, and lifecycle + support for TTU based operations such as BTEQ, TLOAD, and TPT. + + It should not be used directly. Instead, it must be subclassed by concrete hooks + like `BteqHook`, `TloadHook`, or `TddlHook` that implement the actual TTU command logic. + + Core Features: + - Establishes a reusable Teradata connection configuration. + - Provides context management for safe resource cleanup. + - Manages subprocess termination (e.g., for long-running TTU jobs). + + Requirements: + - TTU command-line tools must be installed and accessible via PATH. + - A valid Airflow connection with Teradata credentials must be configured. + """ + + def __init__(self, teradata_conn_id: str = "teradata_default", *args, **kwargs) -> None: + super().__init__() + self.teradata_conn_id = teradata_conn_id + self.conn: dict[str, Any] | None = None + + def __enter__(self): + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self.conn is not None: + self.close_conn() + + def get_conn(self) -> dict[str, Any]: + """ + Set up and return a Teradata connection dictionary. + + This dictionary includes connection credentials and a subprocess placeholder. + Ensures connection is created only once per hook instance. + + :return: Dictionary with connection details. + """ + if not self.conn: + connection = self.get_connection(self.teradata_conn_id) + if not connection.login or not connection.password or not connection.host: + raise AirflowException("Missing required connection parameters: login, password, or host.") + + self.conn = dict( + login=connection.login, + password=connection.password, + host=connection.host, + database=connection.schema, + sp=None, # Subprocess placeholder + ) + return self.conn + + def close_conn(self): + """Terminate any active TTU subprocess and clear the connection.""" + if self.conn: + if self.conn.get("sp") and self.conn["sp"].poll() is None: + self.conn["sp"].terminate() + try: + self.conn["sp"].wait(timeout=5) + except subprocess.TimeoutExpired: + self.log.warning("Subprocess did not terminate in time. Forcing kill...") + self.conn["sp"].kill() + self.conn = None diff --git a/providers/teradata/src/airflow/providers/teradata/operators/bteq.py b/providers/teradata/src/airflow/providers/teradata/operators/bteq.py new file mode 100644 index 0000000000000..22779be9bfb33 --- /dev/null +++ b/providers/teradata/src/airflow/providers/teradata/operators/bteq.py @@ -0,0 +1,284 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +from airflow.providers.teradata.utils.bteq_util import ( + is_valid_encoding, + is_valid_file, + is_valid_remote_bteq_script_file, + prepare_bteq_script_for_local_execution, + prepare_bteq_script_for_remote_execution, + read_file, +) + +if TYPE_CHECKING: + from paramiko import SSHClient + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + from airflow.utils.context import Context + +from airflow.models import BaseOperator +from airflow.providers.ssh.hooks.ssh import SSHHook +from airflow.providers.teradata.hooks.bteq import BteqHook +from airflow.providers.teradata.hooks.teradata import TeradataHook + + +def contains_template(parameter_value): + # Check if the parameter contains Jinja templating syntax + return "{{" in parameter_value and "}}" in parameter_value + + +class BteqOperator(BaseOperator): + """ + Teradata Operator to execute SQL Statements or BTEQ (Basic Teradata Query) scripts using Teradata BTEQ utility. + + This supports execution of BTEQ scripts either locally or remotely via SSH. + + The BTEQ scripts are used to interact with Teradata databases, allowing users to perform + operations such as querying, data manipulation, and administrative tasks. + + Features: + - Supports both local and remote execution of BTEQ scripts. + - Handles connection details, script preparation, and execution. + - Provides robust error handling and logging for debugging. + - Allows configuration of session parameters like session and BTEQ I/O encoding. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BteqOperator` + + :param sql: SQL statement(s) to be executed using BTEQ. (templated) + :param file_path: Optional path to an existing SQL or BTEQ script file. If provided, this file will be used instead of the `sql` content. This path represents remote file path when executing remotely via SSH, or local file path when executing locally. + :param teradata_conn_id: Reference to a specific Teradata connection. + :param ssh_conn_id: Optional SSH connection ID for remote execution. Used only when executing scripts remotely. + :param remote_working_dir: Temporary directory location on the remote host (via SSH) where the BTEQ script will be transferred and executed. Defaults to `/tmp` if not specified. This is only applicable when `ssh_conn_id` is provided. + :param bteq_session_encoding: Character set encoding for the BTEQ session. Defaults to ASCII if not specified. + :param bteq_script_encoding: Character encoding for the BTEQ script file. Defaults to ASCII if not specified. + :param bteq_quit_rc: Accepts a single integer, list, or tuple of return codes. Specifies which BTEQ return codes should be treated as successful, allowing subsequent tasks to continue execution. + :param timeout: Timeout (in seconds) for executing the BTEQ command. Default is 600 seconds (10 minutes). + :param timeout_rc: Return code to use if the BTEQ execution fails due to a timeout. To allow DAG execution to continue after a timeout, include this value in `bteq_quit_rc`. If not specified, a timeout will raise an exception and stop the DAG. + """ + + template_fields = "sql" + ui_color = "#ff976d" + + def __init__( + self, + *, + sql: str | None = None, + file_path: str | None = None, + teradata_conn_id: str = TeradataHook.default_conn_name, + ssh_conn_id: str | None = None, + remote_working_dir: str | None = None, + bteq_session_encoding: str | None = None, + bteq_script_encoding: str | None = None, + bteq_quit_rc: int | list[int] | tuple[int, ...] | None = None, + timeout: int | Literal[600] = 600, # Default to 10 minutes + timeout_rc: int | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.file_path = file_path + self.teradata_conn_id = teradata_conn_id + self.ssh_conn_id = ssh_conn_id + self.remote_working_dir = remote_working_dir + self.timeout = timeout + self.timeout_rc = timeout_rc + self.bteq_session_encoding = bteq_session_encoding + self.bteq_script_encoding = bteq_script_encoding + self.bteq_quit_rc = bteq_quit_rc + self._hook: BteqHook | None = None + self._ssh_hook: SSHHook | None = None + self.temp_file_read_encoding = "UTF-8" + + def execute(self, context: Context) -> int | None: + """Execute BTEQ code using the BteqHook.""" + if not self.sql and not self.file_path: + raise ValueError( + "BteqOperator requires either the 'sql' or 'file_path' parameter. Both are missing." + ) + self._hook = BteqHook(teradata_conn_id=self.teradata_conn_id, ssh_conn_id=self.ssh_conn_id) + self._ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) if self.ssh_conn_id else None + + # Validate and set BTEQ session and script encoding + if not self.bteq_session_encoding or self.bteq_session_encoding == "ASCII": + self.bteq_session_encoding = "" + if self.bteq_script_encoding == "UTF8": + self.temp_file_read_encoding = "UTF-8" + elif self.bteq_script_encoding == "UTF16": + self.temp_file_read_encoding = "UTF-16" + self.bteq_script_encoding = "" + elif self.bteq_session_encoding == "UTF8" and ( + not self.bteq_script_encoding or self.bteq_script_encoding == "ASCII" + ): + self.bteq_script_encoding = "UTF8" + elif self.bteq_session_encoding == "UTF16": + if not self.bteq_script_encoding or self.bteq_script_encoding == "ASCII": + self.bteq_script_encoding = "UTF8" + # for file reading in python. Mapping BTEQ encoding to Python encoding + if self.bteq_script_encoding == "UTF8": + self.temp_file_read_encoding = "UTF-8" + elif self.bteq_script_encoding == "UTF16": + self.temp_file_read_encoding = "UTF-16" + + if not self.remote_working_dir: + self.remote_working_dir = "/tmp" + # Handling execution on local: + if not self._ssh_hook: + if self.sql: + bteq_script = prepare_bteq_script_for_local_execution( + sql=self.sql, + ) + return self._hook.execute_bteq_script( + bteq_script, + self.remote_working_dir, + self.bteq_script_encoding, + self.timeout, + self.timeout_rc, + self.bteq_session_encoding, + self.bteq_quit_rc, + self.temp_file_read_encoding, + ) + if self.file_path: + if not is_valid_file(self.file_path): + raise ValueError( + f"The provided file path '{self.file_path}' is invalid or does not exist." + ) + try: + is_valid_encoding(self.file_path, self.temp_file_read_encoding or "UTF-8") + except UnicodeDecodeError as e: + errmsg = f"The provided file '{self.file_path}' encoding is different from BTEQ I/O encoding i.e.'UTF-8'." + if self.bteq_script_encoding: + errmsg = f"The provided file '{self.file_path}' encoding is different from the specified BTEQ I/O encoding '{self.bteq_script_encoding}'." + raise ValueError(errmsg) from e + return self._handle_local_bteq_file( + file_path=self.file_path, + context=context, + ) + # Execution on Remote machine + elif self._ssh_hook: + # When sql statement is provided as input through sql parameter, Preparing the bteq script + if self.sql: + bteq_script = prepare_bteq_script_for_remote_execution( + conn=self._hook.get_conn(), + sql=self.sql, + ) + return self._hook.execute_bteq_script( + bteq_script, + self.remote_working_dir, + self.bteq_script_encoding, + self.timeout, + self.timeout_rc, + self.bteq_session_encoding, + self.bteq_quit_rc, + self.temp_file_read_encoding, + ) + if self.file_path: + with self._ssh_hook.get_conn() as ssh_client: + # When .sql or .bteq remote file path is provided as input through file_path parameter, executing on remote machine + if self.file_path and is_valid_remote_bteq_script_file(ssh_client, self.file_path): + return self._handle_remote_bteq_file( + ssh_client=self._ssh_hook.get_conn(), + file_path=self.file_path, + context=context, + ) + raise ValueError( + f"The provided remote file path '{self.file_path}' is invalid or file does not exist on remote machine at given path." + ) + else: + raise ValueError( + "BteqOperator requires either the 'sql' or 'file_path' parameter. Both are missing." + ) + return None + + def _handle_remote_bteq_file( + self, + ssh_client: SSHClient, + file_path: str | None, + context: Context, + ) -> int | None: + if file_path: + with ssh_client: + sftp = ssh_client.open_sftp() + try: + with sftp.open(file_path, "r") as remote_file: + original_content = remote_file.read().decode(self.temp_file_read_encoding or "UTF-8") + finally: + sftp.close() + rendered_content = original_content + if contains_template(original_content): + rendered_content = self.render_template(original_content, context) + if self._hook: + bteq_script = prepare_bteq_script_for_remote_execution( + conn=self._hook.get_conn(), + sql=rendered_content, + ) + return self._hook.execute_bteq_script_at_remote( + bteq_script, + self.remote_working_dir, + self.bteq_script_encoding, + self.timeout, + self.timeout_rc, + self.bteq_session_encoding, + self.bteq_quit_rc, + self.temp_file_read_encoding, + ) + return None + raise ValueError( + "Please provide a valid file path for the BTEQ script to be executed on the remote machine." + ) + + def _handle_local_bteq_file( + self, + file_path: str, + context: Context, + ) -> int | None: + if file_path and is_valid_file(file_path): + file_content = read_file(file_path, encoding=str(self.temp_file_read_encoding or "UTF-8")) + # Manually render using operator's context + rendered_content = file_content + if contains_template(file_content): + rendered_content = self.render_template(file_content, context) + bteq_script = prepare_bteq_script_for_local_execution( + sql=rendered_content, + ) + if self._hook: + result = self._hook.execute_bteq_script( + bteq_script, + self.remote_working_dir, + self.bteq_script_encoding, + self.timeout, + self.timeout_rc, + self.bteq_session_encoding, + self.bteq_quit_rc, + self.temp_file_read_encoding, + ) + return result + return None + + def on_kill(self) -> None: + """Handle task termination by invoking the on_kill method of BteqHook.""" + if self._hook: + self._hook.on_kill() + else: + self.log.warning("BteqHook was not initialized. Nothing to terminate.") diff --git a/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py b/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py new file mode 100644 index 0000000000000..0741ebb20090c --- /dev/null +++ b/providers/teradata/src/airflow/providers/teradata/utils/bteq_util.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import shutil +import stat +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from paramiko import SSHClient + +from airflow.exceptions import AirflowException + + +def verify_bteq_installed(): + """Verify if BTEQ is installed and available in the system's PATH.""" + if shutil.which("bteq") is None: + raise AirflowException("BTEQ is not installed or not available in the system's PATH.") + + +def verify_bteq_installed_remote(ssh_client: SSHClient): + """Verify if BTEQ is installed on the remote machine.""" + stdin, stdout, stderr = ssh_client.exec_command("which bteq") + exit_status = stdout.channel.recv_exit_status() + output = stdout.read().strip() + error = stderr.read().strip() + + if exit_status != 0 or not output: + raise AirflowException( + f"BTEQ is not installed or not available in PATH. stderr: {error.decode() if error else 'N/A'}" + ) + + +def transfer_file_sftp(ssh_client, local_path, remote_path): + sftp = ssh_client.open_sftp() + sftp.put(local_path, remote_path) + sftp.close() + + +# We can not pass host details with bteq command when executing on remote machine. Instead, we will prepare .logon in bteq script itself to avoid risk of +# exposing sensitive information +def prepare_bteq_script_for_remote_execution(conn: dict[str, Any], sql: str) -> str: + """Build a BTEQ script with necessary connection and session commands.""" + script_lines = [] + host = conn["host"] + login = conn["login"] + password = conn["password"] + script_lines.append(f" .LOGON {host}/{login},{password}") + return _prepare_bteq_script(script_lines, sql) + + +def prepare_bteq_script_for_local_execution( + sql: str, +) -> str: + """Build a BTEQ script with necessary connection and session commands.""" + script_lines: list[str] = [] + return _prepare_bteq_script(script_lines, sql) + + +def _prepare_bteq_script(script_lines: list[str], sql: str) -> str: + script_lines.append(sql.strip()) + script_lines.append(".EXIT") + return "\n".join(script_lines) + + +def _prepare_bteq_command( + timeout: int, + bteq_script_encoding: str, + bteq_session_encoding: str, + timeout_rc: int, +) -> list[str]: + bteq_core_cmd = ["bteq"] + if bteq_session_encoding: + bteq_core_cmd.append(f" -e {bteq_script_encoding}") + bteq_core_cmd.append(f" -c {bteq_session_encoding}") + bteq_core_cmd.append('"') + bteq_core_cmd.append(f".SET EXITONDELAY ON MAXREQTIME {timeout}") + if timeout_rc is not None and timeout_rc >= 0: + bteq_core_cmd.append(f" RC {timeout_rc}") + bteq_core_cmd.append(";") + # Airflow doesn't display the script of BTEQ in UI but only in log so WIDTH is 500 enough + bteq_core_cmd.append(" .SET WIDTH 500;") + return bteq_core_cmd + + +def prepare_bteq_command_for_remote_execution( + timeout: int, + bteq_script_encoding: str, + bteq_session_encoding: str, + timeout_rc: int, +) -> str: + """Prepare the BTEQ command with necessary parameters.""" + bteq_core_cmd = _prepare_bteq_command(timeout, bteq_script_encoding, bteq_session_encoding, timeout_rc) + bteq_core_cmd.append('"') + return " ".join(bteq_core_cmd) + + +def prepare_bteq_command_for_local_execution( + conn: dict[str, Any], + timeout: int, + bteq_script_encoding: str, + bteq_session_encoding: str, + timeout_rc: int, +) -> str: + """Prepare the BTEQ command with necessary parameters.""" + bteq_core_cmd = _prepare_bteq_command(timeout, bteq_script_encoding, bteq_session_encoding, timeout_rc) + host = conn["host"] + login = conn["login"] + password = conn["password"] + bteq_core_cmd.append(f" .LOGON {host}/{login},{password}") + bteq_core_cmd.append('"') + bteq_command_str = " ".join(bteq_core_cmd) + return bteq_command_str + + +def is_valid_file(file_path: str) -> bool: + return os.path.isfile(file_path) + + +def is_valid_encoding(file_path: str, encoding: str = "UTF-8") -> bool: + """ + Check if the file can be read with the specified encoding. + + :param file_path: Path to the file to be checked. + :param encoding: Encoding to use for reading the file. + :return: True if the file can be read with the specified encoding, False otherwise. + """ + with open(file_path, encoding=encoding) as f: + f.read() + return True + + +def read_file(file_path: str, encoding: str = "UTF-8") -> str: + """ + Read the content of a file with the specified encoding. + + :param file_path: Path to the file to be read. + :param encoding: Encoding to use for reading the file. + :return: Content of the file as a string. + """ + if not os.path.isfile(file_path): + raise FileNotFoundError(f"The file {file_path} does not exist.") + + with open(file_path, encoding=encoding) as f: + return f.read() + + +def is_valid_remote_bteq_script_file(ssh_client: SSHClient, remote_file_path: str, logger=None) -> bool: + """Check if the given remote file path is a valid BTEQ script file.""" + if remote_file_path: + sftp_client = ssh_client.open_sftp() + try: + # Get file metadata + file_stat = sftp_client.stat(remote_file_path) + if file_stat.st_mode: + is_regular_file = stat.S_ISREG(file_stat.st_mode) + return is_regular_file + return False + except FileNotFoundError: + if logger: + logger.error("File does not exist on remote at : %s", remote_file_path) + return False + finally: + sftp_client.close() + else: + return False diff --git a/providers/teradata/src/airflow/providers/teradata/utils/encryption_utils.py b/providers/teradata/src/airflow/providers/teradata/utils/encryption_utils.py new file mode 100644 index 0000000000000..57ed4b9855810 --- /dev/null +++ b/providers/teradata/src/airflow/providers/teradata/utils/encryption_utils.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import secrets +import string +import subprocess + + +def generate_random_password(length=12): + # Define the character set: letters, digits, and special characters + characters = string.ascii_letters + string.digits + string.punctuation + # Generate a random password + password = "".join(secrets.choice(characters) for _ in range(length)) + return password + + +def generate_encrypted_file_with_openssl(file_path: str, password: str, out_file: str): + # Write plaintext temporarily to file + + # Run openssl enc with AES-256-CBC, pbkdf2, salt + cmd = [ + "openssl", + "enc", + "-aes-256-cbc", + "-salt", + "-pbkdf2", + "-pass", + f"pass:{password}", + "-in", + file_path, + "-out", + out_file, + ] + subprocess.run(cmd, check=True) + + +def decrypt_remote_file_to_string(ssh_client, remote_enc_file, password, bteq_command_str): + # Run openssl decrypt command on remote machine + quoted_password = shell_quote_single(password) + + decrypt_cmd = ( + f"openssl enc -d -aes-256-cbc -salt -pbkdf2 -pass pass:{quoted_password} -in {remote_enc_file} | " + + bteq_command_str + ) + # Clear password to prevent lingering sensitive data + password = None + quoted_password = None + stdin, stdout, stderr = ssh_client.exec_command(decrypt_cmd) + # Wait for command to finish + exit_status = stdout.channel.recv_exit_status() + output = stdout.read().decode() + err = stderr.read().decode() + return exit_status, output, err + + +def shell_quote_single(s): + # Escape single quotes in s, then wrap in single quotes + # In shell, to include a single quote inside single quotes, close, add '\'' and reopen + return "'" + s.replace("'", "'\\''") + "'" diff --git a/providers/teradata/tests/system/teradata/example_bteq.py b/providers/teradata/tests/system/teradata/example_bteq.py new file mode 100644 index 0000000000000..a77f81bc697dc --- /dev/null +++ b/providers/teradata/tests/system/teradata/example_bteq.py @@ -0,0 +1,272 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG to show usage of BteqOperator. + +This DAG assumes Airflow Connection with connection id `TTU_DEFAULT` already exists in locally. It +shows how to use Teradata BTEQ commands with BteqOperator as tasks in +airflow dags using BteqeOperator. +""" + +from __future__ import annotations + +import datetime +import os + +import pytest + +from airflow import DAG + +try: + from airflow.providers.teradata.operators.bteq import BteqOperator +except ImportError: + pytest.skip("TERADATA provider not available", allow_module_level=True) + +# [START bteq_operator_howto_guide] + + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "example_bteq" +CONN_ID = "teradata_default" +SSH_CONN_ID = "ssh_default" + +host = os.environ.get("host", "localhost") +username = os.environ.get("username", "temp") +password = os.environ.get("password", "temp") +params = { + "host": host, + "username": username, + "password": password, + "DATABASE_NAME": "airflow", + "TABLE_NAME": "my_employees", + "DB_TABLE_NAME": "airflow.my_employees", +} +with DAG( + dag_id=DAG_ID, + start_date=datetime.datetime(2020, 2, 2), + schedule="@once", + catchup=False, + default_args={"teradata_conn_id": CONN_ID, "params": params}, +) as dag: + # [START bteq_operator_howto_guide_create_table] + create_table = BteqOperator( + task_id="create_table", + sql=r""" + CREATE SET TABLE {{params.DB_TABLE_NAME}} ( + emp_id INT, + emp_name VARCHAR(100), + dept VARCHAR(50) + ) PRIMARY INDEX (emp_id); + """, + bteq_quit_rc=[0, 4], + timeout=20, + bteq_session_encoding="UTF8", + bteq_script_encoding="UTF8", + params=params, + ) + # [END bteq_operator_howto_guide_create_table] + # [START bteq_operator_howto_guide_populate_table] + populate_table = BteqOperator( + task_id="populate_table", + sql=r""" + INSERT INTO {{params.DB_TABLE_NAME}} VALUES (1, 'John Doe', 'IT'); + INSERT INTO {{params.DB_TABLE_NAME}} VALUES (2, 'Jane Smith', 'HR'); + """, + params=params, + bteq_session_encoding="UTF8", + bteq_quit_rc=0, + ) + # [END bteq_operator_howto_guide_populate_table] + + # [START bteq_operator_howto_guide_export_data_to_a_file] + export_to_a_file = BteqOperator( + task_id="export_to_a_file", + sql=r""" + .EXPORT FILE = employees_output.txt; + SELECT * FROM {{params.DB_TABLE_NAME}}; + .EXPORT RESET; + """, + bteq_session_encoding="UTF16", + ) + # [END bteq_operator_howto_guide_export_data_to_a_file] + + # [START bteq_operator_howto_guide_get_it_employees] + get_it_employees = BteqOperator( + task_id="get_it_employees", + sql=r""" + SELECT * FROM {{params.DB_TABLE_NAME}} WHERE dept = 'IT'; + """, + bteq_session_encoding="ASCII", + ) + # [END bteq_operator_howto_guide_get_it_employees] + + # [START bteq_operator_howto_guide_conditional_logic] + cond_logic = BteqOperator( + task_id="cond_logic", + sql=r""" + .IF ERRORCODE <> 0 THEN .GOTO handle_error; + + SELECT COUNT(*) FROM {{params.DB_TABLE_NAME}}; + + .LABEL handle_error; + """, + bteq_script_encoding="UTF8", + ) + # [END bteq_operator_howto_guide_conditional_logic] + + # [START bteq_operator_howto_guide_error_handling] + error_handling = BteqOperator( + task_id="error_handling", + sql=r""" + DROP TABLE my_temp; + .IF ERRORCODE = 3807 THEN .GOTO table_not_found; + SELECT 'Table dropped successfully.'; + .GOTO end; + + .LABEL table_not_found; + SELECT 'Table not found - continuing execution'; + .LABEL end; + .LOGOFF; + .QUIT 0; + """, + bteq_script_encoding="UTF16", + ) + # [END bteq_operator_howto_guide_error_handling] + + # [START bteq_operator_howto_guide_drop_table] + drop_table = BteqOperator( + task_id="drop_table", + sql=r""" + DROP TABLE {{params.DB_TABLE_NAME}}; + .IF ERRORCODE = 3807 THEN .GOTO end; + + .LABEL end; + .LOGOFF; + .QUIT 0; + """, + bteq_script_encoding="ASCII", + ) + # [END bteq_operator_howto_guide_drop_table] + # [START bteq_operator_howto_guide_bteq_file_input] + execute_bteq_file = BteqOperator( + task_id="execute_bteq_file", + file_path="providers/teradata/tests/system/teradata/script.bteq", + params=params, + ) + # [END bteq_operator_howto_guide_bteq_file_input] + # [START bteq_operator_howto_guide_bteq_file_utf8_input] + execute_bteq_utf8_file = BteqOperator( + task_id="execute_bteq_utf8_file", + file_path="providers/teradata/tests/system/teradata/script.bteq", + params=params, + bteq_script_encoding="UTF8", + ) + # [END bteq_operator_howto_guide_bteq_file_utf8_input] + # [START bteq_operator_howto_guide_bteq_file_utf8_session_ascii_input] + execute_bteq_utf8_session_ascii_file = BteqOperator( + task_id="execute_bteq_utf8_session_ascii_file", + file_path="providers/teradata/tests/system/teradata/script.bteq", + params=params, + bteq_script_encoding="UTF8", + bteq_session_encoding="ASCII", + ) + # [END bteq_operator_howto_guide_bteq_file_utf8_session_ascii_input] + # [START bteq_operator_howto_guide_bteq_file_utf8_session_utf8_input] + execute_bteq_utf8_session_utf8_file = BteqOperator( + task_id="execute_bteq_utf8_session_utf8_file", + file_path="providers/teradata/tests/system/teradata/script.bteq", + params=params, + bteq_script_encoding="UTF8", + bteq_session_encoding="UTF8", + ) + # [END bteq_operator_howto_guide_bteq_file_utf8_session_utf8_input] + # [START bteq_operator_howto_guide_bteq_file_utf8_session_utf16_input] + execute_bteq_utf8_session_utf16_file = BteqOperator( + task_id="execute_bteq_utf8_session_utf16_file", + file_path="providers/teradata/tests/system/teradata/script.bteq", + params=params, + bteq_script_encoding="UTF8", + bteq_session_encoding="UTF16", + ) + # [END bteq_operator_howto_guide_bteq_file_utf8_session_utf16_input] + # [START bteq_operator_howto_guide_bteq_file_utf16_input] + execute_bteq_utf16_file = BteqOperator( + task_id="execute_bteq_utf16_file", + file_path="providers/teradata/tests/system/teradata/script_utf16.bteq", + params=params, + bteq_script_encoding="UTF16", + ) + # [END bteq_operator_howto_guide_bteq_file_utf16_input] + # [START bteq_operator_howto_guide_bteq_file_utf16_input] + execute_bteq_utf16_session_ascii_file = BteqOperator( + task_id="execute_bteq_utf16_session_ascii_file", + file_path="providers/teradata/tests/system/teradata/script_utf16.bteq", + params=params, + bteq_script_encoding="UTF16", + bteq_session_encoding="ASCII", + ) + # [END bteq_operator_howto_guide_bteq_file_utf16_input] + # [START bteq_operator_howto_guide_bteq_file_utf16_session_utf8_input] + execute_bteq_utf16_session_utf8_file = BteqOperator( + task_id="execute_bteq_utf16_session_utf8_file", + file_path="providers/teradata/tests/system/teradata/script_utf16.bteq", + params=params, + bteq_script_encoding="UTF16", + bteq_session_encoding="UTF8", + ) + # [END bteq_operator_howto_guide_bteq_file_utf16_session_utf8_input] + # [START bteq_operator_howto_guide_bteq_file_utf16_session_utf8_input] + execute_bteq_utf16_session_utf16_file = BteqOperator( + task_id="execute_bteq_utf16_session_utf16_file", + file_path="providers/teradata/tests/system/teradata/script_utf16.bteq", + params=params, + bteq_script_encoding="UTF16", + bteq_session_encoding="UTF16", + ) + # [END bteq_operator_howto_guide_bteq_file_utf16_session_utf8_input] + ( + create_table + >> populate_table + >> export_to_a_file + >> get_it_employees + >> cond_logic + >> error_handling + >> drop_table + >> execute_bteq_file + >> execute_bteq_utf8_file + >> execute_bteq_utf8_session_ascii_file + >> execute_bteq_utf8_session_utf8_file + >> execute_bteq_utf8_session_utf16_file + >> execute_bteq_utf16_file + >> execute_bteq_utf16_session_ascii_file + >> execute_bteq_utf16_session_utf8_file + >> execute_bteq_utf16_session_utf16_file + ) + + # [END bteq_operator_howto_guide] + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/teradata/tests/system/teradata/example_remote_bteq.py b/providers/teradata/tests/system/teradata/example_remote_bteq.py new file mode 100644 index 0000000000000..76bedff0b950f --- /dev/null +++ b/providers/teradata/tests/system/teradata/example_remote_bteq.py @@ -0,0 +1,272 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG to show usage of BteqOperator. + +This DAG assumes Airflow Connection with connection id `TTU_DEFAULT` already exists in locally. It +shows how to use Teradata BTEQ commands with BteqOperator as tasks in +airflow dags using BteqeOperator. +""" + +from __future__ import annotations + +import datetime +import os + +import pytest + +from airflow import DAG + +try: + from airflow.providers.teradata.operators.bteq import BteqOperator +except ImportError: + pytest.skip("TERADATA provider not available", allow_module_level=True) + +# [START bteq_operator_howto_guide] + + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "example_remote_bteq" +CONN_ID = "teradata_default" +SSH_CONN_ID = "ssh_default" + +host = os.environ.get("host", "localhost") +username = os.environ.get("username", "temp") +password = os.environ.get("password", "temp") +params = { + "host": host, + "username": username, + "password": password, + "DATABASE_NAME": "airflow", + "TABLE_NAME": "my_employees", + "DB_TABLE_NAME": "airflow.my_employees", +} +with DAG( + dag_id=DAG_ID, + start_date=datetime.datetime(2020, 2, 2), + schedule="@once", + catchup=False, + default_args={"teradata_conn_id": CONN_ID, "params": params, "ssh_conn_id": SSH_CONN_ID}, +) as dag: + # [START bteq_operator_howto_guide_create_table] + create_table = BteqOperator( + task_id="create_table", + sql=r""" + CREATE SET TABLE {{params.DB_TABLE_NAME}} ( + emp_id INT, + emp_name VARCHAR(100), + dept VARCHAR(50) + ) PRIMARY INDEX (emp_id); + """, + bteq_quit_rc=[0, 4], + timeout=20, + bteq_session_encoding="UTF8", + bteq_script_encoding="UTF8", + params=params, + ) + # [END bteq_operator_howto_guide_create_table] + # [START bteq_operator_howto_guide_populate_table] + populate_table = BteqOperator( + task_id="populate_table", + sql=r""" + INSERT INTO {{params.DB_TABLE_NAME}} VALUES (1, 'John Doe', 'IT'); + INSERT INTO {{params.DB_TABLE_NAME}} VALUES (2, 'Jane Smith', 'HR'); + """, + params=params, + bteq_session_encoding="UTF8", + bteq_quit_rc=0, + ) + # [END bteq_operator_howto_guide_populate_table] + + # [START bteq_operator_howto_guide_export_data_to_a_file] + export_to_a_file = BteqOperator( + task_id="export_to_a_file", + sql=r""" + .EXPORT FILE = employees_output.txt; + SELECT * FROM {{params.DB_TABLE_NAME}}; + .EXPORT RESET; + """, + bteq_session_encoding="UTF16", + ) + # [END bteq_operator_howto_guide_export_data_to_a_file] + + # [START bteq_operator_howto_guide_get_it_employees] + get_it_employees = BteqOperator( + task_id="get_it_employees", + sql=r""" + SELECT * FROM {{params.DB_TABLE_NAME}} WHERE dept = 'IT'; + """, + bteq_session_encoding="ASCII", + ) + # [END bteq_operator_howto_guide_get_it_employees] + + # [START bteq_operator_howto_guide_conditional_logic] + cond_logic = BteqOperator( + task_id="cond_logic", + sql=r""" + .IF ERRORCODE <> 0 THEN .GOTO handle_error; + + SELECT COUNT(*) FROM {{params.DB_TABLE_NAME}}; + + .LABEL handle_error; + """, + bteq_script_encoding="UTF8", + ) + # [END bteq_operator_howto_guide_conditional_logic] + + # [START bteq_operator_howto_guide_error_handling] + error_handling = BteqOperator( + task_id="error_handling", + sql=r""" + DROP TABLE my_temp; + .IF ERRORCODE = 3807 THEN .GOTO table_not_found; + SELECT 'Table dropped successfully.'; + .GOTO end; + + .LABEL table_not_found; + SELECT 'Table not found - continuing execution'; + .LABEL end; + .LOGOFF; + .QUIT 0; + """, + bteq_script_encoding="UTF16", + ) + # [END bteq_operator_howto_guide_error_handling] + + # [START bteq_operator_howto_guide_drop_table] + drop_table = BteqOperator( + task_id="drop_table", + sql=r""" + DROP TABLE {{params.DB_TABLE_NAME}}; + .IF ERRORCODE = 3807 THEN .GOTO end; + + .LABEL end; + .LOGOFF; + .QUIT 0; + """, + bteq_script_encoding="ASCII", + ) + # [END bteq_operator_howto_guide_drop_table] + # [START bteq_operator_howto_guide_bteq_file_input] + execute_bteq_file = BteqOperator( + task_id="execute_bteq_file", + file_path="/home/devtools/satish/airflow/script.bteq", + params=params, + ) + # [END bteq_operator_howto_guide_bteq_file_input] + # [START bteq_operator_howto_guide_bteq_file_utf8_input] + execute_bteq_utf8_file = BteqOperator( + task_id="execute_bteq_utf8_file", + file_path="/home/devtools/satish/airflow/script.bteq", + params=params, + bteq_script_encoding="UTF8", + ) + # [END bteq_operator_howto_guide_bteq_file_utf8_input] + # [START bteq_operator_howto_guide_bteq_file_utf8_session_ascii_input] + execute_bteq_utf8_session_ascii_file = BteqOperator( + task_id="execute_bteq_utf8_session_ascii_file", + file_path="/home/devtools/satish/airflow/script.bteq", + params=params, + bteq_script_encoding="UTF8", + bteq_session_encoding="ASCII", + ) + # [END bteq_operator_howto_guide_bteq_file_utf8_session_ascii_input] + # [START bteq_operator_howto_guide_bteq_file_utf8_session_utf8_input] + execute_bteq_utf8_session_utf8_file = BteqOperator( + task_id="execute_bteq_utf8_session_utf8_file", + file_path="/home/devtools/satish/airflow/script.bteq", + params=params, + bteq_script_encoding="UTF8", + bteq_session_encoding="UTF8", + ) + # [END bteq_operator_howto_guide_bteq_file_utf8_session_utf8_input] + # [START bteq_operator_howto_guide_bteq_file_utf8_session_utf16_input] + execute_bteq_utf8_session_utf16_file = BteqOperator( + task_id="execute_bteq_utf8_session_utf16_file", + file_path="/home/devtools/satish/airflow/script.bteq", + params=params, + bteq_script_encoding="UTF8", + bteq_session_encoding="UTF16", + ) + # [END bteq_operator_howto_guide_bteq_file_utf8_session_utf16_input] + # [START bteq_operator_howto_guide_bteq_file_utf16_input] + execute_bteq_utf16_file = BteqOperator( + task_id="execute_bteq_utf16_file", + file_path="/home/devtools/satish/airflow/script_utf16.bteq", + params=params, + bteq_script_encoding="UTF16", + ) + # [END bteq_operator_howto_guide_bteq_file_utf16_input] + # [START bteq_operator_howto_guide_bteq_file_utf16_input] + execute_bteq_utf16_session_ascii_file = BteqOperator( + task_id="execute_bteq_utf16_session_ascii_file", + file_path="/home/devtools/satish/airflow/script_utf16.bteq", + params=params, + bteq_script_encoding="UTF16", + bteq_session_encoding="ASCII", + ) + # [END bteq_operator_howto_guide_bteq_file_utf16_input] + # [START bteq_operator_howto_guide_bteq_file_utf16_session_utf8_input] + execute_bteq_utf16_session_utf8_file = BteqOperator( + task_id="execute_bteq_utf16_session_utf8_file", + file_path="/home/devtools/satish/airflow/script_utf16.bteq", + params=params, + bteq_script_encoding="UTF16", + bteq_session_encoding="UTF8", + ) + # [END bteq_operator_howto_guide_bteq_file_utf16_session_utf8_input] + # [START bteq_operator_howto_guide_bteq_file_utf16_session_utf8_input] + execute_bteq_utf16_session_utf16_file = BteqOperator( + task_id="execute_bteq_utf16_session_utf16_file", + file_path="/home/devtools/satish/airflow/script_utf16.bteq", + params=params, + bteq_script_encoding="UTF16", + bteq_session_encoding="UTF16", + ) + # [END bteq_operator_howto_guide_bteq_file_utf16_session_utf8_input] + ( + create_table + >> populate_table + >> export_to_a_file + >> get_it_employees + >> cond_logic + >> error_handling + >> drop_table + >> execute_bteq_file + >> execute_bteq_utf8_file + >> execute_bteq_utf8_session_ascii_file + >> execute_bteq_utf8_session_utf8_file + >> execute_bteq_utf8_session_utf16_file + >> execute_bteq_utf16_file + >> execute_bteq_utf16_session_ascii_file + >> execute_bteq_utf16_session_utf8_file + >> execute_bteq_utf16_session_utf16_file + ) + + # [END bteq_operator_howto_guide] + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/teradata/tests/system/teradata/script b/providers/teradata/tests/system/teradata/script new file mode 100644 index 0000000000000..346e647b1aeae --- /dev/null +++ b/providers/teradata/tests/system/teradata/script @@ -0,0 +1,15 @@ +.LOGON {{params.host}}/{{params.username}},{{params.password}} +.IF ERRORCODE <> 0 THEN .QUIT 8 +.SET WIDTH 500 +.SET SESSION CHARSET 'ASCII' +DATABASE {{params.DATABASE_NAME}}; +CREATE SET TABLE {{params.TABLE_NAME}} ( + emp_id INT, + emp_name VARCHAR(100), + dept VARCHAR(50) + ) PRIMARY INDEX (emp_id); +INSERT INTO {{params.TABLE_NAME}} VALUES (1, 'John Doe', 'IT'); +INSERT INTO {{params.TABLE_NAME}} VALUES (2, 'Jane Smith', 'HR'); +DROP TABLE {{params.TABLE_NAME}}; +.LOGOFF +.quit 0 diff --git a/providers/teradata/tests/system/teradata/script.bteq b/providers/teradata/tests/system/teradata/script.bteq new file mode 100644 index 0000000000000..346e647b1aeae --- /dev/null +++ b/providers/teradata/tests/system/teradata/script.bteq @@ -0,0 +1,15 @@ +.LOGON {{params.host}}/{{params.username}},{{params.password}} +.IF ERRORCODE <> 0 THEN .QUIT 8 +.SET WIDTH 500 +.SET SESSION CHARSET 'ASCII' +DATABASE {{params.DATABASE_NAME}}; +CREATE SET TABLE {{params.TABLE_NAME}} ( + emp_id INT, + emp_name VARCHAR(100), + dept VARCHAR(50) + ) PRIMARY INDEX (emp_id); +INSERT INTO {{params.TABLE_NAME}} VALUES (1, 'John Doe', 'IT'); +INSERT INTO {{params.TABLE_NAME}} VALUES (2, 'Jane Smith', 'HR'); +DROP TABLE {{params.TABLE_NAME}}; +.LOGOFF +.quit 0 diff --git a/providers/teradata/tests/system/teradata/script.sql b/providers/teradata/tests/system/teradata/script.sql new file mode 100644 index 0000000000000..84d93c5fe9aea --- /dev/null +++ b/providers/teradata/tests/system/teradata/script.sql @@ -0,0 +1,29 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +DATABASE {{params.DATABASE_NAME}}; +CREATE SET TABLE {{params.TABLE_NAME}} ( + emp_id INT, + emp_name VARCHAR(100), + dept VARCHAR(50) + ) PRIMARY INDEX (emp_id); +INSERT INTO {{params.TABLE_NAME}} VALUES (1, 'John Doe', 'IT'); +INSERT INTO {{params.TABLE_NAME}} VALUES (2, 'Jane Smith', 'HR'); +SELECT * FROM {{params.TABLE_NAME}}; +DROP TABLE {{params.TABLE_NAME}}; diff --git a/providers/teradata/tests/system/teradata/script_utf16.bteq b/providers/teradata/tests/system/teradata/script_utf16.bteq new file mode 100644 index 0000000000000..34d329deb91b0 Binary files /dev/null and b/providers/teradata/tests/system/teradata/script_utf16.bteq differ diff --git a/providers/teradata/tests/unit/teradata/hooks/test_bteq.py b/providers/teradata/tests/unit/teradata/hooks/test_bteq.py new file mode 100644 index 0000000000000..46b50e652b597 --- /dev/null +++ b/providers/teradata/tests/unit/teradata/hooks/test_bteq.py @@ -0,0 +1,364 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import subprocess +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.teradata.hooks.bteq import BteqHook + + +@pytest.fixture +def dummy_bteq_script(): + return "SELECT * FROM dbc.tables;" + + +@pytest.fixture +def dummy_remote_dir(): + return "/tmp" + + +@pytest.fixture +def dummy_encoding(): + return "utf-8" + + +@pytest.fixture +def dummy_password(): + return "dummy_password" + + +@pytest.fixture +def hook_without_ssh(): + return BteqHook(ssh_conn_id=None, teradata_conn_id="teradata_conn") + + +@patch("airflow.providers.teradata.hooks.bteq.SSHHook") +def test_init_sets_ssh_hook(mock_ssh_hook_class): + mock_ssh_instance = MagicMock() + mock_ssh_hook_class.return_value = mock_ssh_instance + + hook = BteqHook(ssh_conn_id="ssh_conn_id", teradata_conn_id="teradata_conn") + + # Validate the call and assignment + mock_ssh_hook_class.assert_called_once_with(ssh_conn_id="ssh_conn_id") + assert hook.ssh_hook == mock_ssh_instance + + +@patch("subprocess.Popen") +@patch.object( + BteqHook, + "get_conn", + return_value={ + "host": "localhost", + "login": "user", + "password": "pass", + "sp": None, + }, +) +@patch("airflow.providers.teradata.utils.bteq_util.verify_bteq_installed") +@patch("airflow.providers.teradata.utils.bteq_util.prepare_bteq_command_for_local_execution") +def test_execute_bteq_script_at_local_timeout( + mock_prepare_cmd, + mock_verify_bteq, + mock_get_conn, + mock_popen, +): + hook = BteqHook(ssh_conn_id=None, teradata_conn_id="teradata_conn") + + # Create mock process with timeout simulation + mock_process = MagicMock() + mock_process.communicate.return_value = (b"some output", None) + mock_process.wait.side_effect = subprocess.TimeoutExpired(cmd="bteq_command", timeout=5) + mock_process.returncode = None + mock_popen.return_value = mock_process + mock_prepare_cmd.return_value = "bteq_command" + + with pytest.raises(AirflowException): + hook.execute_bteq_script_at_local( + bteq_script="SELECT * FROM test;", + bteq_script_encoding="utf-8", + timeout=5, + timeout_rc=None, + bteq_quit_rc=0, + bteq_session_encoding=None, + temp_file_read_encoding=None, + ) + + +@patch("subprocess.Popen") +@patch.object( + BteqHook, + "get_conn", + return_value={ + "host": "localhost", + "login": "user", + "password": "pass", + "sp": None, + }, +) +@patch("airflow.providers.teradata.hooks.bteq.verify_bteq_installed") # <- patch here +@patch("airflow.providers.teradata.hooks.bteq.prepare_bteq_command_for_local_execution") # <- patch here too +def test_execute_bteq_script_at_local_success( + mock_prepare_cmd, + mock_verify_bteq, + mock_get_conn, + mock_popen, +): + hook = BteqHook(teradata_conn_id="teradata_conn") + + mock_process = MagicMock() + mock_process.communicate.return_value = (b"Output line 1\nOutput line 2\n", None) + mock_process.wait.return_value = 0 + mock_process.returncode = 0 + mock_popen.return_value = mock_process + mock_prepare_cmd.return_value = "bteq_command" + + ret_code = hook.execute_bteq_script_at_local( + bteq_script="SELECT * FROM test;", + bteq_script_encoding="utf-8", + timeout=10, + timeout_rc=None, + bteq_quit_rc=0, + bteq_session_encoding=None, + temp_file_read_encoding=None, + ) + + mock_verify_bteq.assert_called_once() + mock_prepare_cmd.assert_called_once() + mock_popen.assert_called_once_with( + "bteq_command", + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + shell=True, + preexec_fn=os.setsid, + ) + assert ret_code == 0 + + +@patch("subprocess.Popen") +@patch.object( + BteqHook, + "get_conn", + return_value={ + "host": "localhost", + "login": "user", + "password": "pass", + "sp": None, + }, +) +@patch("airflow.providers.teradata.hooks.bteq.verify_bteq_installed") +@patch("airflow.providers.teradata.hooks.bteq.prepare_bteq_command_for_local_execution") +def test_execute_bteq_script_at_local_failure_raises( + mock_prepare_cmd, + mock_verify_bteq, + mock_get_conn, + mock_popen, +): + hook = BteqHook(ssh_conn_id=None, teradata_conn_id="teradata_conn") + + failure_message = "Failure: some error occurred" + + mock_process = MagicMock() + # The output contains "Failure" + mock_process.communicate.return_value = (failure_message.encode("utf-8"), None) + mock_process.wait.return_value = 1 + mock_process.returncode = 1 + mock_popen.return_value = mock_process + mock_prepare_cmd.return_value = "bteq_command" + + with pytest.raises(AirflowException, match="BTEQ task failed with error: Failure: some error occurred"): + hook.execute_bteq_script_at_local( + bteq_script="SELECT * FROM test;", + bteq_script_encoding="utf-8", + timeout=10, + timeout_rc=None, + bteq_quit_rc=0, # 1 is not allowed here + bteq_session_encoding=None, + temp_file_read_encoding=None, + ) + + +@pytest.fixture(autouse=False) +def patch_ssh_hook_class(): + # Patch SSHHook where bteq.py imports it + with patch("airflow.providers.teradata.hooks.bteq.SSHHook") as mock_ssh_hook_class: + mock_ssh_instance = MagicMock() + mock_ssh_hook_class.return_value = mock_ssh_instance + yield mock_ssh_hook_class + + +@pytest.fixture +def hook_with_ssh(patch_ssh_hook_class): + # Now the BteqHook() call will use the patched SSHHook + return BteqHook(ssh_conn_id="ssh_conn_id", teradata_conn_id="teradata_conn") + + +@patch("airflow.providers.teradata.hooks.bteq.SSHHook") +@patch("airflow.providers.teradata.hooks.bteq.verify_bteq_installed_remote") +@patch("airflow.providers.teradata.hooks.bteq.generate_random_password", return_value="test_password") +@patch("airflow.providers.teradata.hooks.bteq.generate_encrypted_file_with_openssl") +@patch("airflow.providers.teradata.hooks.bteq.transfer_file_sftp") +@patch( + "airflow.providers.teradata.hooks.bteq.prepare_bteq_command_for_remote_execution", + return_value="bteq_command", +) +@patch( + "airflow.providers.teradata.hooks.bteq.decrypt_remote_file_to_string", return_value=(0, ["output"], []) +) +def test_execute_bteq_script_at_remote_success( + mock_decrypt, + mock_prepare_cmd, + mock_transfer, + mock_encrypt, + mock_password, + mock_verify, + mock_ssh_hook_class, +): + # Mock SSHHook instance and its get_conn() context manager + mock_ssh_hook = MagicMock() + mock_ssh_client = MagicMock() + mock_ssh_hook.get_conn.return_value.__enter__.return_value = mock_ssh_client + mock_ssh_hook_class.return_value = mock_ssh_hook + + # Instantiate BteqHook with ssh_conn_id (will use mocked SSHHook) + hook = BteqHook(ssh_conn_id="ssh_conn_id", teradata_conn_id="teradata_conn") + + # Call method under test + ret_code = hook.execute_bteq_script_at_remote( + bteq_script="SELECT 1;", + remote_working_dir="/tmp", + bteq_script_encoding="utf-8", + timeout=10, + timeout_rc=None, + bteq_session_encoding="utf-8", + bteq_quit_rc=0, + temp_file_read_encoding=None, + ) + + # Assert mocks called as expected + mock_verify.assert_called_once_with(mock_ssh_client) + mock_password.assert_called_once() + mock_encrypt.assert_called_once() + mock_transfer.assert_called_once() + mock_prepare_cmd.assert_called_once() + mock_decrypt.assert_called_once() + + # Assert the return code is what decrypt_remote_file_to_string returns (0 here) + assert ret_code == 0 + + +def test_on_kill_terminates_process(hook_without_ssh): + process_mock = MagicMock() + # Patch the hook's get_conn method to return a dict with the mocked process + with patch.object(hook_without_ssh, "get_conn", return_value={"sp": process_mock}): + hook_without_ssh.on_kill() + + process_mock.terminate.assert_called_once() + process_mock.wait.assert_called_once() + + +def test_on_kill_no_process(hook_without_ssh): + # Mock get_connection to avoid AirflowNotFoundException + with patch.object(hook_without_ssh, "get_connection", return_value={"host": "dummy_host"}): + # Provide a dummy conn dict to avoid errors + with patch.object(hook_without_ssh, "get_conn", return_value={"sp": None}): + # This should not raise any exceptions even if sp (process) is None + hook_without_ssh.on_kill() + + +@patch("airflow.providers.teradata.hooks.bteq.verify_bteq_installed_remote") +def test_transfer_to_and_execute_bteq_on_remote_ssh_failure(mock_verify, hook_with_ssh): + # Patch get_conn to simulate SSH failure by returning None + hook_with_ssh.ssh_hook.get_conn = MagicMock(return_value=None) + + # Patch helper functions used in the tested function to avoid side effects + with ( + patch("airflow.providers.teradata.hooks.bteq.generate_random_password", return_value="password"), + patch("airflow.providers.teradata.hooks.bteq.generate_encrypted_file_with_openssl"), + patch("airflow.providers.teradata.hooks.bteq.transfer_file_sftp"), + patch( + "airflow.providers.teradata.hooks.bteq.prepare_bteq_command_for_remote_execution", + return_value="cmd", + ), + patch( + "airflow.providers.teradata.hooks.bteq.decrypt_remote_file_to_string", return_value=(0, [], []) + ), + ): + with pytest.raises(AirflowException) as excinfo: + hook_with_ssh._transfer_to_and_execute_bteq_on_remote( + file_path="/tmp/fakefile", + remote_working_dir="/tmp", + bteq_script_encoding="utf-8", + timeout=10, + timeout_rc=None, + bteq_quit_rc=0, + bteq_session_encoding="utf-8", + tmp_dir="/tmp", + ) + assert "SSH connection is not established" in str(excinfo.value) + + +@patch("airflow.providers.teradata.hooks.bteq.verify_bteq_installed_remote") +@patch("airflow.providers.teradata.hooks.bteq.generate_random_password", return_value="testpass") +@patch("airflow.providers.teradata.hooks.bteq.generate_encrypted_file_with_openssl") +@patch("airflow.providers.teradata.hooks.bteq.transfer_file_sftp") +@patch( + "airflow.providers.teradata.hooks.bteq.prepare_bteq_command_for_remote_execution", + return_value="bteq_remote_command", +) +@patch( + "airflow.providers.teradata.hooks.bteq.decrypt_remote_file_to_string", + side_effect=Exception("mocked exception"), +) +def test_remote_execution_cleanup_on_exception( + mock_decrypt, + mock_prepare, + mock_transfer, + mock_generate_enc, + mock_generate_pass, + mock_verify_remote, + hook_with_ssh, +): + temp_dir = "/tmp" + local_file_path = os.path.join(temp_dir, "bteq_script.txt") + remote_working_dir = temp_dir + + # Make sure the local encrypted file exists for cleanup + encrypted_file_path = os.path.join(temp_dir, "bteq_script.enc") + with open(encrypted_file_path, "w") as f: + f.write("dummy") + + with pytest.raises(AirflowException): + hook_with_ssh._transfer_to_and_execute_bteq_on_remote( + file_path=local_file_path, + remote_working_dir=remote_working_dir, + bteq_script_encoding="utf-8", + timeout=5, + timeout_rc=None, + bteq_quit_rc=0, + bteq_session_encoding="utf-8", + tmp_dir=temp_dir, + ) + + # After exception, encrypted file should be deleted + assert not os.path.exists(encrypted_file_path) diff --git a/providers/teradata/tests/unit/teradata/hooks/test_ttu.py b/providers/teradata/tests/unit/teradata/hooks/test_ttu.py new file mode 100644 index 0000000000000..2c49cb7401819 --- /dev/null +++ b/providers/teradata/tests/unit/teradata/hooks/test_ttu.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import subprocess +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.teradata.hooks.ttu import TtuHook + + +class TestTtuHook: + @mock.patch("airflow.providers.teradata.hooks.ttu.TtuHook.get_connection") + def test_get_conn_with_valid_params(self, mock_get_connection): + # Setup + mock_conn = mock.MagicMock() + mock_conn.login = "test_user" + mock_conn.password = "test_pass" + mock_conn.host = "test_host" + mock_conn.extra_dejson = {} + mock_get_connection.return_value = mock_conn + + # Execute + hook = TtuHook() + conn = hook.get_conn() + + # Assert + assert conn["login"] == "test_user" + assert conn["password"] == "test_pass" + assert conn["host"] == "test_host" + + @mock.patch("airflow.providers.teradata.hooks.ttu.TtuHook.get_connection") + def test_get_conn_missing_params(self, mock_get_connection): + # Setup + mock_conn = mock.MagicMock() + mock_conn.login = None + mock_conn.password = "test_pass" + mock_conn.host = "test_host" + mock_conn.extra_dejson = {} + mock_get_connection.return_value = mock_conn + + # Execute and Assert + hook = TtuHook() + with pytest.raises(AirflowException, match="Missing required connection parameters"): + hook.get_conn() + + @mock.patch("subprocess.Popen") + @mock.patch("airflow.providers.teradata.hooks.ttu.TtuHook.get_connection") + def test_close_conn_subprocess_running(self, mock_get_connection, mock_popen): + # Setup + mock_conn = mock.MagicMock() + mock_conn.login = "test_user" + mock_conn.password = "test_pass" + mock_conn.host = "test_host" + mock_conn.extra_dejson = {} + mock_get_connection.return_value = mock_conn + + mock_process = mock.MagicMock() + mock_process.poll.return_value = None + mock_popen.return_value = mock_process + + # Execute + hook = TtuHook() + conn = hook.get_conn() + conn["sp"] = mock_process + hook.close_conn() + + # Assert + mock_process.terminate.assert_called_once() + mock_process.wait.assert_called_once_with(timeout=5) + assert hook.conn is None + + @mock.patch("subprocess.Popen") + @mock.patch("airflow.providers.teradata.hooks.ttu.TtuHook.get_connection") + def test_close_conn_subprocess_timeout(self, mock_get_connection, mock_popen): + # Setup + mock_conn = mock.MagicMock() + mock_conn.login = "test_user" + mock_conn.password = "test_pass" + mock_conn.host = "test_host" + mock_conn.extra_dejson = {} + mock_get_connection.return_value = mock_conn + + mock_process = mock.MagicMock() + mock_process.poll.return_value = None + mock_process.wait.side_effect = subprocess.TimeoutExpired(cmd="test", timeout=5) + mock_popen.return_value = mock_process + + # Execute + hook = TtuHook() + conn = hook.get_conn() + conn["sp"] = mock_process + hook.close_conn() + + # Assert + mock_process.terminate.assert_called_once() + mock_process.wait.assert_called_once() + mock_process.kill.assert_called_once() + assert hook.conn is None + + @mock.patch("airflow.providers.teradata.hooks.ttu.TtuHook.__exit__") + @mock.patch("airflow.providers.teradata.hooks.ttu.TtuHook.__enter__") + def test_hook_context_manager(self, mock_enter, mock_exit): + # Setup + hook = TtuHook() + mock_enter.return_value = hook + + # Execute + with hook as h: + assert h == hook + + # Assert + mock_exit.assert_called_once() + # Ensure the exit method was called with the correct parameters + # Context manager's __exit__ is called with (exc_type, exc_val, exc_tb) + args = mock_exit.call_args[0] + assert len(args) == 3 # Verify we have the correct number of arguments + assert args[0] is None # type should be None + assert args[1] is None # value should be None + assert args[2] is None # traceback should be None diff --git a/providers/teradata/tests/unit/teradata/operators/test_bteq.py b/providers/teradata/tests/unit/teradata/operators/test_bteq.py new file mode 100644 index 0000000000000..d6690096aa83f --- /dev/null +++ b/providers/teradata/tests/unit/teradata/operators/test_bteq.py @@ -0,0 +1,288 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +import tempfile +import unittest +from unittest import mock + +import pytest + +from airflow.providers.teradata.hooks.bteq import BteqHook +from airflow.providers.teradata.operators.bteq import BteqOperator + +log = logging.getLogger(__name__) + + +class TestBteqOperator: + @mock.patch.object(BteqHook, "execute_bteq_script") + @mock.patch.object(BteqHook, "__init__", return_value=None) + def test_execute(self, mock_hook_init, mock_execute_bteq): + task_id = "test_bteq_operator" + sql = "SELECT * FROM my_table;" + teradata_conn_id = "teradata_default" + mock_context = {} + # Given + expected_result = "BTEQ execution result" + mock_execute_bteq.return_value = expected_result + operator = BteqOperator( + task_id=task_id, + sql=sql, + teradata_conn_id=teradata_conn_id, + ) + + # When + result = operator.execute(mock_context) + + # Then + mock_hook_init.assert_called_once_with(teradata_conn_id=teradata_conn_id, ssh_conn_id=None) + mock_execute_bteq.assert_called_once_with(sql + "\n.EXIT", "/tmp", "", 600, None, "", None, "UTF-8") + assert result == "BTEQ execution result" + + @mock.patch.object(BteqHook, "execute_bteq_script") + @mock.patch.object(BteqHook, "__init__", return_value=None) + def test_execute_sql_only(self, mock_hook_init, mock_execute_bteq): + # Arrange + task_id = "test_bteq_operator" + sql = "SELECT * FROM my_table;" + teradata_conn_id = "teradata_default" + mock_context = {} + expected_result = "BTEQ execution result" + mock_execute_bteq.return_value = expected_result + + operator = BteqOperator( + task_id=task_id, + sql=sql, + teradata_conn_id=teradata_conn_id, + ) + # Manually set _hook since we bypassed __init__ + operator._hook = mock.MagicMock() + operator._hook.execute_bteq_script = mock_execute_bteq + + # Act + result = operator.execute(mock_context) + + # Assert + mock_hook_init.assert_called_once_with(teradata_conn_id=teradata_conn_id, ssh_conn_id=None) + mock_execute_bteq.assert_called_once_with( + sql + "\n.EXIT", # Assuming the prepare_bteq_script_for_local_execution appends ".EXIT" + "/tmp", # default remote_working_dir + "", # bteq_script_encoding (default ASCII => empty string) + 600, # timeout default + None, # timeout_rc + "", # bteq_session_encoding + None, # bteq_quit_rc + "UTF-8", + ) + assert result == expected_result + + @mock.patch("airflow.providers.teradata.operators.bteq.BteqHook.execute_bteq_script") + @mock.patch("airflow.providers.teradata.operators.bteq.BteqHook.__init__", return_value=None) + def test_execute_sql_local(self, mock_hook_init, mock_execute_script): + sql = "SELECT * FROM test_table;" + expected_result = 0 + mock_execute_script.return_value = expected_result + context = {} + + op = BteqOperator( + task_id="test_local_sql", + sql=sql, + teradata_conn_id="td_conn", + ) + op._hook = mock.Mock() + op._hook.execute_bteq_script = mock_execute_script + + result = op.execute(context) + + mock_hook_init.assert_called_once_with(teradata_conn_id="td_conn", ssh_conn_id=None) + mock_execute_script.assert_called_once() + assert result == expected_result + + @mock.patch.object(BteqHook, "on_kill") + def test_on_kill(self, mock_on_kill): + task_id = "test_bteq_operator" + sql = "SELECT * FROM my_table;" + # Given + operator = BteqOperator( + task_id=task_id, + sql=sql, + ) + operator._hook = BteqHook(None) + + # When + operator.on_kill() + + # Then + mock_on_kill.assert_called_once() + + def test_on_kill_not_initialized(self): + task_id = "test_bteq_operator" + sql = "SELECT * FROM my_table;" + # Given + operator = BteqOperator( + task_id=task_id, + sql=sql, + ) + operator._hook = None + + # When/Then (no exception should be raised) + operator.on_kill() + + def test_template_fields(self): + # Verify template fields are defined correctly + print(BteqOperator.template_fields) + assert BteqOperator.template_fields == "sql" + + def test_execute_raises_if_no_sql_or_file(self): + op = BteqOperator(task_id="fail_case", teradata_conn_id="td_conn") + with pytest.raises(ValueError, match="requires either the 'sql' or 'file_path' parameter"): + op.execute({}) + + @mock.patch("airflow.providers.teradata.operators.bteq.is_valid_file", return_value=False) + def test_invalid_file_path(self, mock_is_valid_file): + op = BteqOperator( + task_id="fail_invalid_file", + file_path="/invalid/path.sql", + teradata_conn_id="td_conn", + ) + with pytest.raises(ValueError, match="is invalid or does not exist"): + op.execute({}) + + @mock.patch("airflow.providers.teradata.operators.bteq.is_valid_file", return_value=True) + @mock.patch( + "airflow.providers.teradata.operators.bteq.is_valid_encoding", + side_effect=UnicodeDecodeError("utf8", b"", 0, 1, "error"), + ) + def test_file_encoding_error(self, mock_encoding, mock_valid_file): + op = BteqOperator( + task_id="encoding_fail", + file_path="/tmp/test.sql", + bteq_script_encoding="UTF-8", + teradata_conn_id="td_conn", + ) + with pytest.raises(ValueError, match="encoding is different from BTEQ I/O encoding"): + op.execute({}) + + @mock.patch("airflow.providers.teradata.operators.bteq.BteqHook.execute_bteq_script") + @mock.patch("airflow.providers.teradata.operators.bteq.is_valid_file", return_value=True) + @mock.patch("airflow.providers.teradata.operators.bteq.is_valid_encoding") + @mock.patch("airflow.providers.teradata.operators.bteq.read_file") + def test_execute_local_file( + self, + mock_read_file, + mock_valid_encoding, + mock_valid_file, + mock_execute_bteq_script, + ): + mock_execute_bteq_script.return_value = 0 + sql_content = "SELECT * FROM table_name;" + mock_read_file.return_value = sql_content + + with tempfile.NamedTemporaryFile("w+", suffix=".sql", delete=False) as tmp_file: + tmp_file.write(sql_content) + tmp_file_path = tmp_file.name + + op = BteqOperator( + task_id="test_bteq_local_file", + file_path=tmp_file_path, + teradata_conn_id="teradata_default", + ) + + result = op.execute(context={}) + + assert result == 0 + mock_execute_bteq_script.assert_called_once() + + def test_on_kill_calls_hook(self): + op = BteqOperator(task_id="kill_test", teradata_conn_id="td_conn") + op._hook = mock.Mock() + op.on_kill() + op._hook.on_kill.assert_called_once() + + def test_on_kill_logs_if_no_hook(self): + op = BteqOperator(task_id="kill_no_hook", teradata_conn_id="td_conn") + op._hook = None + + with mock.patch.object(op.log, "warning") as mock_log_info: + op.on_kill() + mock_log_info.assert_called_once_with("BteqHook was not initialized. Nothing to terminate.") + + @mock.patch("airflow.providers.teradata.operators.bteq.BteqHook.execute_bteq_script") + @mock.patch("airflow.providers.teradata.operators.bteq.BteqHook.get_conn") + @mock.patch("airflow.providers.teradata.operators.bteq.SSHHook") + @mock.patch("airflow.providers.teradata.operators.bteq.BteqHook.__init__", return_value=None) + def test_remote_execution_with_sql( + self, + mock_bteq_hook_init, + mock_ssh_hook_class, + mock_get_conn, + mock_execute_bteq_script, + ): + mock_execute_bteq_script.return_value = 0 + mock_ssh_hook_instance = mock.Mock() + mock_ssh_hook_class.return_value = mock_ssh_hook_instance + + op = BteqOperator( + task_id="test_remote_sql", + sql="SELECT * FROM customers;", + ssh_conn_id="ssh_default", + teradata_conn_id="teradata_default", + ) + + result = op.execute(context={}) + + mock_bteq_hook_init.assert_called_once_with( + teradata_conn_id="teradata_default", ssh_conn_id="ssh_default" + ) + mock_execute_bteq_script.assert_called_once() + assert result == 0 + + @mock.patch("airflow.models.BaseOperator.render_template") + def test_render_template_in_sql(self, mock_render): + op = BteqOperator(task_id="render_test", sql="SELECT * FROM {{ params.table }};") + mock_render.return_value = "SELECT * FROM my_table;" + rendered_sql = op.render_template("sql", op.sql, context={"params": {"table": "my_table"}}) + assert rendered_sql == "SELECT * FROM my_table;" + + @mock.patch("airflow.providers.teradata.operators.bteq.BteqHook.execute_bteq_script", return_value=99) + @mock.patch("airflow.providers.teradata.operators.bteq.BteqHook.__init__", return_value=None) + def test_bteq_timeout_with_custom_rc(self, mock_hook_init, mock_exec): + op = BteqOperator( + task_id="timeout_case", + sql="SELECT 1", + teradata_conn_id="td_conn", + timeout=30, + timeout_rc=99, + bteq_quit_rc=[99], + ) + result = op.execute({}) + assert result == 99 + mock_exec.assert_called_once() + + @mock.patch("airflow.providers.teradata.operators.bteq.BteqHook.execute_bteq_script", return_value=42) + @mock.patch("airflow.providers.teradata.operators.bteq.BteqHook.__init__", return_value=None) + def test_bteq_return_code_not_in_quit_rc(self, mock_hook_init, mock_exec): + op = BteqOperator( + task_id="rc_not_allowed", sql="SELECT 1", teradata_conn_id="td_conn", bteq_quit_rc=[0, 1] + ) + result = op.execute({}) + assert result == 42 # still returns, but caller can fail on RC if desired + + +if __name__ == "__main__": + unittest.main() diff --git a/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py b/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py new file mode 100644 index 0000000000000..f0ee54aaa6804 --- /dev/null +++ b/providers/teradata/tests/unit/teradata/utils/test_bteq_util.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import stat +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.teradata.utils.bteq_util import ( + is_valid_encoding, + is_valid_file, + is_valid_remote_bteq_script_file, + prepare_bteq_script_for_local_execution, + prepare_bteq_script_for_remote_execution, + read_file, + transfer_file_sftp, + verify_bteq_installed, + verify_bteq_installed_remote, +) + + +class TestBteqUtils: + @patch("shutil.which") + def test_verify_bteq_installed_success(self, mock_which): + mock_which.return_value = "/usr/bin/bteq" + # Should not raise + verify_bteq_installed() + mock_which.assert_called_with("bteq") + + @patch("shutil.which") + def test_verify_bteq_installed_fail(self, mock_which): + mock_which.return_value = None + with pytest.raises(AirflowException): + verify_bteq_installed() + + def test_prepare_bteq_script_for_remote_execution(self): + conn = {"host": "myhost", "login": "user", "password": "pass"} + sql = "SELECT * FROM DUAL;" + script = prepare_bteq_script_for_remote_execution(conn, sql) + assert ".LOGON myhost/user,pass" in script + assert "SELECT * FROM DUAL;" in script + assert ".EXIT" in script + + def test_prepare_bteq_script_for_local_execution(self): + sql = "SELECT 1;" + script = prepare_bteq_script_for_local_execution(sql) + assert "SELECT 1;" in script + assert ".EXIT" in script + + @patch("paramiko.SSHClient.exec_command") + def test_verify_bteq_installed_remote_success(self, mock_exec): + mock_stdin = MagicMock() + mock_stdout = MagicMock() + mock_stderr = MagicMock() + mock_stdout.channel.recv_exit_status.return_value = 0 + mock_stdout.read.return_value = b"/usr/bin/bteq" + mock_stderr.read.return_value = b"" + mock_exec.return_value = (mock_stdin, mock_stdout, mock_stderr) + + ssh_client = MagicMock() + ssh_client.exec_command = mock_exec + + # Should not raise + verify_bteq_installed_remote(ssh_client) + + @patch("paramiko.SSHClient.exec_command") + def test_verify_bteq_installed_remote_fail(self, mock_exec): + mock_stdin = MagicMock() + mock_stdout = MagicMock() + mock_stderr = MagicMock() + mock_stdout.channel.recv_exit_status.return_value = 1 + mock_stdout.read.return_value = b"" + mock_stderr.read.return_value = b"command not found" + mock_exec.return_value = (mock_stdin, mock_stdout, mock_stderr) + + ssh_client = MagicMock() + ssh_client.exec_command = mock_exec + + with pytest.raises(AirflowException): + verify_bteq_installed_remote(ssh_client) + + @patch("paramiko.SSHClient.open_sftp") + def test_transfer_file_sftp(self, mock_open_sftp): + mock_sftp = MagicMock() + mock_open_sftp.return_value = mock_sftp + + ssh_client = MagicMock() + ssh_client.open_sftp = mock_open_sftp + + transfer_file_sftp(ssh_client, "local_file.txt", "remote_file.txt") + + mock_open_sftp.assert_called_once() + mock_sftp.put.assert_called_once_with("local_file.txt", "remote_file.txt") + mock_sftp.close.assert_called_once() + + def test_is_valid_file(self): + # create temp file + with open("temp_test_file.txt", "w") as f: + f.write("hello") + + assert is_valid_file("temp_test_file.txt") is True + assert is_valid_file("non_existent_file.txt") is False + + os.remove("temp_test_file.txt") + + def test_is_valid_encoding(self): + # Write a file with UTF-8 encoding + with open("temp_utf8_file.txt", "w", encoding="utf-8") as f: + f.write("hello world") + + # Should return True + assert is_valid_encoding("temp_utf8_file.txt", encoding="utf-8") is True + + # Cleanup + os.remove("temp_utf8_file.txt") + + def test_read_file_success(self): + content = "Sample content" + with open("temp_read_file.txt", "w") as f: + f.write(content) + + read_content = read_file("temp_read_file.txt") + assert read_content == content + os.remove("temp_read_file.txt") + + def test_read_file_file_not_found(self): + with pytest.raises(FileNotFoundError): + read_file("non_existent_file.txt") + + @patch("paramiko.SSHClient.open_sftp") + def test_is_valid_remote_bteq_script_file_exists(self, mock_open_sftp): + mock_sftp = MagicMock() + mock_open_sftp.return_value = mock_sftp + + # Mock stat to return a regular file mode + mock_stat = MagicMock() + mock_stat.st_mode = stat.S_IFREG + mock_sftp.stat.return_value = mock_stat + + ssh_client = MagicMock() + ssh_client.open_sftp = mock_open_sftp + + result = is_valid_remote_bteq_script_file(ssh_client, "/remote/path/to/file") + assert result is True + mock_sftp.close.assert_called_once() + + @patch("paramiko.SSHClient.open_sftp") + def test_is_valid_remote_bteq_script_file_not_exists(self, mock_open_sftp): + mock_sftp = MagicMock() + mock_open_sftp.return_value = mock_sftp + + # Raise FileNotFoundError for stat + mock_sftp.stat.side_effect = FileNotFoundError + + ssh_client = MagicMock() + ssh_client.open_sftp = mock_open_sftp + + result = is_valid_remote_bteq_script_file(ssh_client, "/remote/path/to/file") + assert result is False + mock_sftp.close.assert_called_once() + + def test_is_valid_remote_bteq_script_file_none_path(self): + ssh_client = MagicMock() + result = is_valid_remote_bteq_script_file(ssh_client, None) + assert result is False + + +if __name__ == "__main__": + unittest.main() diff --git a/providers/teradata/tests/unit/teradata/utils/test_encryption_utils.py b/providers/teradata/tests/unit/teradata/utils/test_encryption_utils.py new file mode 100644 index 0000000000000..ba688182ae6a9 --- /dev/null +++ b/providers/teradata/tests/unit/teradata/utils/test_encryption_utils.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import string +import unittest +from unittest.mock import MagicMock, patch + +from airflow.providers.teradata.utils.encryption_utils import ( + decrypt_remote_file_to_string, + generate_encrypted_file_with_openssl, + generate_random_password, + shell_quote_single, +) + + +class TestEncryptionUtils: + def test_generate_random_password_length(self): + pwd = generate_random_password(16) + assert len(pwd) == 16 + # Check characters are in allowed set + allowed_chars = string.ascii_letters + string.digits + string.punctuation + assert (all(c in allowed_chars for c in pwd)) is True + + @patch("subprocess.run") + def test_generate_encrypted_file_with_openssl_calls_subprocess(self, mock_run): + file_path = "/tmp/plain.txt" + password = "testpass" + out_file = "/tmp/encrypted.enc" + + generate_encrypted_file_with_openssl(file_path, password, out_file) + + mock_run.assert_called_once_with( + [ + "openssl", + "enc", + "-aes-256-cbc", + "-salt", + "-pbkdf2", + "-pass", + f"pass:{password}", + "-in", + file_path, + "-out", + out_file, + ], + check=True, + ) + + def test_shell_quote_single_simple(self): + s = "simple" + quoted = shell_quote_single(s) + assert quoted == "'simple'" + + def test_shell_quote_single_with_single_quote(self): + s = "O'Reilly" + quoted = shell_quote_single(s) + assert quoted == "'O'\\''Reilly'" + + def test_decrypt_remote_file_to_string(self): + password = "mysecret" + remote_enc_file = "/remote/encrypted.enc" + bteq_command_str = "bteq -c UTF-8" + + ssh_client = MagicMock() + mock_stdin = MagicMock() + mock_stdout = MagicMock() + mock_stderr = MagicMock() + + # Setup mock outputs and exit code + mock_stdout.channel.recv_exit_status.return_value = 0 + mock_stdout.read.return_value = b"decrypted output" + mock_stderr.read.return_value = b"" + + ssh_client.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr) + + exit_status, output, err = decrypt_remote_file_to_string( + ssh_client, remote_enc_file, password, bteq_command_str + ) + + quoted_password = shell_quote_single(password) + expected_cmd = ( + f"openssl enc -d -aes-256-cbc -salt -pbkdf2 -pass pass:{quoted_password} -in {remote_enc_file} | " + + bteq_command_str + ) + + ssh_client.exec_command.assert_called_once_with(expected_cmd) + assert exit_status == 0 + assert output == "decrypted output" + assert err == "" + + +if __name__ == "__main__": + unittest.main()