Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions providers/mysql/src/airflow/providers/mysql/hooks/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def bulk_load(self, table: str, tmp_file: str) -> None:
raise ValueError(f"Invalid table name: {table}")

cur.execute(
f"LOAD DATA LOCAL INFILE %s INTO TABLE {table}",
f"LOAD DATA LOCAL INFILE %s INTO TABLE `{table}`",
(tmp_file,),
)
conn.commit()
Expand All @@ -266,7 +266,7 @@ def bulk_dump(self, table: str, tmp_file: str) -> None:
raise ValueError(f"Invalid table name: {table}")

cur.execute(
f"SELECT * INTO OUTFILE %s FROM {table}",
f"SELECT * INTO OUTFILE %s FROM `{table}`",
(tmp_file,),
)
conn.commit()
Expand Down Expand Up @@ -331,7 +331,7 @@ def bulk_load_custom(
cursor = conn.cursor()

cursor.execute(
f"LOAD DATA LOCAL INFILE %s %s INTO TABLE {table} %s",
f"LOAD DATA LOCAL INFILE %s %s INTO TABLE `{table}` %s",
(tmp_file, duplicate_key_handling, extra_options),
)

Expand Down
46 changes: 16 additions & 30 deletions providers/mysql/tests/unit/mysql/hooks/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import json
import os
import uuid
from contextlib import closing
from unittest import mock

Expand Down Expand Up @@ -302,26 +301,29 @@ def test_run_multi_queries(self):

def test_bulk_load(self):
self.db_hook.bulk_load("table", "/tmp/file")
self.cur.execute.assert_called_once_with("LOAD DATA LOCAL INFILE %s INTO TABLE table", ("/tmp/file",))
self.cur.execute.assert_called_once_with(
"LOAD DATA LOCAL INFILE %s INTO TABLE `table`", ("/tmp/file",)
)

def test_bulk_dump(self):
self.db_hook.bulk_dump("table", "/tmp/file")
self.cur.execute.assert_called_once_with("SELECT * INTO OUTFILE %s FROM table", ("/tmp/file",))
self.cur.execute.assert_called_once_with("SELECT * INTO OUTFILE %s FROM `table`", ("/tmp/file",))

def test_serialize_cell(self):
assert self.db_hook._serialize_cell("foo", None) == "foo"

def test_bulk_load_custom(self):
@pytest.mark.parametrize("table", ["table", "where"])
def test_bulk_load_custom(self, table):
self.db_hook.bulk_load_custom(
"table",
table,
"/tmp/file",
"IGNORE",
"""FIELDS TERMINATED BY ';'
OPTIONALLY ENCLOSED BY '"'
IGNORE 1 LINES""",
)
self.cur.execute.assert_called_once_with(
"LOAD DATA LOCAL INFILE %s %s INTO TABLE table %s",
f"LOAD DATA LOCAL INFILE %s %s INTO TABLE `{table}` %s",
(
"/tmp/file",
"IGNORE",
Expand Down Expand Up @@ -441,13 +443,14 @@ def teardown_method(self):
cursor.execute(f"DROP TABLE IF EXISTS {table}")

@pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"])
@pytest.mark.parametrize("table", ["test_airflow", "where"])
@mock.patch.dict(
"os.environ",
{
"AIRFLOW_CONN_AIRFLOW_DB": "mysql://root@mysql/airflow?charset=utf8mb4",
},
)
def test_mysql_hook_test_bulk_load(self, client, tmp_path):
def test_mysql_hook_test_bulk_load(self, client, table, tmp_path):
with MySqlContext(client):
records = ("foo", "bar", "baz")
path = tmp_path / "testfile"
Expand All @@ -456,35 +459,18 @@ def test_mysql_hook_test_bulk_load(self, client, tmp_path):
hook = MySqlHook("airflow_db", local_infile=True)
with closing(hook.get_conn()) as conn, closing(conn.cursor()) as cursor:
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS test_airflow (
f"""
CREATE TABLE IF NOT EXISTS `{table}`(
dummy VARCHAR(50)
)
"""
)
cursor.execute("TRUNCATE TABLE test_airflow")
hook.bulk_load("test_airflow", os.fspath(path))
cursor.execute("SELECT dummy FROM test_airflow")
cursor.execute(f"TRUNCATE TABLE `{table}`")
hook.bulk_load(table, os.fspath(path))
cursor.execute(f"SELECT dummy FROM `{table}`")
results = tuple(result[0] for result in cursor.fetchall())
assert sorted(results) == sorted(records)

@pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"])
def test_mysql_hook_test_bulk_dump(self, client):
with MySqlContext(client):
hook = MySqlHook("airflow_db")
priv = hook.get_first("SELECT @@global.secure_file_priv")
# Use random names to allow re-running
if priv and priv[0]:
# Confirm that no error occurs
hook.bulk_dump(
"INFORMATION_SCHEMA.TABLES",
os.path.join(priv[0], f"TABLES_{client}-{uuid.uuid1()}"),
)
elif priv == ("",):
hook.bulk_dump("INFORMATION_SCHEMA.TABLES", f"TABLES_{client}_{uuid.uuid1()}")
else:
raise pytest.skip("Skip test_mysql_hook_test_bulk_load since file output is not permitted")

@pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"])
@mock.patch("airflow.providers.mysql.hooks.mysql.MySqlHook.get_conn")
def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn, client):
Expand All @@ -498,5 +484,5 @@ def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn, client):
hook.bulk_dump(table, tmp_file)

assert mock_execute.call_count == 1
query = f"SELECT * INTO OUTFILE %s FROM {table}"
query = f"SELECT * INTO OUTFILE %s FROM `{table}`"
assert_equal_ignore_multiple_spaces(mock_execute.call_args.args[0], query)