Skip to content

Commit

Permalink
FEAT : oracle sequence column insert
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee2532 authored and potiuk committed Oct 14, 2024
1 parent 4c5ad9c commit ac194fd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
26 changes: 21 additions & 5 deletions providers/src/airflow/providers/oracle/hooks/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def bulk_insert_rows(
rows: list[tuple],
target_fields: list[str] | None = None,
commit_every: int = 5000,
sequence_column: str | None = None,
sequence_name: str | None = None,
):
"""
Perform bulk inserts efficiently for Oracle DB.
Expand All @@ -342,6 +344,8 @@ def bulk_insert_rows(
If None, each rows should have some order as table columns name
:param commit_every: the maximum number of rows to insert in one transaction
Default 5000. Set greater than 0. Set 1 to insert each row in each transaction
:param sequence_column: the column name to which the sequence will be applied, default None.
:param sequence_name: the names of the sequence_name in the table, default None.
"""
if not rows:
raise ValueError("parameter rows could not be None or empty iterable")
Expand All @@ -350,11 +354,23 @@ def bulk_insert_rows(
self.set_autocommit(conn, False)
cursor = conn.cursor() # type: ignore[attr-defined]
values_base = target_fields or rows[0]
prepared_stm = "insert into {tablename} {columns} values ({values})".format(
tablename=table,
columns="({})".format(", ".join(target_fields)) if target_fields else "",
values=", ".join(f":{i}" for i in range(1, len(values_base) + 1)),
)

if sequence_column and sequence_name:
prepared_stm = "insert into {tablename} {columns} values ({values})".format(
tablename=table,
columns="({})".format(", ".join([sequence_column] + target_fields))
if target_fields
else f"({sequence_column})",
values=", ".join(
[f"{sequence_name}.NEXTVAL"] + [f":{i}" for i in range(1, len(values_base) + 1)]
),
)
else:
prepared_stm = "insert into {tablename} {columns} values ({values})".format(
tablename=table,
columns="({})".format(", ".join(target_fields)) if target_fields else "",
values=", ".join(f":{i}" for i in range(1, len(values_base) + 1)),
)
row_count = 0
# Chunk the rows
row_chunk = []
Expand Down
13 changes: 13 additions & 0 deletions providers/tests/oracle/hooks/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,19 @@ def test_bulk_insert_rows_no_rows(self):
with pytest.raises(ValueError):
self.db_hook.bulk_insert_rows("table", rows)

def test_bulk_insert_sequence_field(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
target_fields = ["col1", "col2", "col3"]
sequence_column = "id"
sequence_name = "my_sequence"
self.db_hook.bulk_insert_rows(
"table", rows, target_fields, sequence_column=sequence_column, sequence_name=sequence_name
)
self.cur.prepare.assert_called_once_with(
"insert into table (id, col1, col2, col3) values (my_sequence.NEXTVAL, :1, :2, :3)"
)
self.cur.executemany.assert_called_once_with(None, rows)

def test_callproc_none(self):
parameters = None

Expand Down

0 comments on commit ac194fd

Please sign in to comment.