Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions providers/teradata/src/airflow/providers/teradata/hooks/bteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.teradata.hooks.ttu import TtuHook
from airflow.providers.teradata.utils.bteq_util import (
get_remote_tmp_dir,
identify_os,
prepare_bteq_command_for_local_execution,
prepare_bteq_command_for_remote_execution,
transfer_file_sftp,
Expand Down Expand Up @@ -161,7 +163,13 @@ def _transfer_to_and_execute_bteq_on_remote(
password = generate_random_password() # Encryption/Decryption password
encrypted_file_path = os.path.join(tmp_dir, "bteq_script.enc")
generate_encrypted_file_with_openssl(file_path, password, encrypted_file_path)
if not remote_working_dir:
remote_working_dir = get_remote_tmp_dir(ssh_client)
self.log.debug(
"Transferring encrypted BTEQ script to remote host: %s", remote_working_dir
)
remote_encrypted_path = os.path.join(remote_working_dir or "", "bteq_script.enc")
remote_encrypted_path = remote_encrypted_path.replace("/", "\\")

transfer_file_sftp(ssh_client, encrypted_file_path, remote_encrypted_path)

Expand Down Expand Up @@ -219,14 +227,20 @@ def _transfer_to_and_execute_bteq_on_remote(
if encrypted_file_path and os.path.exists(encrypted_file_path):
os.remove(encrypted_file_path)
# Cleanup: Delete the remote temporary file
if encrypted_file_path:
cleanup_en_command = f"rm -f {remote_encrypted_path}"
if remote_encrypted_path:
if self.ssh_hook and self.ssh_hook.get_conn():
with self.ssh_hook.get_conn() as ssh_client:
if ssh_client is None:
raise AirflowException(
"Failed to establish SSH connection. `ssh_client` is None."
)
# Detect OS
os_info = identify_os(ssh_client)
if "windows" in os_info:
cleanup_en_command = f'del /f /q "{remote_encrypted_path}"'
else:
cleanup_en_command = f"rm -f '{remote_encrypted_path}'"
self.log.debug("cleaning up remote file: %s", cleanup_en_command)
ssh_client.exec_command(cleanup_en_command)

def execute_bteq_script_at_local(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ def execute(self, context: Context) -> int | None:
elif self.bteq_script_encoding == "UTF16":
self.temp_file_read_encoding = "UTF-16"

if not self.remote_working_dir:
self.remote_working_dir = "/tmp"
# Handling execution on local:
if not self._ssh_hook:
if self.sql:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from airflow.exceptions import AirflowException


def identify_os(ssh_client: SSHClient) -> str:
stdin, stdout, stderr = ssh_client.exec_command("uname || ver")
return stdout.read().decode().lower()


def verify_bteq_installed():
"""Verify if BTEQ is installed and available in the system's PATH."""
if shutil.which("bteq") is None:
Expand All @@ -36,7 +41,15 @@ def verify_bteq_installed():

def verify_bteq_installed_remote(ssh_client: SSHClient):
"""Verify if BTEQ is installed on the remote machine."""
stdin, stdout, stderr = ssh_client.exec_command("which bteq")
# Detect OS
os_info = identify_os(ssh_client)

if "windows" in os_info:
check_cmd = "where bteq"
else:
check_cmd = "which bteq"

stdin, stdout, stderr = ssh_client.exec_command(check_cmd)
exit_status = stdout.channel.recv_exit_status()
output = stdout.read().strip()
error = stderr.read().strip()
Expand All @@ -53,6 +66,20 @@ def transfer_file_sftp(ssh_client, local_path, remote_path):
sftp.close()


def get_remote_tmp_dir(ssh_client):
os_info = identify_os(ssh_client)

if "windows" in os_info:
# Try getting Windows temp dir
stdin, stdout, stderr = ssh_client.exec_command("echo %TEMP%")
tmp_dir = stdout.read().decode().strip()
if not tmp_dir:
tmp_dir = "C:\\Temp"
else:
tmp_dir = "/tmp"
return tmp_dir


# We can not pass host details with bteq command when executing on remote machine. Instead, we will prepare .logon in bteq script itself to avoid risk of
# exposing sensitive information
def prepare_bteq_script_for_remote_execution(conn: dict[str, Any], sql: str) -> str:
Expand Down
32 changes: 27 additions & 5 deletions providers/teradata/tests/unit/teradata/hooks/test_bteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,14 @@ def test_execute_bteq_script_at_remote_success(
mock_ssh_hook.get_conn.return_value.__enter__.return_value = mock_ssh_client
mock_ssh_hook_class.return_value = mock_ssh_hook

# Instantiate BteqHook with ssh_conn_id (will use mocked SSHHook)
# Mock exec_command to simulate 'uname || ver'
mock_stdin = MagicMock()
mock_stdout = MagicMock()
mock_stderr = MagicMock()
mock_stdout.read.return_value = b"Linux\n"
mock_ssh_client.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr)

# Instantiate BteqHook
hook = BteqHook(ssh_conn_id="ssh_conn_id", teradata_conn_id="teradata_conn")

# Call method under test
Expand Down Expand Up @@ -342,13 +349,28 @@ def test_remote_execution_cleanup_on_exception(
temp_dir = "/tmp"
local_file_path = os.path.join(temp_dir, "bteq_script.txt")
remote_working_dir = temp_dir

# Make sure the local encrypted file exists for cleanup
encrypted_file_path = os.path.join(temp_dir, "bteq_script.enc")

# Create dummy local encrypted file
with open(encrypted_file_path, "w") as f:
f.write("dummy")

with pytest.raises(AirflowException):
# Simulate decrypt failing
mock_decrypt.side_effect = Exception("mocked exception")

# Patch exec_command for remote cleanup (identify_os, rm)
ssh_client = hook_with_ssh.ssh_hook.get_conn.return_value.__enter__.return_value

mock_stdin = MagicMock()
mock_stdout = MagicMock()
mock_stderr = MagicMock()

# For identify_os ("uname || ver")
mock_stdout.read.return_value = b"Linux\n"
ssh_client.exec_command.return_value = (mock_stdin, mock_stdout, mock_stderr)

# Run the test
with pytest.raises(AirflowException, match="mocked exception"):
hook_with_ssh._transfer_to_and_execute_bteq_on_remote(
file_path=local_file_path,
remote_working_dir=remote_working_dir,
Expand All @@ -360,5 +382,5 @@ def test_remote_execution_cleanup_on_exception(
tmp_dir=temp_dir,
)

# After exception, encrypted file should be deleted
# Verify local encrypted file is deleted
assert not os.path.exists(encrypted_file_path)
4 changes: 2 additions & 2 deletions providers/teradata/tests/unit/teradata/operators/test_bteq.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_execute(self, mock_hook_init, mock_execute_bteq):

# Then
mock_hook_init.assert_called_once_with(teradata_conn_id=teradata_conn_id, ssh_conn_id=None)
mock_execute_bteq.assert_called_once_with(sql + "\n.EXIT", "/tmp", "", 600, None, "", None, "UTF-8")
mock_execute_bteq.assert_called_once_with(sql + "\n.EXIT", None, "", 600, None, "", None, "UTF-8")
assert result == "BTEQ execution result"

@mock.patch.object(BteqHook, "execute_bteq_script")
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_execute_sql_only(self, mock_hook_init, mock_execute_bteq):
mock_hook_init.assert_called_once_with(teradata_conn_id=teradata_conn_id, ssh_conn_id=None)
mock_execute_bteq.assert_called_once_with(
sql + "\n.EXIT", # Assuming the prepare_bteq_script_for_local_execution appends ".EXIT"
"/tmp", # default remote_working_dir
None, # default remote_working_dir
"", # bteq_script_encoding (default ASCII => empty string)
600, # timeout default
None, # timeout_rc
Expand Down
124 changes: 108 additions & 16 deletions providers/teradata/tests/unit/teradata/utils/test_bteq_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.teradata.utils.bteq_util import (
identify_os,
is_valid_encoding,
is_valid_file,
is_valid_remote_bteq_script_file,
Expand All @@ -38,6 +39,62 @@


class TestBteqUtils:
def test_identify_os_linux(self):
# Arrange
ssh_client = MagicMock()
stdout_mock = MagicMock()
stdout_mock.read.return_value = b"Linux\n"
ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock())

# Act
os_info = identify_os(ssh_client)

# Assert
ssh_client.exec_command.assert_called_once_with("uname || ver")
assert os_info == "linux\n"

def test_identify_os_windows(self):
# Arrange
ssh_client = MagicMock()
stdout_mock = MagicMock()
stdout_mock.read.return_value = b"Microsoft Windows [Version 10.0.19045.3324]\n"
ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock())

# Act
os_info = identify_os(ssh_client)

# Assert
ssh_client.exec_command.assert_called_once_with("uname || ver")
assert "windows" in os_info

def test_identify_os_macos(self):
# Arrange
ssh_client = MagicMock()
stdout_mock = MagicMock()
stdout_mock.read.return_value = b"Darwin\n"
ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock())

# Act
os_info = identify_os(ssh_client)

# Assert
ssh_client.exec_command.assert_called_once_with("uname || ver")
assert os_info == "darwin\n"

def test_identify_os_empty_response(self):
# Arrange
ssh_client = MagicMock()
stdout_mock = MagicMock()
stdout_mock.read.return_value = b""
ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock())

# Act
os_info = identify_os(ssh_client)

# Assert
ssh_client.exec_command.assert_called_once_with("uname || ver")
assert os_info == ""

@patch("shutil.which")
def test_verify_bteq_installed_success(self, mock_which):
mock_which.return_value = "/usr/bin/bteq"
Expand Down Expand Up @@ -65,6 +122,57 @@ def test_prepare_bteq_script_for_local_execution(self):
assert "SELECT 1;" in script
assert ".EXIT" in script

@patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="linux")
def test_verify_bteq_installed_remote_linux(self, mock_os):
ssh_client = MagicMock()
stdout_mock = MagicMock()
stdout_mock.read.return_value = b"/usr/bin/bteq"
stdout_mock.channel.recv_exit_status.return_value = 0

ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock())

verify_bteq_installed_remote(ssh_client)
ssh_client.exec_command.assert_called_once_with("which bteq")

@patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="windows")
def test_verify_bteq_installed_remote_windows(self, mock_os):
ssh_client = MagicMock()
stdout_mock = MagicMock()
stdout_mock.read.return_value = b"C:\\Program Files\\bteq.exe"
stdout_mock.channel.recv_exit_status.return_value = 0

ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock())

