Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions providers/common/sql/docs/operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,37 @@ The below example demonstrates how to instantiate the SQLThresholdCheckOperator

If the value returned by the query, is within the thresholds, the task passes. Otherwise, it fails.

.. _howto/operator:SQLInsertRowsOperator:

Insert rows into Table
~~~~~~~~~~~~~~~~~~~~~~

Use the :class:`~airflow.providers.common.sql.operators.sql.SQLInsertRowsOperator` to insert rows into a database table
directly from Python data structures or an XCom. Parameters of the operator are:

- ``table_name`` - name of the table in which the rows will be inserted (templated).
- ``conn_id`` - the Airflow connection ID used to connect to the database.
- ``schema`` (optional) - the schema in which the table is defined.
- ``database`` (optional) - name of the database which overrides the one defined in the connection.
- ``columns`` (optional) - list of columns to use for the insert when passing a list of dictionaries.
- ``ignored_columns`` (optional) - list of columns to ignore for the insert, if no columns are specified,
columns will be dynamically resolved from the metadata.
- ``rows`` - rows to insert, a list of tuples.
- ``rows_processor`` (optional) - a function applied to the rows before inserting them.
- ``preoperator`` (optional) - SQL statement or list of statements to execute before inserting data (templated).
- ``postoperator`` (optional) - SQL statement or list of statements to execute after inserting data (templated).
- ``hook_params`` (optional) - dictionary of additional parameters passed to the underlying hook.
- ``insert_args`` (optional) - dictionary of additional arguments passed to the hook's ``insert_rows`` method,
can include ``replace``, ``executemany``, ``fast_executemany``, ``autocommit``, and others supported by the hook.

The example below shows how to instantiate the SQLInsertRowsOperator task.

.. exampleinclude:: /../tests/system/common/sql/example_sql_insert_rows.py
:language: python
:dedent: 4
:start-after: [START howto_operator_sql_insert_rows]
:end-before: [END howto_operator_sql_insert_rows]

.. _howto/operator:GenericTransfer:

Generic Transfer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, SupportsAbs

from airflow.exceptions import AirflowException, AirflowFailException
from airflow import XComArg
from airflow.exceptions import AirflowException, AirflowFailException, AirflowSkipException
from airflow.models import SkipMixin
from airflow.providers.common.sql.hooks.handlers import fetch_all_handler, return_single_query_results
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.version_compat import BaseHook, BaseOperator
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
import jinja2

from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context

Expand Down Expand Up @@ -1252,6 +1255,135 @@ def execute(self, context: Context):
self.skip_all_except(context["ti"], follow_branch)


class SQLInsertRowsOperator(BaseSQLOperator):
"""
Insert rows (e.g. a collection of tuples) into a database table directly from an XCom or Python data structure.

:param table: the name of the table in which the rows will be inserted (templated).
:param conn_id: the connection ID used to connect to the database
:param schema: (optional) the name of schema in which the table is defined
:param database: name of database (e.g. schema) which overwrite the defined one in connection
:param columns: (optional) specify a list of columns being used for the insert when passing a list of
dictionaries.
:param ignore_columns: (optional) specify a list of columns being ignored for the insert. If no columns
where specified, the columns will be resolved dynamically from the metadata.
:param rows: the rows to insert into the table. Rows can be a list of tuples or a list of dictionaries.
When a list of dictionaries is provided, the column names are inferred from the dictionary keys and
will be matched with the column names, ignored columns will be filtered out.
:rows_processor: (optional) a function that will be applied to the rows before inserting them into the table.
:param preoperator: sql statement or list of statements to be executed prior to loading the data. (templated)
:param postoperator: sql statement or list of statements to be executed after loading the data. (templated)
:param insert_args: (optional) dictionary of additional arguments passed to the underlying hook's
`insert_rows` method. This allows you to configure options such as `replace`, `executemany`,
`fast_executemany`, and `autocommit`.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:SQLInsertRowsOperator`
"""

template_fields: Sequence[str] = (
"table_name",
"conn_id",
"schema",
"database",
"_columns",
"ignored_columns",
"preoperator",
"postoperator",
"insert_args",
)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"preoperator": "sql"}

