Skip to content

Commit

Permalink
Removed Mandatory Encryption in Neo4jHook (#30418)
Browse files Browse the repository at this point in the history
* Removed mandatory encryption in neo4jhook

* Added unit tests and altered exising

* Added unit-test and fixed existing ones.

* Changed the implementation of get_client

* Changed test for encrypted param

* fix unit test and check if encrypted arg is provided or not

* fix static checks

* fix unit tests fo python 3.7

---------

Co-authored-by: Hussein Awala <hussein@awala.fr>
  • Loading branch information
eldar-elne and hussein-awala authored Apr 21, 2023
1 parent 93a5422 commit cd45842
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
19 changes: 16 additions & 3 deletions airflow/providers/neo4j/hooks/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

from typing import Any
from urllib.parse import urlsplit

from neo4j import Driver, GraphDatabase

Expand Down Expand Up @@ -61,12 +62,24 @@ def get_conn(self) -> Driver:

is_encrypted = self.connection.extra_dejson.get("encrypted", False)

self.client = GraphDatabase.driver(
uri, auth=(self.connection.login, self.connection.password), encrypted=is_encrypted
)
self.client = self.get_client(self.connection, is_encrypted, uri)

return self.client

def get_client(self, conn: Connection, encrypted: bool, uri: str) -> Driver:
"""
Function to determine that relevant driver based on extras.
:param conn: Connection object.
:param encrypted: boolean if encrypted connection or not.
:param uri: uri string for connection.
:return: Driver
"""
parsed_uri = urlsplit(uri)
kwargs: dict[str, Any] = {}
if parsed_uri.scheme in ["bolt", "neo4j"]:
kwargs["encrypted"] = encrypted
return GraphDatabase.driver(uri, auth=(conn.login, conn.password), **kwargs)

def get_uri(self, conn: Connection) -> str:
"""
Build the uri based on extras
Expand Down
44 changes: 41 additions & 3 deletions tests/providers/neo4j/hooks/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ class TestNeo4jHookConn:
"conn_extra, expected_uri",
[
({}, "bolt://host:7687"),
({"bolt_scheme": True}, "bolt://host:7687"),
({"certs_self_signed": True, "bolt_scheme": True}, "bolt+ssc://host:7687"),
({"certs_trusted_ca": True, "bolt_scheme": True}, "bolt+s://host:7687"),
({"neo4j_scheme": False}, "bolt://host:7687"),
({"certs_self_signed": True, "neo4j_scheme": False}, "bolt+ssc://host:7687"),
({"certs_trusted_ca": True, "neo4j_scheme": False}, "bolt+s://host:7687"),
({"certs_self_signed": True, "neo4j_scheme": True}, "neo4j+ssc://host:7687"),
({"certs_trusted_ca": True, "neo4j_scheme": True}, "neo4j+s://host:7687"),
],
)
def test_get_uri_neo4j_scheme(self, conn_extra, expected_uri):
Expand Down Expand Up @@ -101,3 +103,39 @@ def test_run_without_schema(self, mock_graph_database):
)
session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value
assert op_result == session.run.return_value.data.return_value

@pytest.mark.parametrize(
"conn_extra, should_provide_encrypted, expected_encrypted",
[
({}, True, False),
({"neo4j_scheme": False, "encrypted": True}, True, True),
({"certs_self_signed": False, "neo4j_scheme": False, "encrypted": False}, True, False),
({"certs_trusted_ca": False, "neo4j_scheme": False, "encrypted": False}, True, False),
({"certs_self_signed": False, "neo4j_scheme": True, "encrypted": False}, True, False),
({"certs_trusted_ca": False, "neo4j_scheme": True, "encrypted": False}, True, False),
({"certs_self_signed": True, "neo4j_scheme": False, "encrypted": False}, False, None),
({"certs_trusted_ca": True, "neo4j_scheme": False, "encrypted": False}, False, None),
({"certs_self_signed": True, "neo4j_scheme": True, "encrypted": False}, False, None),
({"certs_trusted_ca": True, "neo4j_scheme": True, "encrypted": False}, False, None),
],
)
@mock.patch("airflow.providers.neo4j.hooks.neo4j.GraphDatabase.driver")
def test_encrypted_provided(
self, mock_graph_database, conn_extra, should_provide_encrypted, expected_encrypted
):
connection = Connection(
conn_type="neo4j",
login="login",
password="password",
host="host",
schema="schema",
extra=conn_extra,
)
with mock.patch.dict("os.environ", AIRFLOW_CONN_NEO4J_DEFAULT=connection.get_uri()):
neo4j_hook = Neo4jHook()
with neo4j_hook.get_conn():
if should_provide_encrypted:
assert "encrypted" in mock_graph_database.call_args[1]
assert mock_graph_database.call_args[1]["encrypted"] == expected_encrypted
else:
assert "encrypted" not in mock_graph_database.call_args[1]

0 comments on commit cd45842

Please sign in to comment.