Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed Mandatory Encryption in Neo4jHook #30418

Merged
merged 9 commits into from
Apr 21, 2023
23 changes: 20 additions & 3 deletions airflow/providers/neo4j/hooks/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,29 @@ def get_conn(self) -> Driver:

is_encrypted = self.connection.extra_dejson.get("encrypted", False)

self.client = GraphDatabase.driver(
uri, auth=(self.connection.login, self.connection.password), encrypted=is_encrypted
)
self.client = self.get_client(self.connection, is_encrypted, uri)

return self.client

def get_client(self, conn: Connection, encrypted: bool, uri: str) -> Driver:
"""
Function to determine that relevant driver based on extras.
:param conn: Connection object.
:param encrypted: boolean if encrypted connection or not.
:return: Neo4jDriver
"""
# Self signed certificates
ssc = conn.extra_dejson.get("certs_self_signed", False)

# Only certificates signed by CA.
trusted_ca = conn.extra_dejson.get("certs_trusted_ca", False)

if trusted_ca or ssc:
driver = GraphDatabase.driver(uri, auth=(conn.login, conn.password))
else:
driver = GraphDatabase.driver(uri, auth=(conn.login, conn.password), encrypted=encrypted)
eladkal marked this conversation as resolved.
Show resolved Hide resolved
return driver

def get_uri(self, conn: Connection) -> str:
"""
Build the uri based on extras
Expand Down
40 changes: 36 additions & 4 deletions tests/providers/neo4j/hooks/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@
import pytest

from airflow.models import Connection
from airflow.providers.neo4j.hooks.neo4j import Neo4jHook
from airflow.providers.neo4j.hooks.neo4j import Driver, Neo4jHook


class TestNeo4jHookConn:
@pytest.mark.parametrize(
"conn_extra, expected_uri",
[
({}, "bolt://host:7687"),
({"bolt_scheme": True}, "bolt://host:7687"),
({"certs_self_signed": True, "bolt_scheme": True}, "bolt+ssc://host:7687"),
({"certs_trusted_ca": True, "bolt_scheme": True}, "bolt+s://host:7687"),
({"neo4j_scheme": False}, "bolt://host:7687"),
({"certs_self_signed": True, "neo4j_scheme": False}, "bolt+ssc://host:7687"),
({"certs_trusted_ca": True, "neo4j_scheme": False}, "bolt+s://host:7687"),
({"certs_self_signed": True, "neo4j_scheme": True}, "neo4j+ssc://host:7687"),
({"certs_trusted_ca": True, "neo4j_scheme": True}, "neo4j+s://host:7687"),
],
)
def test_get_uri_neo4j_scheme(self, conn_extra, expected_uri):
Expand Down Expand Up @@ -101,3 +103,33 @@ def test_run_without_schema(self, mock_graph_database):
)
session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value
assert op_result == session.run.return_value.data.return_value

@pytest.mark.parametrize(
"conn_extra, expected",
[
({"certs_self_signed": True, "neo4j_scheme": False, "encrypted": True}, True),
({"certs_self_signed": True, "neo4j_scheme": False, "encrypted": False}, True),
({"certs_trusted_ca": True, "neo4j_scheme": False, "encrypted": False}, True),
({"certs_self_signed": True, "neo4j_scheme": True, "encrypted": False}, True),
({"certs_trusted_ca": True, "neo4j_scheme": True, "encrypted": False}, True),
({"certs_trusted_ca": False, "neo4j_scheme": False, "encrypted": True}, True),
],
)
def test_get_client(self, conn_extra, expected):
connection = Connection(
conn_type="neo4j",
login="login",
password="password",
host="host",
schema="schema",
extra=conn_extra,
)
# Use the environment variable mocking to test saving the configuration as a URI and
# to avoid mocking Airflow models class
with mock.patch.dict("os.environ", AIRFLOW_CONN_NEO4J_DEFAULT=connection.get_uri()):
neo4j_hook = Neo4jHook()
is_encrypted = conn_extra.get("encrypted", False)
with neo4j_hook.get_client(
conn=connection, encrypted=is_encrypted, uri=neo4j_hook.get_uri(connection)
) as client:
assert isinstance(client, Driver) == expected
eladkal marked this conversation as resolved.
Show resolved Hide resolved