Skip to content

Commit

Permalink
Add an auth scheme parameter to connections
Browse files Browse the repository at this point in the history
  • Loading branch information
antepusic committed May 29, 2024
1 parent 267e281 commit df4420d
Show file tree
Hide file tree
Showing 11 changed files with 76 additions and 28 deletions.
22 changes: 20 additions & 2 deletions gqlalchemy/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ def __init__(
self,
host: str,
port: int,
scheme: str,
username: str,
password: str,
encrypted: bool,
client_name: Optional[str] = None,
):
self.host = host
self.port = port
self.scheme = scheme
self.username = username
self.password = password
self.encrypted = encrypted
Expand All @@ -65,14 +67,21 @@ def __init__(
self,
host: str,
port: int,
scheme: str,
username: str,
password: str,
encrypted: bool,
client_name: Optional[str] = None,
lazy: bool = False,
):
super().__init__(
host=host, port=port, username=username, password=password, encrypted=encrypted, client_name=client_name
host=host,
port=port,
scheme=scheme,
username=username,
password=password,
encrypted=encrypted,
client_name=client_name,
)
self.lazy = lazy
self._connection = self._create_connection()
Expand Down Expand Up @@ -106,6 +115,7 @@ def _create_connection(self) -> Connection:
connection = mgclient.connect(
host=self.host,
port=self.port,
scheme=self.scheme,
username=self.username,
password=self.password,
sslmode=sslmode,
Expand Down Expand Up @@ -154,14 +164,21 @@ def __init__(
self,
host: str,
port: int,
scheme: str,
username: str,
password: str,
encrypted: bool,
client_name: Optional[str] = None,
lazy: bool = True,
):
super().__init__(
host=host, port=port, username=username, password=password, encrypted=encrypted, client_name=client_name
host=host,
port=port,
scheme=scheme,
username=username,
password=password,
encrypted=encrypted,
client_name=client_name,
)
self.lazy = lazy
self._connection = self._create_connection()
Expand All @@ -184,6 +201,7 @@ def is_active(self) -> bool:
return self._connection is not None

def _create_connection(self):
# TODO: antepusic fit in scheme
return GraphDatabase.driver(
f"bolt://{self.host}:{self.port}", auth=(self.username, self.password), encrypted=self.encrypted
)
Expand Down
1 change: 1 addition & 0 deletions gqlalchemy/memgraph_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

MG_HOST = os.getenv("MG_HOST", "127.0.0.1")
MG_PORT = int(os.getenv("MG_PORT", "7687"))
MG_SCHEME = os.getenv("MG_SCHEME", "")
MG_USERNAME = os.getenv("MG_USERNAME", "")
MG_PASSWORD = os.getenv("MG_PASSWORD", "")
MG_ENCRYPTED = os.getenv("MG_ENCRYPT", "false").lower() == "true"
Expand Down
7 changes: 4 additions & 3 deletions gqlalchemy/transformations/export/graph_transporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
graph_type: str,
host: str = mg_consts.MG_HOST,
port: int = mg_consts.MG_PORT,
scheme: str = mg_consts.MG_SCHEME,
username: str = mg_consts.MG_USERNAME,
password: str = mg_consts.MG_PASSWORD,
encrypted: bool = mg_consts.MG_ENCRYPTED,
Expand All @@ -58,12 +59,12 @@ def __init__(
self.graph_type = graph_type.upper()
if self.graph_type == GraphType.DGL.name:
raise_if_not_imported(dependency=DGLTranslator, dependency_name="dgl")
self.translator = DGLTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = DGLTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.PYG.name:
raise_if_not_imported(dependency=PyGTranslator, dependency_name="torch_geometric")
self.translator = PyGTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = PyGTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.NX.name:
self.translator = NxTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = NxTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
else:
raise ValueError("Unknown export option. Currently supported are DGL, PyG and NetworkX.")

Expand Down
7 changes: 4 additions & 3 deletions gqlalchemy/transformations/importing/graph_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
graph_type: str,
host: str = mg_consts.MG_HOST,
port: int = mg_consts.MG_PORT,
scheme: str = mg_consts.MG_SCHEME,
username: str = mg_consts.MG_USERNAME,
password: str = mg_consts.MG_PASSWORD,
encrypted: bool = mg_consts.MG_ENCRYPTED,
Expand All @@ -57,12 +58,12 @@ def __init__(
self.graph_type = graph_type.upper()
if self.graph_type == GraphType.DGL.name:
raise_if_not_imported(dependency=DGLTranslator, dependency_name="dgl")
self.translator = DGLTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = DGLTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.PYG.name:
raise_if_not_imported(dependency=PyGTranslator, dependency_name="torch_geometric")
self.translator = PyGTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = PyGTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
elif self.graph_type == GraphType.NX.name:
self.translator = NxTranslator(host, port, username, password, encrypted, client_name, lazy)
self.translator = NxTranslator(host, port, scheme, username, password, encrypted, client_name, lazy)
else:
raise ValueError("Unknown import option. Currently supported options are: DGL, PyG and NetworkX.")

Expand Down
4 changes: 3 additions & 1 deletion gqlalchemy/transformations/translators/dgl_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from gqlalchemy.memgraph_constants import (
MG_HOST,
MG_PORT,
MG_SCHEME,
MG_USERNAME,
MG_PASSWORD,
MG_ENCRYPTED,
Expand All @@ -45,13 +46,14 @@ def __init__(
self,
host: str = MG_HOST,
port: int = MG_PORT,
scheme: str = MG_SCHEME,
username: str = MG_USERNAME,
password: str = MG_PASSWORD,
encrypted: bool = MG_ENCRYPTED,
client_name: str = MG_CLIENT_NAME,
lazy: bool = MG_LAZY,
) -> None:
super().__init__(host, port, username, password, encrypted, client_name, lazy)
super().__init__(host, port, scheme, username, password, encrypted, client_name, lazy)

def to_cypher_queries(self, graph: Union[dgl.DGLGraph, dgl.DGLHeteroGraph]):
"""Produce cypher queries for data saved as part of the DGL graph. The method handles both homogeneous and heterogeneous graph. If the graph is homogeneous, a default DGL's labels will be used.
Expand Down
15 changes: 10 additions & 5 deletions gqlalchemy/transformations/translators/nx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from gqlalchemy.memgraph_constants import (
MG_HOST,
MG_PORT,
MG_SCHEME,
MG_USERNAME,
MG_PASSWORD,
MG_ENCRYPTED,
Expand Down Expand Up @@ -152,14 +153,15 @@ def __init__(
self,
host: str = MG_HOST,
port: int = MG_PORT,
scheme: str = MG_SCHEME,
username: str = MG_USERNAME,
password: str = MG_PASSWORD,
encrypted: bool = MG_ENCRYPTED,
client_name: str = MG_CLIENT_NAME,
lazy: bool = MG_LAZY,
) -> None:
self.__all__ = ("nx_to_cypher", "nx_graph_to_memgraph_parallel")
super().__init__(host, port, username, password, encrypted, client_name, lazy)
super().__init__(host, port, scheme, username, password, encrypted, client_name, lazy)

def to_cypher_queries(self, graph: nx.Graph, config: NetworkXCypherConfig = None) -> Iterator[str]:
"""Generates a Cypher query for creating a graph."""
Expand Down Expand Up @@ -187,14 +189,15 @@ def nx_graph_to_memgraph_parallel(
self._check_for_index_hint(
self.host,
self.port,
self.scheme,
self.username,
self.password,
self.encrypted,
)

for query_group in query_groups:
self._start_parallel_execution(
query_group, self.host, self.port, self.username, self.password, self.encrypted
query_group, self.host, self.port, self.scheme, self.username, self.password, self.encrypted
)

def _start_parallel_execution(self, queries_gen: Iterator[str]) -> None:
Expand All @@ -212,6 +215,7 @@ def _start_parallel_execution(self, queries_gen: Iterator[str]) -> None:
process_queries,
self.host,
self.port,
self.scheme,
self.username,
self.password,
self.encrypted,
Expand All @@ -224,10 +228,10 @@ def _start_parallel_execution(self, queries_gen: Iterator[str]) -> None:
p.join()

def _insert_queries(
self, queries: List[str], host: str, port: int, username: str, password: str, encrypted: bool
self, queries: List[str], host: str, port: int, scheme: str, username: str, password: str, encrypted: bool
) -> None:
"""Used by multiprocess insertion of nx into memgraph, works on a chunk of queries."""
memgraph = Memgraph(host, port, username, password, encrypted)
memgraph = Memgraph(host, port, scheme, username, password, encrypted)
while len(queries) > 0:
try:
query = queries.pop()
Expand All @@ -241,12 +245,13 @@ def _check_for_index_hint(
self,
host: str = "127.0.0.1",
port: int = 7687,
scheme: str = "",
username: str = "",
password: str = "",
encrypted: bool = False,
):
"""Check if the there are indexes, if not show warnings."""
memgraph = Memgraph(host, port, username, password, encrypted)
memgraph = Memgraph(host, port, scheme, username, password, encrypted)
indexes = memgraph.get_indexes()
if len(indexes) == 0:
logging.getLogger(__file__).warning(
Expand Down
4 changes: 3 additions & 1 deletion gqlalchemy/transformations/translators/pyg_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gqlalchemy.memgraph_constants import (
MG_HOST,
MG_PORT,
MG_SCHEME,
MG_USERNAME,
MG_PASSWORD,
MG_ENCRYPTED,
Expand All @@ -33,13 +34,14 @@ def __init__(
self,
host: str = MG_HOST,
port: int = MG_PORT,
scheme: str = MG_SCHEME,
username: str = MG_USERNAME,
password: str = MG_PASSWORD,
encrypted: bool = MG_ENCRYPTED,
client_name: str = MG_CLIENT_NAME,
lazy: bool = MG_LAZY,
) -> None:
super().__init__(host, port, username, password, encrypted, client_name, lazy)
super().__init__(host, port, scheme, username, password, encrypted, client_name, lazy)

@classmethod
def get_node_properties(cls, graph, node_label: str, node_id: int):
Expand Down
10 changes: 5 additions & 5 deletions gqlalchemy/transformations/translators/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from gqlalchemy.memgraph_constants import (
MG_HOST,
MG_PORT,
MG_SCHEME,
MG_USERNAME,
MG_PASSWORD,
MG_ENCRYPTED,
Expand All @@ -40,25 +41,24 @@

class Translator(ABC):
# Lambda function to concat list of labels
merge_labels: Callable[[Set[str]], str] = (
lambda labels, default_node_label: LABELS_CONCAT.join([label for label in sorted(labels)])
if len(labels)
else default_node_label
merge_labels: Callable[[Set[str]], str] = lambda labels, default_node_label: (
LABELS_CONCAT.join([label for label in sorted(labels)]) if len(labels) else default_node_label
)

@abstractmethod
def __init__(
self,
host: str = MG_HOST,
port: int = MG_PORT,
scheme: str = MG_SCHEME,
username: str = MG_USERNAME,
password: str = MG_PASSWORD,
encrypted: bool = MG_ENCRYPTED,
client_name: str = MG_CLIENT_NAME,
lazy: bool = MG_LAZY,
) -> None:
super().__init__()
self.connection = Memgraph(host, port, username, password, encrypted, client_name, lazy)
self.connection = Memgraph(host, port, scheme, username, password, encrypted, client_name, lazy)

@abstractmethod
def to_cypher_queries(graph):
Expand Down
2 changes: 2 additions & 0 deletions gqlalchemy/vendors/database_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ def __init__(
self,
host: str,
port: int,
scheme: str,
username: str,
password: str,
encrypted: bool,
client_name: str,
):
self._host = host
self._port = port
self._scheme = scheme
self._username = username
self._password = password
self._encrypted = encrypted
Expand Down
11 changes: 10 additions & 1 deletion gqlalchemy/vendors/memgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,21 @@ def __init__(
self,
host: str = mg_consts.MG_HOST,
port: int = mg_consts.MG_PORT,
scheme: str = mg_consts.MG_SCHEME,
username: str = mg_consts.MG_USERNAME,
password: str = mg_consts.MG_PASSWORD,
encrypted: bool = mg_consts.MG_ENCRYPTED,
client_name: str = mg_consts.MG_CLIENT_NAME,
lazy: bool = mg_consts.MG_LAZY,
):
super().__init__(
host=host, port=port, username=username, password=password, encrypted=encrypted, client_name=client_name
host=host,
port=port,
scheme=scheme,
username=username,
password=password,
encrypted=encrypted,
client_name=client_name,
)
self._lazy = lazy
self._on_disk_db = None
Expand Down Expand Up @@ -124,6 +131,7 @@ def new_connection(self) -> Connection:
args = dict(
host=self._host,
port=self._port,
scheme=self._scheme,
username=self._username,
password=self._password,
encrypted=self._encrypted,
Expand Down Expand Up @@ -197,6 +205,7 @@ def _new_connection(self) -> Connection:
args = dict(
host=self._host,
port=self._port,
scheme=self._scheme,
username=self._username,
password=self._password,
encrypted=self._encrypted,
Expand Down
Loading

0 comments on commit df4420d

Please sign in to comment.