Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 31 additions & 42 deletions providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
if TYPE_CHECKING:
from airflow.models import Connection

# Default Neo4j port
DEFAULT_NEO4J_PORT = 7687


class Neo4jHook(BaseHook):
"""
Expand Down Expand Up @@ -59,24 +62,20 @@ def get_conn(self) -> Driver:
return self.client

self.connection = self.get_connection(self.neo4j_conn_id)

uri = self.get_uri(self.connection)
self.log.info("URI: %s", uri)

is_encrypted = self.connection.extra_dejson.get("encrypted", False)

self.client = self.get_client(self.connection, is_encrypted, uri)

encrypted = self.connection.extra_dejson.get("encrypted", False)
self.client = self._create_driver(self.connection, encrypted, uri)
return self.client

def get_client(self, conn: Connection, encrypted: bool, uri: str) -> Driver:
def _create_driver(self, conn: Connection, encrypted: bool, uri: str) -> Driver:
"""
Determine that relevant driver based on extras.
Create a Neo4j driver instance.

:param conn: Connection object.
:param encrypted: boolean if encrypted connection or not.
:param uri: uri string for connection.
:return: Driver
:param encrypted: Boolean indicating if encrypted connection is required.
:param uri: URI string for connection.
:return: Neo4j Driver instance.
"""
parsed_uri = urlsplit(uri)
kwargs: dict[str, Any] = {}
Expand All @@ -86,50 +85,40 @@ def get_client(self, conn: Connection, encrypted: bool, uri: str) -> Driver:

def get_uri(self, conn: Connection) -> str:
"""
Build the uri based on extras.
Build the URI based on connection extras.

- Default - uses bolt scheme(bolt://)
- neo4j_scheme - neo4j://
- certs_self_signed - neo4j+ssc://
- certs_trusted_ca - neo4j+s://
- Default scheme: bolt
- Neo4j scheme: neo4j (if enabled)
- Encryption schemes:
- certs_self_signed: +ssc
- certs_trusted_ca: +s

:param conn: connection object.
:return: uri
:param conn: Connection object.
:return: Constructed URI string.
"""
use_neo4j_scheme = conn.extra_dejson.get("neo4j_scheme", False)
scheme = "neo4j" if use_neo4j_scheme else "bolt"

# Self signed certificates
ssc = conn.extra_dejson.get("certs_self_signed", False)
scheme = "neo4j" if conn.extra_dejson.get("neo4j_scheme", False) else "bolt"

# Only certificates signed by CA.
trusted_ca = conn.extra_dejson.get("certs_trusted_ca", False)
# Determine encryption scheme
encryption_scheme = ""

if ssc:
if conn.extra_dejson.get("certs_self_signed", False):
encryption_scheme = "+ssc"
elif trusted_ca:
elif conn.extra_dejson.get("certs_trusted_ca", False):
encryption_scheme = "+s"

return f"{scheme}{encryption_scheme}://{conn.host}:{7687 if conn.port is None else conn.port}"
port = conn.port or DEFAULT_NEO4J_PORT
return f"{scheme}{encryption_scheme}://{conn.host}:{port}"

def run(self, query: str, parameters: dict[str, Any] | None = None) -> list[Any]:
"""
Create a neo4j session and execute the query in the session.
Execute a Neo4j query within a session.

:param query: Neo4j query
:param parameters: Optional parameters for the query
:return: Result
:param query: Neo4j query string.
:param parameters: Optional parameters for the query.
:return: List of result records.
"""
driver = self.get_conn()
session_paramters = {}

if db := self.connection.schema:
session_paramters["database"] = db
session_params = {"database": self.connection.schema} if self.connection.schema else {}

with driver.session(**session_paramters) as session:
if parameters is not None:
result = session.run(query, parameters)
else:
result = session.run(query)
with driver.session(**session_params) as session:
result = session.run(query, parameters) if parameters else session.run(query)
return result.data()