def __init__(
self,
*,
table_name: str,
conn_id: str | None = None,
schema: str | None = None,
database: str | None = None,
columns: Iterable[str] | None = None,
ignored_columns: Iterable[str] | None = None,
rows: list[Any] | XComArg | None = None,
rows_processor: Callable[[Any, Context], Any] = lambda rows, **context: rows,
preoperator: str | list[str] | None = None,
postoperator: str | list[str] | None = None,
hook_params: dict | None = None,
insert_args: dict | None = None,
**kwargs,
):
super().__init__(
conn_id=conn_id,
database=database,
hook_params=hook_params,
**kwargs,
)
self.table_name = table_name
self.schema = schema
self._columns: list | None = list(columns) if columns else None
self.ignored_columns = set(ignored_columns or {})
self.rows = rows or []
self._rows_processor = rows_processor
self.preoperator = preoperator
self.postoperator = postoperator
self.insert_args = insert_args or {}
self.do_xcom_push = False

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
super().render_template_fields(context=context, jinja_env=jinja_env)

if isinstance(self.rows, XComArg):
self.rows = self.rows.resolve(context=context)

@property
def table_name_with_schema(self) -> str:
if self.schema is not None:
return f"{self.schema}.{self.table_name}"
return self.table_name

@cached_property
def columns(self):
if self._columns is None:
self._columns = self.get_db_hook().dialect.get_column_names(self.table_name_with_schema)
return self._columns

@property
def column_names(self) -> list[str]:
if self.ignored_columns:
return [column for column in self.columns if column not in self.ignored_columns]
return self.columns

def _process_rows(self, context: Context):
return self._rows_processor(context, self.rows) # type: ignore

def execute(self, context: Context) -> Any:
if not self.rows:
raise AirflowSkipException(f"Skipping task {self.task_id} because no rows.")

self.log.debug("Table: %s", self.table_name_with_schema)
self.log.debug("Column names: %s", self.column_names)
if self.preoperator:
self.log.debug("Running preoperator")
self.log.debug(self.preoperator)
self.get_db_hook().run(self.preoperator)
rows = self._process_rows(context=context)
self.get_db_hook().insert_rows(
table=self.table_name_with_schema,
rows=rows,
target_fields=self.column_names,
**self.insert_args,
)
if self.postoperator:
self.log.debug("Running postoperator")
self.log.debug(self.postoperator)
self.get_db_hook().run(self.postoperator)


def _initialize_partition_clause(clause: str | None) -> str | None:
"""Ensure the partition_clause contains only valid patterns."""
if clause is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow import DAG
from airflow.providers.common.sql.operators.sql import SQLInsertRowsOperator
from airflow.utils.timezone import datetime

AIRFLOW_DB_METADATA_TABLE = "ab_user"
connection_args = {
"conn_id": "airflow_db",
"conn_type": "Postgres",
"host": "postgres",
"schema": "postgres",
"login": "postgres",
"password": "postgres",
"port": 5432,
}

with DAG(
"example_sql_insert_rows",
description="Example DAG for SQLInsertRowsOperator.",
default_args=connection_args,
start_date=datetime(2021, 1, 1),
schedule=None,
catchup=False,
) as dag:
"""
### Example SQL insert rows DAG

Runs the SQLInsertRowsOperator against the Airflow metadata DB.
"""

# [START howto_operator_sql_insert_rows]
insert_rows = SQLInsertRowsOperator(
task_id="insert_rows",
table_name="actors",
columns=[
"name",
"firstname",
"age",
],
rows=[
("Stallone", "Sylvester", 78),
("Statham", "Jason", 57),
("Li", "Jet", 61),
("Lundgren", "Dolph", 66),
("Norris", "Chuck", 84),
],
preoperator=[
"""
CREATE TABLE IF NOT EXISTS actors (
index BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
name TEXT NOT NULL,
firstname TEXT NOT NULL,
age BIGINT NOT NULL
);
""",
"TRUNCATE TABLE actors;",
],
postoperator="DROP TABLE IF EXISTS actors;",
insert_args={
"commit_every": 1000,
"autocommit": False,
"executemany": True,
"fast_executemany": True,
},
)
# [END howto_operator_sql_insert_rows]


from tests_common.test_utils.system_tests import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)