From ac194fd7fae4727700a212c22f9fede7a1c3d613 Mon Sep 17 00:00:00 2001 From: Lee2532 Date: Fri, 11 Oct 2024 22:42:26 +0900 Subject: [PATCH 1/3] FEAT : oracle sequence column insert --- .../airflow/providers/oracle/hooks/oracle.py | 26 +++++++++++++++---- providers/tests/oracle/hooks/test_oracle.py | 13 ++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/providers/src/airflow/providers/oracle/hooks/oracle.py b/providers/src/airflow/providers/oracle/hooks/oracle.py index a252a7599cd34..d272d92056254 100644 --- a/providers/src/airflow/providers/oracle/hooks/oracle.py +++ b/providers/src/airflow/providers/oracle/hooks/oracle.py @@ -328,6 +328,8 @@ def bulk_insert_rows( rows: list[tuple], target_fields: list[str] | None = None, commit_every: int = 5000, + sequence_column: str | None = None, + sequence_name: str | None = None, ): """ Perform bulk inserts efficiently for Oracle DB. @@ -342,6 +344,8 @@ def bulk_insert_rows( If None, each rows should have some order as table columns name :param commit_every: the maximum number of rows to insert in one transaction Default 5000. Set greater than 0. Set 1 to insert each row in each transaction + :param sequence_column: the column name to which the sequence will be applied, default None. + :param sequence_name: the names of the sequence_name in the table, default None. """ if not rows: raise ValueError("parameter rows could not be None or empty iterable") @@ -350,11 +354,23 @@ def bulk_insert_rows( self.set_autocommit(conn, False) cursor = conn.cursor() # type: ignore[attr-defined] values_base = target_fields or rows[0] - prepared_stm = "insert into {tablename} {columns} values ({values})".format( - tablename=table, - columns="({})".format(", ".join(target_fields)) if target_fields else "", - values=", ".join(f":{i}" for i in range(1, len(values_base) + 1)), - ) + + if sequence_column and sequence_name: + prepared_stm = "insert into {tablename} {columns} values ({values})".format( + tablename=table, + columns="({})".format(", ".join([sequence_column] + target_fields)) + if target_fields + else f"({sequence_column})", + values=", ".join( + [f"{sequence_name}.NEXTVAL"] + [f":{i}" for i in range(1, len(values_base) + 1)] + ), + ) + else: + prepared_stm = "insert into {tablename} {columns} values ({values})".format( + tablename=table, + columns="({})".format(", ".join(target_fields)) if target_fields else "", + values=", ".join(f":{i}" for i in range(1, len(values_base) + 1)), + ) row_count = 0 # Chunk the rows row_chunk = [] diff --git a/providers/tests/oracle/hooks/test_oracle.py b/providers/tests/oracle/hooks/test_oracle.py index fc4709020eb73..dd4abad83f38a 100644 --- a/providers/tests/oracle/hooks/test_oracle.py +++ b/providers/tests/oracle/hooks/test_oracle.py @@ -369,6 +369,19 @@ def test_bulk_insert_rows_no_rows(self): with pytest.raises(ValueError): self.db_hook.bulk_insert_rows("table", rows) + def test_bulk_insert_sequence_field(self): + rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + target_fields = ["col1", "col2", "col3"] + sequence_column = "id" + sequence_name = "my_sequence" + self.db_hook.bulk_insert_rows( + "table", rows, target_fields, sequence_column=sequence_column, sequence_name=sequence_name + ) + self.cur.prepare.assert_called_once_with( + "insert into table (id, col1, col2, col3) values (my_sequence.NEXTVAL, :1, :2, :3)" + ) + self.cur.executemany.assert_called_once_with(None, rows) + def test_callproc_none(self): parameters = None From 086cf84fada51f12d2d84070f477952b76e0c1c5 Mon Sep 17 00:00:00 2001 From: Lee2532 Date: Fri, 25 Oct 2024 19:37:57 +0900 Subject: [PATCH 2/3] FEAT : Exception case if you enter only part of it --- .../airflow/providers/oracle/hooks/oracle.py | 5 +++++ providers/tests/oracle/hooks/test_oracle.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/providers/src/airflow/providers/oracle/hooks/oracle.py b/providers/src/airflow/providers/oracle/hooks/oracle.py index d272d92056254..0b86dc169c5c9 100644 --- a/providers/src/airflow/providers/oracle/hooks/oracle.py +++ b/providers/src/airflow/providers/oracle/hooks/oracle.py @@ -355,6 +355,11 @@ def bulk_insert_rows( cursor = conn.cursor() # type: ignore[attr-defined] values_base = target_fields or rows[0] + if (sequence_column is None) != (sequence_name is None): + raise ValueError( + "Parameters 'sequence_column' and 'sequence_name' must be provided together or not at all." + ) + if sequence_column and sequence_name: prepared_stm = "insert into {tablename} {columns} values ({values})".format( tablename=table, diff --git a/providers/tests/oracle/hooks/test_oracle.py b/providers/tests/oracle/hooks/test_oracle.py index dd4abad83f38a..2650d8f7ca98e 100644 --- a/providers/tests/oracle/hooks/test_oracle.py +++ b/providers/tests/oracle/hooks/test_oracle.py @@ -382,6 +382,23 @@ def test_bulk_insert_sequence_field(self): ) self.cur.executemany.assert_called_once_with(None, rows) + def test_bulk_insert_sequence_without_parameter(self): + rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + target_fields = ["col1", "col2", "col3"] + sequence_column = "id" + sequence_name = None + with pytest.raises(ValueError): + self.db_hook.bulk_insert_rows( + "table", rows, target_fields, sequence_column=sequence_column, sequence_name=sequence_name + ) + + sequence_column = None + sequence_name = "my_sequence" + with pytest.raises(ValueError): + self.db_hook.bulk_insert_rows( + "table", rows, target_fields, sequence_column=sequence_column, sequence_name=sequence_name + ) + def test_callproc_none(self): parameters = None From f8aec8289d934c32c0e5d0b2a7b12714ac1ca072 Mon Sep 17 00:00:00 2001 From: Lee2532 Date: Sat, 26 Oct 2024 12:17:14 +0900 Subject: [PATCH 3/3] FIX : pythonic code --- providers/src/airflow/providers/oracle/hooks/oracle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/src/airflow/providers/oracle/hooks/oracle.py b/providers/src/airflow/providers/oracle/hooks/oracle.py index 0b86dc169c5c9..3c51fe31c8a7e 100644 --- a/providers/src/airflow/providers/oracle/hooks/oracle.py +++ b/providers/src/airflow/providers/oracle/hooks/oracle.py @@ -355,7 +355,7 @@ def bulk_insert_rows( cursor = conn.cursor() # type: ignore[attr-defined] values_base = target_fields or rows[0] - if (sequence_column is None) != (sequence_name is None): + if bool(sequence_column) ^ bool(sequence_name): raise ValueError( "Parameters 'sequence_column' and 'sequence_name' must be provided together or not at all." )