diff --git a/providers/neo4j/docs/operators/neo4j.rst b/providers/neo4j/docs/operators/neo4j.rst index cb5f6c7350157..df7426c680c5f 100644 --- a/providers/neo4j/docs/operators/neo4j.rst +++ b/providers/neo4j/docs/operators/neo4j.rst @@ -54,3 +54,24 @@ the connection metadata is structured as follows: :dedent: 4 :start-after: [START run_query_neo4j_operator] :end-before: [END run_query_neo4j_operator] + +Passing parameters into the query +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Neo4jOperator provides ``parameters`` argument to pass parameters into the +query. This allows you to use placeholders in your parameterized query and +substitute them with actual values at execution time. + +When using the ``parameters`` argument, you should prefix placeholders in your +query using the ``$`` syntax. For example, if your query uses a placeholder +like ``$name``, you would provide the parameters as ``{"name": "value"}`` in +the operator. This allows you to write dynamic queries without having to +concatenate strings. This is particularly useful when you want to execute +the same query with different values, or use values from the Airflow +context, such as task instance parameters or variables. + +.. exampleinclude:: /../../neo4j/tests/system/neo4j/example_neo4j_query.py + :language: python + :dedent: 4 + :start-after: [START run_query_neo4j_operator] + :end-before: [END run_query_neo4j_operator] diff --git a/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py b/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py index 9bd18a2680427..f3580002bcc6a 100644 --- a/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py +++ b/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py @@ -113,19 +113,23 @@ def get_uri(self, conn: Connection) -> str: return f"{scheme}{encryption_scheme}://{conn.host}:{7687 if conn.port is None else conn.port}" - def run(self, query) -> list[Any]: + def run(self, query: str, parameters: dict[str, Any] | None = None) -> list[Any]: """ Create a neo4j session and execute the query in the session. :param query: Neo4j query + :param parameters: Optional parameters for the query :return: Result """ driver = self.get_conn() - if not self.connection.schema: - with driver.session() as session: - result = session.run(query) - return result.data() - else: - with driver.session(database=self.connection.schema) as session: + session_paramters = {} + + if db := self.connection.schema: + session_paramters["database"] = db + + with driver.session(**session_paramters) as session: + if parameters is not None: + result = session.run(query, parameters) + else: result = session.run(query) - return result.data() + return result.data() diff --git a/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py b/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py index 819991a1bb8fe..96a6ec3648d79 100644 --- a/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py +++ b/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from airflow.providers.neo4j.hooks.neo4j import Neo4jHook @@ -42,16 +42,17 @@ class Neo4jOperator(BaseOperator): :param sql: the sql code to be executed. Can receive a str representing a sql statement :param neo4j_conn_id: Reference to :ref:`Neo4j connection id `. + :param parameters: the parameters to send to Neo4j driver session """ - template_fields: Sequence[str] = ("sql",) + template_fields: Sequence[str] = ("sql", "parameters") def __init__( self, *, sql: str, neo4j_conn_id: str = "neo4j_default", - parameters: Iterable | Mapping[str, Any] | None = None, + parameters: dict[str, Any] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -62,4 +63,4 @@ def __init__( def execute(self, context: Context) -> None: self.log.info("Executing: %s", self.sql) hook = Neo4jHook(conn_id=self.neo4j_conn_id) - hook.run(self.sql) + hook.run(self.sql, self.parameters) diff --git a/providers/neo4j/tests/system/neo4j/example_neo4j_query.py b/providers/neo4j/tests/system/neo4j/example_neo4j_query.py new file mode 100644 index 0000000000000..fbc313d760b5f --- /dev/null +++ b/providers/neo4j/tests/system/neo4j/example_neo4j_query.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example use of Neo4j related operators with parameters. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import DAG +from airflow.providers.neo4j.operators.neo4j import Neo4jOperator + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "example_neo4j_query" + +with DAG( + DAG_ID, + start_date=datetime(2025, 1, 1), + schedule=None, + tags=["example"], + catchup=False, +) as dag: + # [START run_query_neo4j_operator] + + neo4j_task = Neo4jOperator( + task_id="run_neo4j_query_with_parameters", + neo4j_conn_id="neo4j_conn_id", + parameters={"name": "Tom Hanks"}, + sql='MATCH (actor {name: $name, date: "{{ds}}"}) RETURN actor', + dag=dag, + ) + + # [END run_query_neo4j_operator] + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py b/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py index 54c2d3d64ed3c..95d8332bd3cce 100644 --- a/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py +++ b/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py @@ -79,6 +79,32 @@ def test_run_with_schema(self, mock_graph_database): session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value assert op_result == session.run.return_value.data.return_value + @mock.patch("airflow.providers.neo4j.hooks.neo4j.GraphDatabase") + def test_run_with_schema_and_params(self, mock_graph_database): + connection = Connection( + conn_type="neo4j", login="login", password="password", host="host", schema="schema" + ) + mock_sql = mock.MagicMock(name="sql") + mock_parameters = mock.MagicMock(name="parameters") + + # Use the environment variable mocking to test saving the configuration as a URI and + # to avoid mocking Airflow models class + with mock.patch.dict("os.environ", AIRFLOW_CONN_NEO4J_DEFAULT=connection.get_uri()): + neo4j_hook = Neo4jHook() + op_result = neo4j_hook.run(mock_sql, mock_parameters) + mock_graph_database.assert_has_calls( + [ + mock.call.driver("bolt://host:7687", auth=("login", "password"), encrypted=False), + mock.call.driver().session(database="schema"), + mock.call.driver().session().__enter__(), + mock.call.driver().session().__enter__().run(mock_sql, mock_parameters), + mock.call.driver().session().__enter__().run().data(), + mock.call.driver().session().__exit__(None, None, None), + ] + ) + session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value + assert op_result == session.run.return_value.data.return_value + @mock.patch("airflow.providers.neo4j.hooks.neo4j.GraphDatabase") def test_run_without_schema(self, mock_graph_database): connection = Connection( @@ -104,6 +130,32 @@ def test_run_without_schema(self, mock_graph_database): session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value assert op_result == session.run.return_value.data.return_value + @mock.patch("airflow.providers.neo4j.hooks.neo4j.GraphDatabase") + def test_run_without_schema_and_params(self, mock_graph_database): + connection = Connection( + conn_type="neo4j", login="login", password="password", host="host", schema=None + ) + mock_sql = mock.MagicMock(name="sql") + mock_parameters = mock.MagicMock(name="parameters") + + # Use the environment variable mocking to test saving the configuration as a URI and + # to avoid mocking Airflow models class + with mock.patch.dict("os.environ", AIRFLOW_CONN_NEO4J_DEFAULT=connection.get_uri()): + neo4j_hook = Neo4jHook() + op_result = neo4j_hook.run(mock_sql, mock_parameters) + mock_graph_database.assert_has_calls( + [ + mock.call.driver("bolt://host:7687", auth=("login", "password"), encrypted=False), + mock.call.driver().session(), + mock.call.driver().session().__enter__(), + mock.call.driver().session().__enter__().run(mock_sql, mock_parameters), + mock.call.driver().session().__enter__().run().data(), + mock.call.driver().session().__exit__(None, None, None), + ] + ) + session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value + assert op_result == session.run.return_value.data.return_value + @pytest.mark.parametrize( "conn_extra, should_provide_encrypted, expected_encrypted", [ diff --git a/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py b/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py index 76005e2aeab62..dccfcb56c8210 100644 --- a/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py +++ b/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py @@ -36,4 +36,15 @@ def test_neo4j_operator_test(self, mock_hook): op = Neo4jOperator(task_id="basic_neo4j", sql=sql) op.execute(mock.MagicMock()) mock_hook.assert_called_once_with(conn_id="neo4j_default") - mock_hook.return_value.run.assert_called_once_with(sql) + mock_hook.return_value.run.assert_called_once_with(sql, None) + + @mock.patch("airflow.providers.neo4j.operators.neo4j.Neo4jHook") + def test_neo4j_operator_test_with_params(self, mock_hook): + sql = """ + MATCH (actor {name: $name}) RETURN actor + """ + parameters = {"name": "Tom Hanks"} + op = Neo4jOperator(task_id="basic_neo4j", sql=sql, parameters=parameters) + op.execute(mock.MagicMock()) + mock_hook.assert_called_once_with(conn_id="neo4j_default") + mock_hook.return_value.run.assert_called_once_with(sql, parameters)