diff --git a/providers/common/sql/docs/operators.rst b/providers/common/sql/docs/operators.rst index fa62b6fc11f32..837a003739748 100644 --- a/providers/common/sql/docs/operators.rst +++ b/providers/common/sql/docs/operators.rst @@ -174,6 +174,37 @@ The below example demonstrates how to instantiate the SQLThresholdCheckOperator If the value returned by the query, is within the thresholds, the task passes. Otherwise, it fails. +.. _howto/operator:SQLInsertRowsOperator: + +Insert rows into Table +~~~~~~~~~~~~~~~~~~~~~~ + +Use the :class:`~airflow.providers.common.sql.operators.sql.SQLInsertRowsOperator` to insert rows into a database table +directly from Python data structures or an XCom. Parameters of the operator are: + +- ``table_name`` - name of the table in which the rows will be inserted (templated). +- ``conn_id`` - the Airflow connection ID used to connect to the database. +- ``schema`` (optional) - the schema in which the table is defined. +- ``database`` (optional) - name of the database which overrides the one defined in the connection. +- ``columns`` (optional) - list of columns to use for the insert when passing a list of dictionaries. +- ``ignored_columns`` (optional) - list of columns to ignore for the insert, if no columns are specified, + columns will be dynamically resolved from the metadata. +- ``rows`` - rows to insert, a list of tuples. +- ``rows_processor`` (optional) - a function applied to the rows before inserting them. +- ``preoperator`` (optional) - SQL statement or list of statements to execute before inserting data (templated). +- ``postoperator`` (optional) - SQL statement or list of statements to execute after inserting data (templated). +- ``hook_params`` (optional) - dictionary of additional parameters passed to the underlying hook. +- ``insert_args`` (optional) - dictionary of additional arguments passed to the hook's ``insert_rows`` method, + can include ``replace``, ``executemany``, ``fast_executemany``, ``autocommit``, and others supported by the hook. + +The example below shows how to instantiate the SQLInsertRowsOperator task. + +.. exampleinclude:: /../tests/system/common/sql/example_sql_insert_rows.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sql_insert_rows] + :end-before: [END howto_operator_sql_insert_rows] + .. _howto/operator:GenericTransfer: Generic Transfer diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py index 0b884d97f0c26..6caa2000059e9 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/sql.py @@ -23,7 +23,8 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, SupportsAbs -from airflow.exceptions import AirflowException, AirflowFailException +from airflow import XComArg +from airflow.exceptions import AirflowException, AirflowFailException, AirflowSkipException from airflow.models import SkipMixin from airflow.providers.common.sql.hooks.handlers import fetch_all_handler, return_single_query_results from airflow.providers.common.sql.hooks.sql import DbApiHook @@ -31,6 +32,8 @@ from airflow.utils.helpers import merge_dicts if TYPE_CHECKING: + import jinja2 + from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.context import Context @@ -1252,6 +1255,135 @@ def execute(self, context: Context): self.skip_all_except(context["ti"], follow_branch) +class SQLInsertRowsOperator(BaseSQLOperator): + """ + Insert rows (e.g. a collection of tuples) into a database table directly from an XCom or Python data structure. + + :param table: the name of the table in which the rows will be inserted (templated). + :param conn_id: the connection ID used to connect to the database + :param schema: (optional) the name of schema in which the table is defined + :param database: name of database (e.g. schema) which overwrite the defined one in connection + :param columns: (optional) specify a list of columns being used for the insert when passing a list of + dictionaries. + :param ignore_columns: (optional) specify a list of columns being ignored for the insert. If no columns + where specified, the columns will be resolved dynamically from the metadata. + :param rows: the rows to insert into the table. Rows can be a list of tuples or a list of dictionaries. + When a list of dictionaries is provided, the column names are inferred from the dictionary keys and + will be matched with the column names, ignored columns will be filtered out. + :rows_processor: (optional) a function that will be applied to the rows before inserting them into the table. + :param preoperator: sql statement or list of statements to be executed prior to loading the data. (templated) + :param postoperator: sql statement or list of statements to be executed after loading the data. (templated) + :param insert_args: (optional) dictionary of additional arguments passed to the underlying hook's + `insert_rows` method. This allows you to configure options such as `replace`, `executemany`, + `fast_executemany`, and `autocommit`. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SQLInsertRowsOperator` + """ + + template_fields: Sequence[str] = ( + "table_name", + "conn_id", + "schema", + "database", + "_columns", + "ignored_columns", + "preoperator", + "postoperator", + "insert_args", + ) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"preoperator": "sql"} + + def __init__( + self, + *, + table_name: str, + conn_id: str | None = None, + schema: str | None = None, + database: str | None = None, + columns: Iterable[str] | None = None, + ignored_columns: Iterable[str] | None = None, + rows: list[Any] | XComArg | None = None, + rows_processor: Callable[[Any, Context], Any] = lambda rows, **context: rows, + preoperator: str | list[str] | None = None, + postoperator: str | list[str] | None = None, + hook_params: dict | None = None, + insert_args: dict | None = None, + **kwargs, + ): + super().__init__( + conn_id=conn_id, + database=database, + hook_params=hook_params, + **kwargs, + ) + self.table_name = table_name + self.schema = schema + self._columns: list | None = list(columns) if columns else None + self.ignored_columns = set(ignored_columns or {}) + self.rows = rows or [] + self._rows_processor = rows_processor + self.preoperator = preoperator + self.postoperator = postoperator + self.insert_args = insert_args or {} + self.do_xcom_push = False + + def render_template_fields( + self, + context: Context, + jinja_env: jinja2.Environment | None = None, + ) -> None: + super().render_template_fields(context=context, jinja_env=jinja_env) + + if isinstance(self.rows, XComArg): + self.rows = self.rows.resolve(context=context) + + @property + def table_name_with_schema(self) -> str: + if self.schema is not None: + return f"{self.schema}.{self.table_name}" + return self.table_name + + @cached_property + def columns(self): + if self._columns is None: + self._columns = self.get_db_hook().dialect.get_column_names(self.table_name_with_schema) + return self._columns + + @property + def column_names(self) -> list[str]: + if self.ignored_columns: + return [column for column in self.columns if column not in self.ignored_columns] + return self.columns + + def _process_rows(self, context: Context): + return self._rows_processor(context, self.rows) # type: ignore + + def execute(self, context: Context) -> Any: + if not self.rows: + raise AirflowSkipException(f"Skipping task {self.task_id} because no rows.") + + self.log.debug("Table: %s", self.table_name_with_schema) + self.log.debug("Column names: %s", self.column_names) + if self.preoperator: + self.log.debug("Running preoperator") + self.log.debug(self.preoperator) + self.get_db_hook().run(self.preoperator) + rows = self._process_rows(context=context) + self.get_db_hook().insert_rows( + table=self.table_name_with_schema, + rows=rows, + target_fields=self.column_names, + **self.insert_args, + ) + if self.postoperator: + self.log.debug("Running postoperator") + self.log.debug(self.postoperator) + self.get_db_hook().run(self.postoperator) + + def _initialize_partition_clause(clause: str | None) -> str | None: """Ensure the partition_clause contains only valid patterns.""" if clause is None: diff --git a/providers/common/sql/tests/system/common/sql/example_sql_insert_rows.py b/providers/common/sql/tests/system/common/sql/example_sql_insert_rows.py new file mode 100644 index 0000000000000..ba24082602e6a --- /dev/null +++ b/providers/common/sql/tests/system/common/sql/example_sql_insert_rows.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow import DAG +from airflow.providers.common.sql.operators.sql import SQLInsertRowsOperator +from airflow.utils.timezone import datetime + +AIRFLOW_DB_METADATA_TABLE = "ab_user" +connection_args = { + "conn_id": "airflow_db", + "conn_type": "Postgres", + "host": "postgres", + "schema": "postgres", + "login": "postgres", + "password": "postgres", + "port": 5432, +} + +with DAG( + "example_sql_insert_rows", + description="Example DAG for SQLInsertRowsOperator.", + default_args=connection_args, + start_date=datetime(2021, 1, 1), + schedule=None, + catchup=False, +) as dag: + """ + ### Example SQL insert rows DAG + + Runs the SQLInsertRowsOperator against the Airflow metadata DB. + """ + + # [START howto_operator_sql_insert_rows] + insert_rows = SQLInsertRowsOperator( + task_id="insert_rows", + table_name="actors", + columns=[ + "name", + "firstname", + "age", + ], + rows=[ + ("Stallone", "Sylvester", 78), + ("Statham", "Jason", 57), + ("Li", "Jet", 61), + ("Lundgren", "Dolph", 66), + ("Norris", "Chuck", 84), + ], + preoperator=[ + """ + CREATE TABLE IF NOT EXISTS actors ( + index BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + name TEXT NOT NULL, + firstname TEXT NOT NULL, + age BIGINT NOT NULL + ); + """, + "TRUNCATE TABLE actors;", + ], + postoperator="DROP TABLE IF EXISTS actors;", + insert_args={ + "commit_every": 1000, + "autocommit": False, + "executemany": True, + "fast_executemany": True, + }, + ) + # [END howto_operator_sql_insert_rows] + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)