Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed Mandatory Encryption in Neo4jHook #30418

Merged
merged 9 commits into from
Apr 21, 2023
19 changes: 16 additions & 3 deletions airflow/providers/neo4j/hooks/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

from typing import Any
from urllib.parse import urlsplit

from neo4j import Driver, GraphDatabase

Expand Down Expand Up @@ -61,12 +62,24 @@ def get_conn(self) -> Driver:

is_encrypted = self.connection.extra_dejson.get("encrypted", False)

self.client = GraphDatabase.driver(
uri, auth=(self.connection.login, self.connection.password), encrypted=is_encrypted
)
self.client = self.get_client(self.connection, is_encrypted, uri)

return self.client

def get_client(self, conn: Connection, encrypted: bool, uri: str) -> Driver:
"""
Function to determine that relevant driver based on extras.
:param conn: Connection object.
:param encrypted: boolean if encrypted connection or not.
:param uri: uri string for connection.
:return: Driver
"""
parsed_uri = urlsplit(uri)
kwargs: dict[str, Any] = {}
if parsed_uri.scheme in ["bolt", "neo4j"]:
kwargs["encrypted"] = encrypted
return GraphDatabase.driver(uri, auth=(conn.login, conn.password), **kwargs)

def get_uri(self, conn: Connection) -> str:
"""
Build the uri based on extras
Expand Down
44 changes: 41 additions & 3 deletions tests/providers/neo4j/hooks/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ class TestNeo4jHookConn:
"conn_extra, expected_uri",
[
({}, "bolt://host:7687"),
({"bolt_scheme": True}, "bolt://host:7687"),
({"certs_self_signed": True, "bolt_scheme": True}, "bolt+ssc://host:7687"),
({"certs_trusted_ca": True, "bolt_scheme": True}, "bolt+s://host:7687"),
({"neo4j_scheme": False}, "bolt://host:7687"),
({"certs_self_signed": True, "neo4j_scheme": False}, "bolt+ssc://host:7687"),
({"certs_trusted_ca": True, "neo4j_scheme": False}, "bolt+s://host:7687"),
({"certs_self_signed": True, "neo4j_scheme": True}, "neo4j+ssc://host:7687"),
({"certs_trusted_ca": True, "neo4j_scheme": True}, "neo4j+s://host:7687"),
],
)
def test_get_uri_neo4j_scheme(self, conn_extra, expected_uri):
Expand Down Expand Up @@ -101,3 +103,39 @@ def test_run_without_schema(self, mock_graph_database):
)
session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value
assert op_result == session.run.return_value.data.return_value

@pytest.mark.parametrize(
"conn_extra, should_provide_encrypted, expected_encrypted",
[
({}, True, False),
({"neo4j_scheme": False, "encrypted": True}, True, True),
({"certs_self_signed": False, "neo4j_scheme": False, "encrypted": False}, True, False),
({"certs_trusted_ca": False, "neo4j_scheme": False, "encrypted": False}, True, False),
({"certs_self_signed": False, "neo4j_scheme": True, "encrypted": False}, True, False),
({"certs_trusted_ca": False, "neo4j_scheme": True, "encrypted": False}, True, False),
({"certs_self_signed": True, "neo4j_scheme": False, "encrypted": False}, False, None),
({"certs_trusted_ca": True, "neo4j_scheme": False, "encrypted": False}, False, None),
({"certs_self_signed": True, "neo4j_scheme": True, "encrypted": False}, False, None),
({"certs_trusted_ca": True, "neo4j_scheme": True, "encrypted": False}, False, None),
],
)
@mock.patch("airflow.providers.neo4j.hooks.neo4j.GraphDatabase.driver")
def test_encrypted_provided(
self, mock_graph_database, conn_extra, should_provide_encrypted, expected_encrypted
):
connection = Connection(
conn_type="neo4j",
login="login",
password="password",
host="host",
schema="schema",
extra=conn_extra,
)
with mock.patch.dict("os.environ", AIRFLOW_CONN_NEO4J_DEFAULT=connection.get_uri()):
neo4j_hook = Neo4jHook()
with neo4j_hook.get_conn():
if should_provide_encrypted:
assert "encrypted" in mock_graph_database.call_args.kwargs
assert mock_graph_database.call_args.kwargs["encrypted"] == expected_encrypted
else:
assert "encrypted" not in mock_graph_database.call_args.kwargs