diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py index 414ef8bad8462..5c659ecf5c7ab 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py @@ -56,6 +56,7 @@ class GenericTransfer(BaseOperator): executed prior to loading the data. (templated) :param insert_args: extra params for `insert_rows` method. :param page_size: number of records to be read in paginated mode (optional). + :param paginated_sql_statement_clause: SQL statement clause to be used for pagination (optional). """ template_fields: Sequence[str] = ( @@ -65,6 +66,8 @@ class GenericTransfer(BaseOperator): "destination_table", "preoperator", "insert_args", + "page_size", + "paginated_sql_statement_clause", ) template_ext: Sequence[str] = ( ".sql", @@ -85,6 +88,7 @@ def __init__( preoperator: str | list[str] | None = None, insert_args: dict | None = None, page_size: int | None = None, + paginated_sql_statement_clause: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -97,9 +101,7 @@ def __init__( self.preoperator = preoperator self.insert_args = insert_args or {} self.page_size = page_size - self._paginated_sql_statement_format = kwargs.get( - "paginated_sql_statement_format", "{} LIMIT {} OFFSET {}" - ) + self.paginated_sql_statement_clause = paginated_sql_statement_clause or "{} LIMIT {} OFFSET {}" @classmethod def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> DbApiHook: @@ -126,7 +128,7 @@ def destination_hook(self) -> DbApiHook: def get_paginated_sql(self, offset: int) -> str: """Format the paginated SQL statement using the current format.""" - return self._paginated_sql_statement_format.format(self.sql, self.page_size, offset) + return self.paginated_sql_statement_clause.format(self.sql, self.page_size, offset) def render_template_fields( self, diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py index 07a250cc40d4c..ee82b5440162b 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py @@ -241,6 +241,8 @@ def test_templated_fields(self): destination_conn_id="{{ destination_conn_id }}", preoperator="{{ preoperator }}", insert_args="{{ insert_args }}", + page_size="{{ page_size }}", + paginated_sql_statement_clause="{{ paginated_sql_statement_clause }}", dag=dag, ) operator.render_template_fields( @@ -251,6 +253,8 @@ def test_templated_fields(self): "destination_conn_id": "my_destination_conn_id", "preoperator": "my_preoperator", "insert_args": {"commit_every": 5000, "executemany": True, "replace": True}, + "page_size": 1000, + "paginated_sql_statement_clause": "{} OFFSET {} ROWS FETCH NEXT {} ROWS ONLY;", } ) assert operator.sql == "my_sql" @@ -259,6 +263,8 @@ def test_templated_fields(self): assert operator.destination_conn_id == "my_destination_conn_id" assert operator.preoperator == "my_preoperator" assert operator.insert_args == {"commit_every": 5000, "executemany": True, "replace": True} + assert operator.page_size == 1000 + assert operator.paginated_sql_statement_clause == "{} OFFSET {} ROWS FETCH NEXT {} ROWS ONLY;" def test_non_paginated_read(self): with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection):