Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sequence insert support to OracleHook #42947

Merged
merged 8 commits into from
Oct 26, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions providers/src/airflow/providers/oracle/hooks/oracle.py
Original file line number Diff line number Diff line change
@@ -328,6 +328,8 @@ def bulk_insert_rows(
rows: list[tuple],
target_fields: list[str] | None = None,
commit_every: int = 5000,
sequence_column: str | None = None,
sequence_name: str | None = None,
):
"""
Perform bulk inserts efficiently for Oracle DB.
@@ -342,6 +344,8 @@ def bulk_insert_rows(
If None, each rows should have some order as table columns name
:param commit_every: the maximum number of rows to insert in one transaction
Default 5000. Set greater than 0. Set 1 to insert each row in each transaction
:param sequence_column: the column name to which the sequence will be applied, default None.
:param sequence_name: the names of the sequence_name in the table, default None.
"""
if not rows:
raise ValueError("parameter rows could not be None or empty iterable")
@@ -350,11 +354,28 @@ def bulk_insert_rows(
self.set_autocommit(conn, False)
cursor = conn.cursor() # type: ignore[attr-defined]
values_base = target_fields or rows[0]
prepared_stm = "insert into {tablename} {columns} values ({values})".format(
tablename=table,
columns="({})".format(", ".join(target_fields)) if target_fields else "",
values=", ".join(f":{i}" for i in range(1, len(values_base) + 1)),
)

if (sequence_column is None) != (sequence_name is None):
Lee2532 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Parameters 'sequence_column' and 'sequence_name' must be provided together or not at all."
)

if sequence_column and sequence_name:
Lee2532 marked this conversation as resolved.
Show resolved Hide resolved
prepared_stm = "insert into {tablename} {columns} values ({values})".format(
tablename=table,
columns="({})".format(", ".join([sequence_column] + target_fields))
if target_fields
else f"({sequence_column})",
values=", ".join(
[f"{sequence_name}.NEXTVAL"] + [f":{i}" for i in range(1, len(values_base) + 1)]
),
)
else:
prepared_stm = "insert into {tablename} {columns} values ({values})".format(
tablename=table,
columns="({})".format(", ".join(target_fields)) if target_fields else "",
values=", ".join(f":{i}" for i in range(1, len(values_base) + 1)),
)
row_count = 0
# Chunk the rows
row_chunk = []
30 changes: 30 additions & 0 deletions providers/tests/oracle/hooks/test_oracle.py
Original file line number Diff line number Diff line change
@@ -369,6 +369,36 @@ def test_bulk_insert_rows_no_rows(self):
with pytest.raises(ValueError):
self.db_hook.bulk_insert_rows("table", rows)

def test_bulk_insert_sequence_field(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
target_fields = ["col1", "col2", "col3"]
sequence_column = "id"
sequence_name = "my_sequence"
self.db_hook.bulk_insert_rows(
"table", rows, target_fields, sequence_column=sequence_column, sequence_name=sequence_name
)
self.cur.prepare.assert_called_once_with(
"insert into table (id, col1, col2, col3) values (my_sequence.NEXTVAL, :1, :2, :3)"
)
self.cur.executemany.assert_called_once_with(None, rows)

def test_bulk_insert_sequence_without_parameter(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
target_fields = ["col1", "col2", "col3"]
sequence_column = "id"
sequence_name = None
with pytest.raises(ValueError):
self.db_hook.bulk_insert_rows(
"table", rows, target_fields, sequence_column=sequence_column, sequence_name=sequence_name
)

sequence_column = None
sequence_name = "my_sequence"
with pytest.raises(ValueError):
self.db_hook.bulk_insert_rows(
"table", rows, target_fields, sequence_column=sequence_column, sequence_name=sequence_name
)

def test_callproc_none(self):
parameters = None