verify_bteq_installed_remote(ssh_client)
ssh_client.exec_command.assert_called_once_with("where bteq")

@patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="darwin")
def test_verify_bteq_installed_remote_macos(self, mock_os):
ssh_client = MagicMock()
stdout_mock = MagicMock()
stdout_mock.read.return_value = b"/usr/local/bin/bteq"
stdout_mock.channel.recv_exit_status.return_value = 0

ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, MagicMock())

verify_bteq_installed_remote(ssh_client)
ssh_client.exec_command.assert_called_once_with("which bteq")

@patch("airflow.providers.teradata.utils.bteq_util.identify_os", return_value="linux")
def test_verify_bteq_installed_remote_fail(self, mock_os):
ssh_client = MagicMock()
stdout_mock = MagicMock()
stderr_mock = MagicMock()
stdout_mock.read.return_value = b""
stderr_mock.read.return_value = b"command not found"
stdout_mock.channel.recv_exit_status.return_value = 1

ssh_client.exec_command.return_value = (MagicMock(), stdout_mock, stderr_mock)

with pytest.raises(AirflowException, match="BTEQ is not installed or not available in PATH"):
verify_bteq_installed_remote(ssh_client)
ssh_client.exec_command.assert_called_once_with("which bteq")

@patch("paramiko.SSHClient.exec_command")
def test_verify_bteq_installed_remote_success(self, mock_exec):
mock_stdin = MagicMock()
Expand All @@ -81,22 +189,6 @@ def test_verify_bteq_installed_remote_success(self, mock_exec):
# Should not raise
verify_bteq_installed_remote(ssh_client)

@patch("paramiko.SSHClient.exec_command")
def test_verify_bteq_installed_remote_fail(self, mock_exec):
mock_stdin = MagicMock()
mock_stdout = MagicMock()
mock_stderr = MagicMock()
mock_stdout.channel.recv_exit_status.return_value = 1
mock_stdout.read.return_value = b""
mock_stderr.read.return_value = b"command not found"
mock_exec.return_value = (mock_stdin, mock_stdout, mock_stderr)

ssh_client = MagicMock()
ssh_client.exec_command = mock_exec

with pytest.raises(AirflowException):
verify_bteq_installed_remote(ssh_client)

@patch("paramiko.SSHClient.open_sftp")
def test_transfer_file_sftp(self, mock_open_sftp):
mock_sftp = MagicMock()
Expand Down
Loading