From df4420df6c7dbf95eb7f5d0b45b552ebe3c1fa4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Pu=C5=A1i=C4=87?= Date: Wed, 29 May 2024 14:58:00 +0200 Subject: [PATCH] Add an auth scheme parameter to connections --- gqlalchemy/connection.py | 22 +++++++++++++++++-- gqlalchemy/memgraph_constants.py | 1 + .../export/graph_transporter.py | 7 +++--- .../importing/graph_importer.py | 7 +++--- .../translators/dgl_translator.py | 4 +++- .../translators/nx_translator.py | 15 ++++++++----- .../translators/pyg_translator.py | 4 +++- .../transformations/translators/translator.py | 10 ++++----- gqlalchemy/vendors/database_client.py | 2 ++ gqlalchemy/vendors/memgraph.py | 11 +++++++++- gqlalchemy/vendors/neo4j.py | 21 ++++++++++++------ 11 files changed, 76 insertions(+), 28 deletions(-) diff --git a/gqlalchemy/connection.py b/gqlalchemy/connection.py index 0f4a8699..fa513b59 100644 --- a/gqlalchemy/connection.py +++ b/gqlalchemy/connection.py @@ -32,6 +32,7 @@ def __init__( self, host: str, port: int, + scheme: str, username: str, password: str, encrypted: bool, @@ -39,6 +40,7 @@ def __init__( ): self.host = host self.port = port + self.scheme = scheme self.username = username self.password = password self.encrypted = encrypted @@ -65,6 +67,7 @@ def __init__( self, host: str, port: int, + scheme: str, username: str, password: str, encrypted: bool, @@ -72,7 +75,13 @@ def __init__( lazy: bool = False, ): super().__init__( - host=host, port=port, username=username, password=password, encrypted=encrypted, client_name=client_name + host=host, + port=port, + scheme=scheme, + username=username, + password=password, + encrypted=encrypted, + client_name=client_name, ) self.lazy = lazy self._connection = self._create_connection() @@ -106,6 +115,7 @@ def _create_connection(self) -> Connection: connection = mgclient.connect( host=self.host, port=self.port, + scheme=self.scheme, username=self.username, password=self.password, sslmode=sslmode, @@ -154,6 +164,7 @@ def __init__( self, host: str, port: int, + scheme: str, username: str, password: str, encrypted: bool, @@ -161,7 +172,13 @@ def __init__( lazy: bool = True, ): super().__init__( - host=host, port=port, username=username, password=password, encrypted=encrypted, client_name=client_name + host=host, + port=port, + scheme=scheme, + username=username, + password=password, + encrypted=encrypted, + client_name=client_name, ) self.lazy = lazy self._connection = self._create_connection() @@ -184,6 +201,7 @@ def is_active(self) -> bool: return self._connection is not None def _create_connection(self): + # TODO: antepusic fit in scheme return GraphDatabase.driver( f"bolt://{self.host}:{self.port}", auth=(self.username, self.password), encrypted=self.encrypted ) diff --git a/gqlalchemy/memgraph_constants.py b/gqlalchemy/memgraph_constants.py index 21ede339..26655694 100644 --- a/gqlalchemy/memgraph_constants.py +++ b/gqlalchemy/memgraph_constants.py @@ -2,6 +2,7 @@ MG_HOST = os.getenv("MG_HOST", "127.0.0.1") MG_PORT = int(os.getenv("MG_PORT", "7687")) +MG_SCHEME = os.getenv("MG_SCHEME", "") MG_USERNAME = os.getenv("MG_USERNAME", "") MG_PASSWORD = os.getenv("MG_PASSWORD", "") MG_ENCRYPTED = os.getenv("MG_ENCRYPT", "false").lower() == "true" diff --git a/gqlalchemy/transformations/export/graph_transporter.py b/gqlalchemy/transformations/export/graph_transporter.py index 2bcf2ea6..36e9f07e 100644 --- a/gqlalchemy/transformations/export/graph_transporter.py +++ b/gqlalchemy/transformations/export/graph_transporter.py @@ -41,6 +41,7 @@ def __init__( graph_type: str, host: str = mg_consts.MG_HOST, port: int = mg_consts.MG_PORT, + scheme: str = mg_consts.MG_SCHEME, username: str = mg_consts.MG_USERNAME, password: str = mg_consts.MG_PASSWORD, encrypted: bool = mg_consts.MG_ENCRYPTED, @@ -58,12 +59,12 @@ def __init__( self.graph_type = graph_type.upper() if self.graph_type == GraphType.DGL.name: raise_if_not_imported(dependency=DGLTranslator, dependency_name="dgl") - self.translator = DGLTranslator(host, port, username, password, encrypted, client_name, lazy) + self.translator = DGLTranslator(host, port, scheme, username, password, encrypted, client_name, lazy) elif self.graph_type == GraphType.PYG.name: raise_if_not_imported(dependency=PyGTranslator, dependency_name="torch_geometric") - self.translator = PyGTranslator(host, port, username, password, encrypted, client_name, lazy) + self.translator = PyGTranslator(host, port, scheme, username, password, encrypted, client_name, lazy) elif self.graph_type == GraphType.NX.name: - self.translator = NxTranslator(host, port, username, password, encrypted, client_name, lazy) + self.translator = NxTranslator(host, port, scheme, username, password, encrypted, client_name, lazy) else: raise ValueError("Unknown export option. Currently supported are DGL, PyG and NetworkX.") diff --git a/gqlalchemy/transformations/importing/graph_importer.py b/gqlalchemy/transformations/importing/graph_importer.py index 96fc5927..498b3ddb 100644 --- a/gqlalchemy/transformations/importing/graph_importer.py +++ b/gqlalchemy/transformations/importing/graph_importer.py @@ -47,6 +47,7 @@ def __init__( graph_type: str, host: str = mg_consts.MG_HOST, port: int = mg_consts.MG_PORT, + scheme: str = mg_consts.MG_SCHEME, username: str = mg_consts.MG_USERNAME, password: str = mg_consts.MG_PASSWORD, encrypted: bool = mg_consts.MG_ENCRYPTED, @@ -57,12 +58,12 @@ def __init__( self.graph_type = graph_type.upper() if self.graph_type == GraphType.DGL.name: raise_if_not_imported(dependency=DGLTranslator, dependency_name="dgl") - self.translator = DGLTranslator(host, port, username, password, encrypted, client_name, lazy) + self.translator = DGLTranslator(host, port, scheme, username, password, encrypted, client_name, lazy) elif self.graph_type == GraphType.PYG.name: raise_if_not_imported(dependency=PyGTranslator, dependency_name="torch_geometric") - self.translator = PyGTranslator(host, port, username, password, encrypted, client_name, lazy) + self.translator = PyGTranslator(host, port, scheme, username, password, encrypted, client_name, lazy) elif self.graph_type == GraphType.NX.name: - self.translator = NxTranslator(host, port, username, password, encrypted, client_name, lazy) + self.translator = NxTranslator(host, port, scheme, username, password, encrypted, client_name, lazy) else: raise ValueError("Unknown import option. Currently supported options are: DGL, PyG and NetworkX.") diff --git a/gqlalchemy/transformations/translators/dgl_translator.py b/gqlalchemy/transformations/translators/dgl_translator.py index 193a679f..3abb3b84 100644 --- a/gqlalchemy/transformations/translators/dgl_translator.py +++ b/gqlalchemy/transformations/translators/dgl_translator.py @@ -23,6 +23,7 @@ from gqlalchemy.memgraph_constants import ( MG_HOST, MG_PORT, + MG_SCHEME, MG_USERNAME, MG_PASSWORD, MG_ENCRYPTED, @@ -45,13 +46,14 @@ def __init__( self, host: str = MG_HOST, port: int = MG_PORT, + scheme: str = MG_SCHEME, username: str = MG_USERNAME, password: str = MG_PASSWORD, encrypted: bool = MG_ENCRYPTED, client_name: str = MG_CLIENT_NAME, lazy: bool = MG_LAZY, ) -> None: - super().__init__(host, port, username, password, encrypted, client_name, lazy) + super().__init__(host, port, scheme, username, password, encrypted, client_name, lazy) def to_cypher_queries(self, graph: Union[dgl.DGLGraph, dgl.DGLHeteroGraph]): """Produce cypher queries for data saved as part of the DGL graph. The method handles both homogeneous and heterogeneous graph. If the graph is homogeneous, a default DGL's labels will be used. diff --git a/gqlalchemy/transformations/translators/nx_translator.py b/gqlalchemy/transformations/translators/nx_translator.py index 2e4d6f85..904ff692 100644 --- a/gqlalchemy/transformations/translators/nx_translator.py +++ b/gqlalchemy/transformations/translators/nx_translator.py @@ -28,6 +28,7 @@ from gqlalchemy.memgraph_constants import ( MG_HOST, MG_PORT, + MG_SCHEME, MG_USERNAME, MG_PASSWORD, MG_ENCRYPTED, @@ -152,6 +153,7 @@ def __init__( self, host: str = MG_HOST, port: int = MG_PORT, + scheme: str = MG_SCHEME, username: str = MG_USERNAME, password: str = MG_PASSWORD, encrypted: bool = MG_ENCRYPTED, @@ -159,7 +161,7 @@ def __init__( lazy: bool = MG_LAZY, ) -> None: self.__all__ = ("nx_to_cypher", "nx_graph_to_memgraph_parallel") - super().__init__(host, port, username, password, encrypted, client_name, lazy) + super().__init__(host, port, scheme, username, password, encrypted, client_name, lazy) def to_cypher_queries(self, graph: nx.Graph, config: NetworkXCypherConfig = None) -> Iterator[str]: """Generates a Cypher query for creating a graph.""" @@ -187,6 +189,7 @@ def nx_graph_to_memgraph_parallel( self._check_for_index_hint( self.host, self.port, + self.scheme, self.username, self.password, self.encrypted, @@ -194,7 +197,7 @@ def nx_graph_to_memgraph_parallel( for query_group in query_groups: self._start_parallel_execution( - query_group, self.host, self.port, self.username, self.password, self.encrypted + query_group, self.host, self.port, self.scheme, self.username, self.password, self.encrypted ) def _start_parallel_execution(self, queries_gen: Iterator[str]) -> None: @@ -212,6 +215,7 @@ def _start_parallel_execution(self, queries_gen: Iterator[str]) -> None: process_queries, self.host, self.port, + self.scheme, self.username, self.password, self.encrypted, @@ -224,10 +228,10 @@ def _start_parallel_execution(self, queries_gen: Iterator[str]) -> None: p.join() def _insert_queries( - self, queries: List[str], host: str, port: int, username: str, password: str, encrypted: bool + self, queries: List[str], host: str, port: int, scheme: str, username: str, password: str, encrypted: bool ) -> None: """Used by multiprocess insertion of nx into memgraph, works on a chunk of queries.""" - memgraph = Memgraph(host, port, username, password, encrypted) + memgraph = Memgraph(host, port, scheme, username, password, encrypted) while len(queries) > 0: try: query = queries.pop() @@ -241,12 +245,13 @@ def _check_for_index_hint( self, host: str = "127.0.0.1", port: int = 7687, + scheme: str = "", username: str = "", password: str = "", encrypted: bool = False, ): """Check if the there are indexes, if not show warnings.""" - memgraph = Memgraph(host, port, username, password, encrypted) + memgraph = Memgraph(host, port, scheme, username, password, encrypted) indexes = memgraph.get_indexes() if len(indexes) == 0: logging.getLogger(__file__).warning( diff --git a/gqlalchemy/transformations/translators/pyg_translator.py b/gqlalchemy/transformations/translators/pyg_translator.py index 95095e95..214ff7a7 100644 --- a/gqlalchemy/transformations/translators/pyg_translator.py +++ b/gqlalchemy/transformations/translators/pyg_translator.py @@ -20,6 +20,7 @@ from gqlalchemy.memgraph_constants import ( MG_HOST, MG_PORT, + MG_SCHEME, MG_USERNAME, MG_PASSWORD, MG_ENCRYPTED, @@ -33,13 +34,14 @@ def __init__( self, host: str = MG_HOST, port: int = MG_PORT, + scheme: str = MG_SCHEME, username: str = MG_USERNAME, password: str = MG_PASSWORD, encrypted: bool = MG_ENCRYPTED, client_name: str = MG_CLIENT_NAME, lazy: bool = MG_LAZY, ) -> None: - super().__init__(host, port, username, password, encrypted, client_name, lazy) + super().__init__(host, port, scheme, username, password, encrypted, client_name, lazy) @classmethod def get_node_properties(cls, graph, node_label: str, node_id: int): diff --git a/gqlalchemy/transformations/translators/translator.py b/gqlalchemy/transformations/translators/translator.py index 51f40369..1381f9df 100644 --- a/gqlalchemy/transformations/translators/translator.py +++ b/gqlalchemy/transformations/translators/translator.py @@ -27,6 +27,7 @@ from gqlalchemy.memgraph_constants import ( MG_HOST, MG_PORT, + MG_SCHEME, MG_USERNAME, MG_PASSWORD, MG_ENCRYPTED, @@ -40,10 +41,8 @@ class Translator(ABC): # Lambda function to concat list of labels - merge_labels: Callable[[Set[str]], str] = ( - lambda labels, default_node_label: LABELS_CONCAT.join([label for label in sorted(labels)]) - if len(labels) - else default_node_label + merge_labels: Callable[[Set[str]], str] = lambda labels, default_node_label: ( + LABELS_CONCAT.join([label for label in sorted(labels)]) if len(labels) else default_node_label ) @abstractmethod @@ -51,6 +50,7 @@ def __init__( self, host: str = MG_HOST, port: int = MG_PORT, + scheme: str = MG_SCHEME, username: str = MG_USERNAME, password: str = MG_PASSWORD, encrypted: bool = MG_ENCRYPTED, @@ -58,7 +58,7 @@ def __init__( lazy: bool = MG_LAZY, ) -> None: super().__init__() - self.connection = Memgraph(host, port, username, password, encrypted, client_name, lazy) + self.connection = Memgraph(host, port, scheme, username, password, encrypted, client_name, lazy) @abstractmethod def to_cypher_queries(graph): diff --git a/gqlalchemy/vendors/database_client.py b/gqlalchemy/vendors/database_client.py index df907f67..83766da3 100644 --- a/gqlalchemy/vendors/database_client.py +++ b/gqlalchemy/vendors/database_client.py @@ -30,6 +30,7 @@ def __init__( self, host: str, port: int, + scheme: str, username: str, password: str, encrypted: bool, @@ -37,6 +38,7 @@ def __init__( ): self._host = host self._port = port + self._scheme = scheme self._username = username self._password = password self._encrypted = encrypted diff --git a/gqlalchemy/vendors/memgraph.py b/gqlalchemy/vendors/memgraph.py index 6918b19a..5d4d5f56 100644 --- a/gqlalchemy/vendors/memgraph.py +++ b/gqlalchemy/vendors/memgraph.py @@ -54,6 +54,7 @@ def __init__( self, host: str = mg_consts.MG_HOST, port: int = mg_consts.MG_PORT, + scheme: str = mg_consts.MG_SCHEME, username: str = mg_consts.MG_USERNAME, password: str = mg_consts.MG_PASSWORD, encrypted: bool = mg_consts.MG_ENCRYPTED, @@ -61,7 +62,13 @@ def __init__( lazy: bool = mg_consts.MG_LAZY, ): super().__init__( - host=host, port=port, username=username, password=password, encrypted=encrypted, client_name=client_name + host=host, + port=port, + scheme=scheme, + username=username, + password=password, + encrypted=encrypted, + client_name=client_name, ) self._lazy = lazy self._on_disk_db = None @@ -124,6 +131,7 @@ def new_connection(self) -> Connection: args = dict( host=self._host, port=self._port, + scheme=self._scheme, username=self._username, password=self._password, encrypted=self._encrypted, @@ -197,6 +205,7 @@ def _new_connection(self) -> Connection: args = dict( host=self._host, port=self._port, + scheme=self._scheme, username=self._username, password=self._password, encrypted=self._encrypted, diff --git a/gqlalchemy/vendors/neo4j.py b/gqlalchemy/vendors/neo4j.py index 24428948..62ce4a48 100644 --- a/gqlalchemy/vendors/neo4j.py +++ b/gqlalchemy/vendors/neo4j.py @@ -33,6 +33,7 @@ NEO4J_HOST = os.getenv("NEO4J_HOST", "localhost") NEO4J_PORT = int(os.getenv("NEO4J_PORT", "7687")) +NEO4J_SCHEME = os.getenv("NEO4J_SCHEME", "") NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j") NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "test") NEO4J_ENCRYPTED = os.getenv("NEO4J_ENCRYPT", "false").lower() == "true" @@ -57,13 +58,14 @@ def __init__( self, host: str = NEO4J_HOST, port: int = NEO4J_PORT, + scheme: str = NEO4J_SCHEME, username: str = NEO4J_USERNAME, password: str = NEO4J_PASSWORD, encrypted: bool = NEO4J_ENCRYPTED, client_name: str = NEO4J_CLIENT_NAME, ): super().__init__( - host=host, port=port, username=username, password=password, encrypted=encrypted, client_name=client_name + host=host, port=port, scheme=scheme, username=username, password=password, encrypted=encrypted, client_name=client_name ) self._cached_connection: Optional[Connection] = None @@ -73,12 +75,16 @@ def get_indexes(self) -> List[Neo4jIndex]: for result in self.execute_and_fetch("SHOW INDEX;"): indexes.append( Neo4jIndex( - result[Neo4jConstants.LABEL][0] - if result[Neo4jConstants.TYPE] != Neo4jConstants.LOOKUP - else result[Neo4jConstants.LABEL], - result[Neo4jConstants.PROPERTIES][0] - if result[Neo4jConstants.TYPE] != Neo4jConstants.LOOKUP - else result[Neo4jConstants.PROPERTIES], + ( + result[Neo4jConstants.LABEL][0] + if result[Neo4jConstants.TYPE] != Neo4jConstants.LOOKUP + else result[Neo4jConstants.LABEL] + ), + ( + result[Neo4jConstants.PROPERTIES][0] + if result[Neo4jConstants.TYPE] != Neo4jConstants.LOOKUP + else result[Neo4jConstants.PROPERTIES] + ), result[Neo4jConstants.TYPE], result[Neo4jConstants.UNIQUENESS], ) @@ -125,6 +131,7 @@ def new_connection(self) -> Connection: args = dict( host=self._host, port=self._port, + scheme=self._scheme, username=self._username, password=self._password, encrypted=self._encrypted,