diff --git a/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py b/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py index 6e38a5fff3ce2..30f45818e09b3 100644 --- a/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py +++ b/providers/mysql/src/airflow/providers/mysql/hooks/mysql.py @@ -249,7 +249,7 @@ def bulk_load(self, table: str, tmp_file: str) -> None: raise ValueError(f"Invalid table name: {table}") cur.execute( - f"LOAD DATA LOCAL INFILE %s INTO TABLE {table}", + f"LOAD DATA LOCAL INFILE %s INTO TABLE `{table}`", (tmp_file,), ) conn.commit() @@ -266,7 +266,7 @@ def bulk_dump(self, table: str, tmp_file: str) -> None: raise ValueError(f"Invalid table name: {table}") cur.execute( - f"SELECT * INTO OUTFILE %s FROM {table}", + f"SELECT * INTO OUTFILE %s FROM `{table}`", (tmp_file,), ) conn.commit() @@ -331,7 +331,7 @@ def bulk_load_custom( cursor = conn.cursor() cursor.execute( - f"LOAD DATA LOCAL INFILE %s %s INTO TABLE {table} %s", + f"LOAD DATA LOCAL INFILE %s %s INTO TABLE `{table}` %s", (tmp_file, duplicate_key_handling, extra_options), ) diff --git a/providers/mysql/tests/unit/mysql/hooks/test_mysql.py b/providers/mysql/tests/unit/mysql/hooks/test_mysql.py index 6facc6f4b44e6..18ffe66abbecb 100644 --- a/providers/mysql/tests/unit/mysql/hooks/test_mysql.py +++ b/providers/mysql/tests/unit/mysql/hooks/test_mysql.py @@ -19,7 +19,6 @@ import json import os -import uuid from contextlib import closing from unittest import mock @@ -302,18 +301,21 @@ def test_run_multi_queries(self): def test_bulk_load(self): self.db_hook.bulk_load("table", "/tmp/file") - self.cur.execute.assert_called_once_with("LOAD DATA LOCAL INFILE %s INTO TABLE table", ("/tmp/file",)) + self.cur.execute.assert_called_once_with( + "LOAD DATA LOCAL INFILE %s INTO TABLE `table`", ("/tmp/file",) + ) def test_bulk_dump(self): self.db_hook.bulk_dump("table", "/tmp/file") - self.cur.execute.assert_called_once_with("SELECT * INTO OUTFILE %s FROM table", ("/tmp/file",)) + self.cur.execute.assert_called_once_with("SELECT * INTO OUTFILE %s FROM `table`", ("/tmp/file",)) def test_serialize_cell(self): assert self.db_hook._serialize_cell("foo", None) == "foo" - def test_bulk_load_custom(self): + @pytest.mark.parametrize("table", ["table", "where"]) + def test_bulk_load_custom(self, table): self.db_hook.bulk_load_custom( - "table", + table, "/tmp/file", "IGNORE", """FIELDS TERMINATED BY ';' @@ -321,7 +323,7 @@ def test_bulk_load_custom(self): IGNORE 1 LINES""", ) self.cur.execute.assert_called_once_with( - "LOAD DATA LOCAL INFILE %s %s INTO TABLE table %s", + f"LOAD DATA LOCAL INFILE %s %s INTO TABLE `{table}` %s", ( "/tmp/file", "IGNORE", @@ -441,13 +443,14 @@ def teardown_method(self): cursor.execute(f"DROP TABLE IF EXISTS {table}") @pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"]) + @pytest.mark.parametrize("table", ["test_airflow", "where"]) @mock.patch.dict( "os.environ", { "AIRFLOW_CONN_AIRFLOW_DB": "mysql://root@mysql/airflow?charset=utf8mb4", }, ) - def test_mysql_hook_test_bulk_load(self, client, tmp_path): + def test_mysql_hook_test_bulk_load(self, client, table, tmp_path): with MySqlContext(client): records = ("foo", "bar", "baz") path = tmp_path / "testfile" @@ -456,35 +459,18 @@ def test_mysql_hook_test_bulk_load(self, client, tmp_path): hook = MySqlHook("airflow_db", local_infile=True) with closing(hook.get_conn()) as conn, closing(conn.cursor()) as cursor: cursor.execute( - """ - CREATE TABLE IF NOT EXISTS test_airflow ( + f""" + CREATE TABLE IF NOT EXISTS `{table}`( dummy VARCHAR(50) ) """ ) - cursor.execute("TRUNCATE TABLE test_airflow") - hook.bulk_load("test_airflow", os.fspath(path)) - cursor.execute("SELECT dummy FROM test_airflow") + cursor.execute(f"TRUNCATE TABLE `{table}`") + hook.bulk_load(table, os.fspath(path)) + cursor.execute(f"SELECT dummy FROM `{table}`") results = tuple(result[0] for result in cursor.fetchall()) assert sorted(results) == sorted(records) - @pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"]) - def test_mysql_hook_test_bulk_dump(self, client): - with MySqlContext(client): - hook = MySqlHook("airflow_db") - priv = hook.get_first("SELECT @@global.secure_file_priv") - # Use random names to allow re-running - if priv and priv[0]: - # Confirm that no error occurs - hook.bulk_dump( - "INFORMATION_SCHEMA.TABLES", - os.path.join(priv[0], f"TABLES_{client}-{uuid.uuid1()}"), - ) - elif priv == ("",): - hook.bulk_dump("INFORMATION_SCHEMA.TABLES", f"TABLES_{client}_{uuid.uuid1()}") - else: - raise pytest.skip("Skip test_mysql_hook_test_bulk_load since file output is not permitted") - @pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"]) @mock.patch("airflow.providers.mysql.hooks.mysql.MySqlHook.get_conn") def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn, client): @@ -498,5 +484,5 @@ def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn, client): hook.bulk_dump(table, tmp_file) assert mock_execute.call_count == 1 - query = f"SELECT * INTO OUTFILE %s FROM {table}" + query = f"SELECT * INTO OUTFILE %s FROM `{table}`" assert_equal_ignore_multiple_spaces(mock_execute.call_args.args[0], query)