diff --git a/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py b/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py index f3580002bcc6a..e0bb6ab52c311 100644 --- a/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py +++ b/providers/neo4j/src/airflow/providers/neo4j/hooks/neo4j.py @@ -32,6 +32,9 @@ if TYPE_CHECKING: from airflow.models import Connection +# Default Neo4j port +DEFAULT_NEO4J_PORT = 7687 + class Neo4jHook(BaseHook): """ @@ -59,24 +62,20 @@ def get_conn(self) -> Driver: return self.client self.connection = self.get_connection(self.neo4j_conn_id) - uri = self.get_uri(self.connection) self.log.info("URI: %s", uri) - - is_encrypted = self.connection.extra_dejson.get("encrypted", False) - - self.client = self.get_client(self.connection, is_encrypted, uri) - + encrypted = self.connection.extra_dejson.get("encrypted", False) + self.client = self._create_driver(self.connection, encrypted, uri) return self.client - def get_client(self, conn: Connection, encrypted: bool, uri: str) -> Driver: + def _create_driver(self, conn: Connection, encrypted: bool, uri: str) -> Driver: """ - Determine that relevant driver based on extras. + Create a Neo4j driver instance. :param conn: Connection object. - :param encrypted: boolean if encrypted connection or not. - :param uri: uri string for connection. - :return: Driver + :param encrypted: Boolean indicating if encrypted connection is required. + :param uri: URI string for connection. + :return: Neo4j Driver instance. """ parsed_uri = urlsplit(uri) kwargs: dict[str, Any] = {} @@ -86,50 +85,40 @@ def get_client(self, conn: Connection, encrypted: bool, uri: str) -> Driver: def get_uri(self, conn: Connection) -> str: """ - Build the uri based on extras. + Build the URI based on connection extras. - - Default - uses bolt scheme(bolt://) - - neo4j_scheme - neo4j:// - - certs_self_signed - neo4j+ssc:// - - certs_trusted_ca - neo4j+s:// + - Default scheme: bolt + - Neo4j scheme: neo4j (if enabled) + - Encryption schemes: + - certs_self_signed: +ssc + - certs_trusted_ca: +s - :param conn: connection object. - :return: uri + :param conn: Connection object. + :return: Constructed URI string. """ - use_neo4j_scheme = conn.extra_dejson.get("neo4j_scheme", False) - scheme = "neo4j" if use_neo4j_scheme else "bolt" - - # Self signed certificates - ssc = conn.extra_dejson.get("certs_self_signed", False) + scheme = "neo4j" if conn.extra_dejson.get("neo4j_scheme", False) else "bolt" - # Only certificates signed by CA. - trusted_ca = conn.extra_dejson.get("certs_trusted_ca", False) + # Determine encryption scheme encryption_scheme = "" - - if ssc: + if conn.extra_dejson.get("certs_self_signed", False): encryption_scheme = "+ssc" - elif trusted_ca: + elif conn.extra_dejson.get("certs_trusted_ca", False): encryption_scheme = "+s" - return f"{scheme}{encryption_scheme}://{conn.host}:{7687 if conn.port is None else conn.port}" + port = conn.port or DEFAULT_NEO4J_PORT + return f"{scheme}{encryption_scheme}://{conn.host}:{port}" def run(self, query: str, parameters: dict[str, Any] | None = None) -> list[Any]: """ - Create a neo4j session and execute the query in the session. + Execute a Neo4j query within a session. - :param query: Neo4j query - :param parameters: Optional parameters for the query - :return: Result + :param query: Neo4j query string. + :param parameters: Optional parameters for the query. + :return: List of result records. """ driver = self.get_conn() - session_paramters = {} - - if db := self.connection.schema: - session_paramters["database"] = db + session_params = {"database": self.connection.schema} if self.connection.schema else {} - with driver.session(**session_paramters) as session: - if parameters is not None: - result = session.run(query, parameters) - else: - result = session.run(query) + with driver.session(**session_params) as session: + result = session.run(query, parameters) if parameters else session.run(query) return result.data()