Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions providers/teradata/src/airflow/providers/teradata/hooks/bteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.teradata.hooks.ttu import TtuHook
from airflow.providers.teradata.utils.bteq_util import (
get_remote_tmp_dir,
identify_os,
prepare_bteq_command_for_local_execution,
prepare_bteq_command_for_remote_execution,
transfer_file_sftp,
Expand Down Expand Up @@ -161,7 +163,13 @@ def _transfer_to_and_execute_bteq_on_remote(
password = generate_random_password() # Encryption/Decryption password
encrypted_file_path = os.path.join(tmp_dir, "bteq_script.enc")
generate_encrypted_file_with_openssl(file_path, password, encrypted_file_path)
if not remote_working_dir:
remote_working_dir = get_remote_tmp_dir(ssh_client)
self.log.debug(
"Transferring encrypted BTEQ script to remote host: %s", remote_working_dir
)
remote_encrypted_path = os.path.join(remote_working_dir or "", "bteq_script.enc")
remote_encrypted_path = remote_encrypted_path.replace("/", "\\")

transfer_file_sftp(ssh_client, encrypted_file_path, remote_encrypted_path)

Expand Down Expand Up @@ -219,14 +227,20 @@ def _transfer_to_and_execute_bteq_on_remote(
if encrypted_file_path and os.path.exists(encrypted_file_path):
os.remove(encrypted_file_path)
# Cleanup: Delete the remote temporary file
if encrypted_file_path:
cleanup_en_command = f"rm -f {remote_encrypted_path}"
if remote_encrypted_path:
if self.ssh_hook and self.ssh_hook.get_conn():
with self.ssh_hook.get_conn() as ssh_client:
if ssh_client is None:
raise AirflowException(
"Failed to establish SSH connection. `ssh_client` is None."
)
# Detect OS
os_info = identify_os(ssh_client)
if "windows" in os_info:
cleanup_en_command = f'del /f /q "{remote_encrypted_path}"'
else:
cleanup_en_command = f"rm -f '{remote_encrypted_path}'"
self.log.debug("cleaning up remote file: %s", cleanup_en_command)
ssh_client.exec_command(cleanup_en_command)

def execute_bteq_script_at_local(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ def execute(self, context: Context) -> int | None:
elif self.bteq_script_encoding == "UTF16":
self.temp_file_read_encoding = "UTF-16"

if not self.remote_working_dir:
self.remote_working_dir = "/tmp"
# Handling execution on local:
if not self._ssh_hook:
if self.sql:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from airflow.exceptions import AirflowException


def identify_os(ssh_client: SSHClient) -> str:
stdin, stdout, stderr = ssh_client.exec_command("uname || ver")
return stdout.read().decode().lower()


def verify_bteq_installed():
"""Verify if BTEQ is installed and available in the system's PATH."""
if shutil.which("bteq") is None:
Expand All @@ -36,7 +41,23 @@ def verify_bteq_installed():

def verify_bteq_installed_remote(ssh_client: SSHClient):
"""Verify if BTEQ is installed on the remote machine."""
stdin, stdout, stderr = ssh_client.exec_command("which bteq")
# Detect OS
os_info = identify_os(ssh_client)

if "windows" in os_info:
check_cmd = "where bteq"
elif "darwin" in os_info:
# Check if zsh exists first
stdin, stdout, stderr = ssh_client.exec_command("command -v zsh")
zsh_path = stdout.read().strip()
if zsh_path:
check_cmd = 'zsh -l -c "which bteq"'
else:
check_cmd = "which bteq"
else:
check_cmd = "which bteq"

stdin, stdout, stderr = ssh_client.exec_command(check_cmd)
exit_status = stdout.channel.recv_exit_status()
output = stdout.read().strip()
error = stderr.read().strip()
Expand All @@ -53,6 +74,20 @@ def transfer_file_sftp(ssh_client, local_path, remote_path):
sftp.close()


def get_remote_tmp_dir(ssh_client):
os_info = identify_os(ssh_client)

if "windows" in os_info:
# Try getting Windows temp dir
stdin, stdout, stderr = ssh_client.exec_command("echo %TEMP%")
tmp_dir = stdout.read().decode().strip()
if not tmp_dir:
tmp_dir = "C:\\Temp"
else:
tmp_dir = "/tmp"
return tmp_dir


# We can not pass host details with bteq command when executing on remote machine. Instead, we will prepare .logon in bteq script itself to avoid risk of
# exposing sensitive information
def prepare_bteq_script_for_remote_execution(conn: dict[str, Any], sql: str) -> str:
Expand Down
32 changes: 27 additions & 5 deletions providers/teradata/tests/unit/teradata/hooks/test_bteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,14 @@ def test_execute_bteq_script_at_remote_success(
mock_ssh_hook.get_conn.return_value.__enter__.return_value = mock_ssh_client
mock_ssh_hook_class.return_value = mock_ssh_hook

# Instantiate BteqHook with ssh_conn_id (will use mocked SSHHook)
# Mock exec_command to simulate 'uname || ver'
mock_stdin = MagicMock()
mock_stdout = MagicMock()
mock_stderr = MagicMock()
mock_stdout.read.return_value = b"Linux\n"
mock_ssh_client.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr)

# Instantiate BteqHook
hook = BteqHook(ssh_conn_id="ssh_conn_id", teradata_conn_id="teradata_conn")

# Call method under test
Expand Down Expand Up @@ -342,13 +349,28 @@ def test_remote_execution_cleanup_on_exception(
temp_dir = "/tmp"
local_file_path = os.path.join(temp_dir, "bteq_script.txt")
remote_working_dir = temp_dir

# Make sure the local encrypted file exists for cleanup
encrypted_file_path = os.path.join(temp_dir, "bteq_script.enc")

# Create dummy local encrypted file
with open(encrypted_file_path, "w") as f:
f.write("dummy")

with pytest.raises(AirflowException):
# Simulate decrypt failing
mock_decrypt.side_effect = Exception("mocked exception")

# Patch exec_command for remote cleanup (identify_os, rm)
ssh_client = hook_with_ssh.ssh_hook.get_conn.return_value.__enter__.return_value

mock_stdin = MagicMock()
mock_stdout = MagicMock()
mock_stderr = MagicMock()

# For identify_os ("uname || ver")
mock_stdout.read.return_value = b"Linux\n"
ssh_client.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr)

# Run the test
with pytest.raises(AirflowException, match="mocked exception"):
hook_with_ssh._transfer_to_and_execute_bteq_on_remote(
file_path=local_file_path,
remote_working_dir=remote_working_dir,
Expand All @@ -360,5 +382,5 @@ def test_remote_execution_cleanup_on_exception(
tmp_dir=temp_dir,
)

# After exception, encrypted file should be deleted
# Verify local encrypted file is deleted
assert not os.path.exists(encrypted_file_path)
4 changes: 2 additions & 2 deletions providers/teradata/tests/unit/teradata/operators/test_bteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_execute(self, mock_hook_init, mock_execute_bteq):

# Then
mock_hook_init.assert_called_once_with(teradata_conn_id=teradata_conn_id, ssh_conn_id=None)
mock_execute_bteq.assert_called_once_with(sql + "\n.EXIT", "/tmp", "", 600, None, "", None, "UTF-8")
mock_execute_bteq.assert_called_once_with(sql + "\n.EXIT", None, "", 600, None, "", None, "UTF-8")
assert result == "BTEQ execution result"

@mock.patch.object(BteqHook, "execute_bteq_script")
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_execute_sql_only(self, mock_hook_init, mock_execute_bteq):
mock_hook_init.assert_called_once_with(teradata_conn_id=teradata_conn_id, ssh_conn_id=None)
mock_execute_bteq.assert_called_once_with(
sql + "\n.EXIT", # Assuming the prepare_bteq_script_for_local_execution appends ".EXIT"
"/tmp", # default remote_working_dir
None, # default remote_working_dir
"", # bteq_script_encoding (default ASCII => empty string)
600, # timeout default
None, # timeout_rc
Expand Down
Loading