From 8d30120487874555f55df6f79c3e94860f35a8bf Mon Sep 17 00:00:00 2001 From: mraje1 Date: Mon, 7 Jul 2025 11:26:37 -0400 Subject: [PATCH 1/4] feat(provider): pass parameters to Neo4j driver session Closes: 52723 --- .../airflow/providers/neo4j/hooks/neo4j.py | 11 ++-- .../providers/neo4j/operators/neo4j.py | 5 +- .../tests/unit/neo4j/hooks/test_neo4j.py | 54 +++++++++++++++++++ .../tests/unit/neo4j/operators/test_neo4j.py | 12 +++++ 4 files changed, 76 insertions(+), 6 deletions(-) diff --git a/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py b/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py index 9bd18a2680427..6d990f5a34991 100644 --- a/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py +++ b/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py @@ -47,11 +47,13 @@ class Neo4jHook(BaseHook): conn_type = "neo4j" hook_name = "Neo4j" - def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None: + def __init__(self, conn_id: str = default_conn_name, parameters=None, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.neo4j_conn_id = conn_id self.connection = kwargs.pop("connection", None) self.client: Driver | None = None + self.parameters = parameters + def get_conn(self) -> Driver: """Initiate a new Neo4j connection with username, password and database schema.""" @@ -113,19 +115,20 @@ def get_uri(self, conn: Connection) -> str: return f"{scheme}{encryption_scheme}://{conn.host}:{7687 if conn.port is None else conn.port}" - def run(self, query) -> list[Any]: + def run(self, query, parameters) -> list[Any]: """ Create a neo4j session and execute the query in the session. :param query: Neo4j query + :param parameters: Optional parameters for the query :return: Result """ driver = self.get_conn() if not self.connection.schema: with driver.session() as session: - result = session.run(query) + result = session.run(query, parameters) return result.data() else: with driver.session(database=self.connection.schema) as session: - result = session.run(query) + result = session.run(query, parameters) return result.data() diff --git a/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py b/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py index 819991a1bb8fe..8cf2d429a429d 100644 --- a/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py +++ b/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py @@ -42,9 +42,10 @@ class Neo4jOperator(BaseOperator): :param sql: the sql code to be executed. Can receive a str representing a sql statement :param neo4j_conn_id: Reference to :ref:`Neo4j connection id `. + :param parameters: the parameters to send to Neo4j driver session """ - template_fields: Sequence[str] = ("sql",) + template_fields: Sequence[str] = ("sql", "parameters") def __init__( self, @@ -61,5 +62,5 @@ def __init__( def execute(self, context: Context) -> None: self.log.info("Executing: %s", self.sql) - hook = Neo4jHook(conn_id=self.neo4j_conn_id) + hook = Neo4jHook(conn_id=self.neo4j_conn_id, parameters=self.parameters) hook.run(self.sql) diff --git a/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py b/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py index 54c2d3d64ed3c..b0e9b146309b8 100644 --- a/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py +++ b/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py @@ -79,6 +79,33 @@ def test_run_with_schema(self, mock_graph_database): session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value assert op_result == session.run.return_value.data.return_value + @mock.patch("airflow.providers.neo4j.hooks.neo4j.GraphDatabase") + def test_run_with_schema_and_params(self, mock_graph_database): + connection = Connection( + conn_type="neo4j", login="login", password="password", host="host", schema="schema" + ) + mock_sql = mock.MagicMock(name="sql") + mock_parameters = mock.MagicMock(name="parameters") + + # Use the environment variable mocking to test saving the configuration as a URI and + # to avoid mocking Airflow models class + with mock.patch.dict("os.environ", AIRFLOW_CONN_NEO4J_DEFAULT=connection.get_uri()): + neo4j_hook = Neo4jHook() + op_result = neo4j_hook.run(mock_sql, mock_parameters) + mock_graph_database.assert_has_calls( + [ + mock.call.driver("bolt://host:7687", auth=("login", "password"), encrypted=False), + mock.call.driver().session(database="schema"), + mock.call.driver().session().__enter__(), + mock.call.driver().session().__enter__().run(mock_sql), + mock.call.driver().session().__enter__().run(mock_parameters), + mock.call.driver().session().__enter__().run().data(), + mock.call.driver().session().__exit__(None, None, None), + ] + ) + session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value + assert op_result == session.run.return_value.data.return_value + @mock.patch("airflow.providers.neo4j.hooks.neo4j.GraphDatabase") def test_run_without_schema(self, mock_graph_database): connection = Connection( @@ -104,6 +131,33 @@ def test_run_without_schema(self, mock_graph_database): session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value assert op_result == session.run.return_value.data.return_value + @mock.patch("airflow.providers.neo4j.hooks.neo4j.GraphDatabase") + def test_run_without_schema_and_params(self, mock_graph_database): + connection = Connection( + conn_type="neo4j", login="login", password="password", host="host", schema=None + ) + mock_sql = mock.MagicMock(name="sql") + mock_parameters = mock.MagicMock(name="parameters") + + # Use the environment variable mocking to test saving the configuration as a URI and + # to avoid mocking Airflow models class + with mock.patch.dict("os.environ", AIRFLOW_CONN_NEO4J_DEFAULT=connection.get_uri()): + neo4j_hook = Neo4jHook() + op_result = neo4j_hook.run(mock_sql, mock_parameters) + mock_graph_database.assert_has_calls( + [ + mock.call.driver("bolt://host:7687", auth=("login", "password"), encrypted=False), + mock.call.driver().session(), + mock.call.driver().session().__enter__(), + mock.call.driver().session().__enter__().run(mock_sql), + mock.call.driver().session().__enter__().run(mock_parameters), + mock.call.driver().session().__enter__().run().data(), + mock.call.driver().session().__exit__(None, None, None), + ] + ) + session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value + assert op_result == session.run.return_value.data.return_value + @pytest.mark.parametrize( "conn_extra, should_provide_encrypted, expected_encrypted", [ diff --git a/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py b/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py index 76005e2aeab62..b0bd7c998180f 100644 --- a/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py +++ b/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py @@ -37,3 +37,15 @@ def test_neo4j_operator_test(self, mock_hook): op.execute(mock.MagicMock()) mock_hook.assert_called_once_with(conn_id="neo4j_default") mock_hook.return_value.run.assert_called_once_with(sql) + + @mock.patch("airflow.providers.neo4j.operators.neo4j.Neo4jHook") + def test_neo4j_operator_test_with_params(self, mock_hook): + sql = """ + MATCH (actor {name: $name}) RETURN actor + """ + parameters={ "name": "Tom Hanks"} + op = Neo4jOperator(task_id="basic_neo4j", sql=sql, parameters=parameters, conn_id="test_conn") + op.execute(mock.MagicMock()) + mock_hook.assert_called_once_with(conn_id="test_conn") + mock_hook.return_value.run.assert_called_once_with(sql=sql) + mock_hook.return_value.run.assert_called_once_with(parameters=parameters) From 8d1b3f8b507569a968ee57e5ab2989d3bbc0efd0 Mon Sep 17 00:00:00 2001 From: mraje1 Date: Tue, 8 Jul 2025 21:18:21 -0400 Subject: [PATCH 2/4] feat(hook): updated run params and fixed tests Closes: 52723 --- .../airflow/providers/neo4j/hooks/neo4j.py | 23 ++++++++++--------- .../providers/neo4j/operators/neo4j.py | 4 ++-- .../tests/unit/neo4j/hooks/test_neo4j.py | 6 ++--- .../tests/unit/neo4j/operators/test_neo4j.py | 5 ++-- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py b/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py index 6d990f5a34991..f3580002bcc6a 100644 --- a/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py +++ b/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py @@ -47,13 +47,11 @@ class Neo4jHook(BaseHook): conn_type = "neo4j" hook_name = "Neo4j" - def __init__(self, conn_id: str = default_conn_name, parameters=None, *args, **kwargs) -> None: + def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.neo4j_conn_id = conn_id self.connection = kwargs.pop("connection", None) self.client: Driver | None = None - self.parameters = parameters - def get_conn(self) -> Driver: """Initiate a new Neo4j connection with username, password and database schema.""" @@ -115,7 +113,7 @@ def get_uri(self, conn: Connection) -> str: return f"{scheme}{encryption_scheme}://{conn.host}:{7687 if conn.port is None else conn.port}" - def run(self, query, parameters) -> list[Any]: + def run(self, query: str, parameters: dict[str, Any] | None = None) -> list[Any]: """ Create a neo4j session and execute the query in the session. @@ -124,11 +122,14 @@ def run(self, query, parameters) -> list[Any]: :return: Result """ driver = self.get_conn() - if not self.connection.schema: - with driver.session() as session: - result = session.run(query, parameters) - return result.data() - else: - with driver.session(database=self.connection.schema) as session: + session_paramters = {} + + if db := self.connection.schema: + session_paramters["database"] = db + + with driver.session(**session_paramters) as session: + if parameters is not None: result = session.run(query, parameters) - return result.data() + else: + result = session.run(query) + return result.data() diff --git a/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py b/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py index 8cf2d429a429d..fbb7961297965 100644 --- a/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py +++ b/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py @@ -62,5 +62,5 @@ def __init__( def execute(self, context: Context) -> None: self.log.info("Executing: %s", self.sql) - hook = Neo4jHook(conn_id=self.neo4j_conn_id, parameters=self.parameters) - hook.run(self.sql) + hook = Neo4jHook(conn_id=self.neo4j_conn_id) + hook.run(self.sql, self.parameters) diff --git a/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py b/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py index b0e9b146309b8..95d8332bd3cce 100644 --- a/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py +++ b/providers/neo4j/tests/unit/neo4j/hooks/test_neo4j.py @@ -97,8 +97,7 @@ def test_run_with_schema_and_params(self, mock_graph_database): mock.call.driver("bolt://host:7687", auth=("login", "password"), encrypted=False), mock.call.driver().session(database="schema"), mock.call.driver().session().__enter__(), - mock.call.driver().session().__enter__().run(mock_sql), - mock.call.driver().session().__enter__().run(mock_parameters), + mock.call.driver().session().__enter__().run(mock_sql, mock_parameters), mock.call.driver().session().__enter__().run().data(), mock.call.driver().session().__exit__(None, None, None), ] @@ -149,8 +148,7 @@ def test_run_without_schema_and_params(self, mock_graph_database): mock.call.driver("bolt://host:7687", auth=("login", "password"), encrypted=False), mock.call.driver().session(), mock.call.driver().session().__enter__(), - mock.call.driver().session().__enter__().run(mock_sql), - mock.call.driver().session().__enter__().run(mock_parameters), + mock.call.driver().session().__enter__().run(mock_sql, mock_parameters), mock.call.driver().session().__enter__().run().data(), mock.call.driver().session().__exit__(None, None, None), ] diff --git a/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py b/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py index b0bd7c998180f..f6da9afca2923 100644 --- a/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py +++ b/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py @@ -36,7 +36,7 @@ def test_neo4j_operator_test(self, mock_hook): op = Neo4jOperator(task_id="basic_neo4j", sql=sql) op.execute(mock.MagicMock()) mock_hook.assert_called_once_with(conn_id="neo4j_default") - mock_hook.return_value.run.assert_called_once_with(sql) + mock_hook.return_value.run.assert_called_once_with(sql, None) @mock.patch("airflow.providers.neo4j.operators.neo4j.Neo4jHook") def test_neo4j_operator_test_with_params(self, mock_hook): @@ -47,5 +47,4 @@ def test_neo4j_operator_test_with_params(self, mock_hook): op = Neo4jOperator(task_id="basic_neo4j", sql=sql, parameters=parameters, conn_id="test_conn") op.execute(mock.MagicMock()) mock_hook.assert_called_once_with(conn_id="test_conn") - mock_hook.return_value.run.assert_called_once_with(sql=sql) - mock_hook.return_value.run.assert_called_once_with(parameters=parameters) + mock_hook.return_value.run.assert_called_once_with(sql, parameters) From 5cd5f75286e95d50d3f87caf35c0ed1a31249f89 Mon Sep 17 00:00:00 2001 From: mraje1 Date: Thu, 10 Jul 2025 17:11:54 -0400 Subject: [PATCH 3/4] test(unit): updated operator tests --- .../neo4j/src/airflow/providers/neo4j/operators/neo4j.py | 4 ++-- providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py b/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py index fbb7961297965..96a6ec3648d79 100644 --- a/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py +++ b/providers/neo4j/src/airflow/providers/neo4j/operators/neo4j.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from airflow.providers.neo4j.hooks.neo4j import Neo4jHook @@ -52,7 +52,7 @@ def __init__( *, sql: str, neo4j_conn_id: str = "neo4j_default", - parameters: Iterable | Mapping[str, Any] | None = None, + parameters: dict[str, Any] | None = None, **kwargs, ) -> None: super().__init__(**kwargs) diff --git a/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py b/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py index f6da9afca2923..dccfcb56c8210 100644 --- a/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py +++ b/providers/neo4j/tests/unit/neo4j/operators/test_neo4j.py @@ -43,8 +43,8 @@ def test_neo4j_operator_test_with_params(self, mock_hook): sql = """ MATCH (actor {name: $name}) RETURN actor """ - parameters={ "name": "Tom Hanks"} - op = Neo4jOperator(task_id="basic_neo4j", sql=sql, parameters=parameters, conn_id="test_conn") + parameters = {"name": "Tom Hanks"} + op = Neo4jOperator(task_id="basic_neo4j", sql=sql, parameters=parameters) op.execute(mock.MagicMock()) - mock_hook.assert_called_once_with(conn_id="test_conn") + mock_hook.assert_called_once_with(conn_id="neo4j_default") mock_hook.return_value.run.assert_called_once_with(sql, parameters) From ad78ab015d4d58e8f510d6988707a799b767e051 Mon Sep 17 00:00:00 2001 From: mraje1 Date: Tue, 22 Jul 2025 09:30:48 -0400 Subject: [PATCH 4/4] docs(operator): added parameters documentation --- providers/neo4j/docs/operators/neo4j.rst | 21 +++++++ .../tests/system/neo4j/example_neo4j_query.py | 55 +++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 providers/neo4j/tests/system/neo4j/example_neo4j_query.py diff --git a/providers/neo4j/docs/operators/neo4j.rst b/providers/neo4j/docs/operators/neo4j.rst index cb5f6c7350157..df7426c680c5f 100644 --- a/providers/neo4j/docs/operators/neo4j.rst +++ b/providers/neo4j/docs/operators/neo4j.rst @@ -54,3 +54,24 @@ the connection metadata is structured as follows: :dedent: 4 :start-after: [START run_query_neo4j_operator] :end-before: [END run_query_neo4j_operator] + +Passing parameters into the query +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Neo4jOperator provides ``parameters`` argument to pass parameters into the +query. This allows you to use placeholders in your parameterized query and +substitute them with actual values at execution time. + +When using the ``parameters`` argument, you should prefix placeholders in your +query using the ``$`` syntax. For example, if your query uses a placeholder +like ``$name``, you would provide the parameters as ``{"name": "value"}`` in +the operator. This allows you to write dynamic queries without having to +concatenate strings. This is particularly useful when you want to execute +the same query with different values, or use values from the Airflow +context, such as task instance parameters or variables. + +.. exampleinclude:: /../../neo4j/tests/system/neo4j/example_neo4j_query.py + :language: python + :dedent: 4 + :start-after: [START run_query_neo4j_operator] + :end-before: [END run_query_neo4j_operator] diff --git a/providers/neo4j/tests/system/neo4j/example_neo4j_query.py b/providers/neo4j/tests/system/neo4j/example_neo4j_query.py new file mode 100644 index 0000000000000..fbc313d760b5f --- /dev/null +++ b/providers/neo4j/tests/system/neo4j/example_neo4j_query.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example use of Neo4j related operators with parameters. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import DAG +from airflow.providers.neo4j.operators.neo4j import Neo4jOperator + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "example_neo4j_query" + +with DAG( + DAG_ID, + start_date=datetime(2025, 1, 1), + schedule=None, + tags=["example"], + catchup=False, +) as dag: + # [START run_query_neo4j_operator] + + neo4j_task = Neo4jOperator( + task_id="run_neo4j_query_with_parameters", + neo4j_conn_id="neo4j_conn_id", + parameters={"name": "Tom Hanks"}, + sql='MATCH (actor {name: $name, date: "{{ds}}"}) RETURN actor', + dag=dag, + ) + + # [END run_query_neo4j_operator] + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)