From dcc22d69bbe030887d9fe9562966ba7d5c0cdd65 Mon Sep 17 00:00:00 2001 From: katarinasupe Date: Tue, 25 Jun 2024 13:48:44 +0200 Subject: [PATCH] Add test + review comments --- gqlalchemy/vendors/memgraph.py | 48 +++++++++++++++++++++++----------- tests/ogm/test_transactions.py | 7 +++++ 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/gqlalchemy/vendors/memgraph.py b/gqlalchemy/vendors/memgraph.py index d1ac7fde..13ec5eb2 100644 --- a/gqlalchemy/vendors/memgraph.py +++ b/gqlalchemy/vendors/memgraph.py @@ -69,6 +69,34 @@ class MemgraphConstants: UNIQUE = "unique" +def create_transaction(transaction_data) -> MemgraphTransaction: + """Create a MemgraphTransaction object from transaction data. + Args: + transaction_data (dict): A dictionary containing transaction data. + Returns: + MemgraphTransaction: A MemgraphTransaction object. + """ + return MemgraphTransaction( + username=transaction_data["username"], + transaction_id=transaction_data["transaction_id"], + query=transaction_data["query"], + metadata=transaction_data["metadata"], + ) + + +def create_terminated_transaction(transaction_data) -> MemgraphTerminatedTransaction: + """Create a MemgraphTerminatedTransaction object from transaction data. + Args: + transaction_data (dict): A dictionary containing transaction data. + Returns: + MemgraphTerminatedTransaction: A MemgraphTerminatedTransaction object. + """ + return MemgraphTerminatedTransaction( + transaction_id=transaction_data["transaction_id"], + killed=transaction_data["killed"], + ) + + class Memgraph(DatabaseClient): def __init__( self, @@ -460,16 +488,7 @@ def get_transactions(self) -> List[MemgraphTransaction]: """ transactions_data = self.execute_and_fetch("SHOW TRANSACTIONS;") - transactions = [] - - for transaction_data in transactions_data: - transaction = MemgraphTransaction( - username=transaction_data["username"], - transaction_id=transaction_data["transaction_id"], - query=transaction_data["query"], - metadata=transaction_data["metadata"], - ) - transactions.append(transaction) + transactions = list(map(create_transaction, transactions_data)) return transactions @@ -484,11 +503,10 @@ def terminate_transactions(self, transaction_ids: List[str]) -> List[MemgraphTer query = ( "TERMINATE TRANSACTIONS " + ", ".join([f"'{transaction_id}'" for transaction_id in transaction_ids]) + ";" ) - print(query) + terminated_transactions = [] - results = self.execute_and_fetch(query) + transactions_data = self.execute_and_fetch(query) + + terminated_transactions = list(map(create_terminated_transaction, transactions_data)) - for result in results: - terminated_transactions.append(MemgraphTerminatedTransaction(result["transaction_id"], result["killed"])) - print(terminated_transactions) return terminated_transactions diff --git a/tests/ogm/test_transactions.py b/tests/ogm/test_transactions.py index f304a8bb..c814eb10 100644 --- a/tests/ogm/test_transactions.py +++ b/tests/ogm/test_transactions.py @@ -5,3 +5,10 @@ def test_get_transactions(memgraph): assert result[0].transaction_id != "" assert result[0].query == ["SHOW TRANSACTIONS;"] assert result[0].metadata == {} + + +def test_terminate_transactions(memgraph): + result = memgraph.get_transactions() + terminated_transactions = memgraph.terminate_transactions([result[0].transaction_id]) + assert terminated_transactions[0].killed is False + assert terminated_transactions[0].transaction_id == result[0].transaction_id