Skip to content

Commit

Permalink
Add test + review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
katarinasupe committed Jun 25, 2024
1 parent 5cda32a commit dcc22d6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 15 deletions.
48 changes: 33 additions & 15 deletions gqlalchemy/vendors/memgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,34 @@ class MemgraphConstants:
UNIQUE = "unique"


def create_transaction(transaction_data) -> MemgraphTransaction:
"""Create a MemgraphTransaction object from transaction data.
Args:
transaction_data (dict): A dictionary containing transaction data.
Returns:
MemgraphTransaction: A MemgraphTransaction object.
"""
return MemgraphTransaction(
username=transaction_data["username"],
transaction_id=transaction_data["transaction_id"],
query=transaction_data["query"],
metadata=transaction_data["metadata"],
)


def create_terminated_transaction(transaction_data) -> MemgraphTerminatedTransaction:
"""Create a MemgraphTerminatedTransaction object from transaction data.
Args:
transaction_data (dict): A dictionary containing transaction data.
Returns:
MemgraphTerminatedTransaction: A MemgraphTerminatedTransaction object.
"""
return MemgraphTerminatedTransaction(
transaction_id=transaction_data["transaction_id"],
killed=transaction_data["killed"],
)


class Memgraph(DatabaseClient):
def __init__(
self,
Expand Down Expand Up @@ -460,16 +488,7 @@ def get_transactions(self) -> List[MemgraphTransaction]:
"""

transactions_data = self.execute_and_fetch("SHOW TRANSACTIONS;")
transactions = []

for transaction_data in transactions_data:
transaction = MemgraphTransaction(
username=transaction_data["username"],
transaction_id=transaction_data["transaction_id"],
query=transaction_data["query"],
metadata=transaction_data["metadata"],
)
transactions.append(transaction)
transactions = list(map(create_transaction, transactions_data))

return transactions

Expand All @@ -484,11 +503,10 @@ def terminate_transactions(self, transaction_ids: List[str]) -> List[MemgraphTer
query = (
"TERMINATE TRANSACTIONS " + ", ".join([f"'{transaction_id}'" for transaction_id in transaction_ids]) + ";"
)
print(query)

terminated_transactions = []
results = self.execute_and_fetch(query)
transactions_data = self.execute_and_fetch(query)

terminated_transactions = list(map(create_terminated_transaction, transactions_data))

for result in results:
terminated_transactions.append(MemgraphTerminatedTransaction(result["transaction_id"], result["killed"]))
print(terminated_transactions)
return terminated_transactions
7 changes: 7 additions & 0 deletions tests/ogm/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,10 @@ def test_get_transactions(memgraph):
assert result[0].transaction_id != ""
assert result[0].query == ["SHOW TRANSACTIONS;"]
assert result[0].metadata == {}


def test_terminate_transactions(memgraph):
result = memgraph.get_transactions()
terminated_transactions = memgraph.terminate_transactions([result[0].transaction_id])
assert terminated_transactions[0].killed is False
assert terminated_transactions[0].transaction_id == result[0].transaction_id

0 comments on commit dcc22d6

Please sign in to comment.