Skip to content

Commit

Permalink
Merge undo commands to help doing stuff as block
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelma committed May 24, 2023
1 parent ada77fc commit e3f2eff
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 97 deletions.
59 changes: 33 additions & 26 deletions spinetoolbox/spine_db_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,68 +32,75 @@ def undo_age(self):


class AgedUndoCommand(QUndoCommand):
def __init__(self, parent=None):
def __init__(self, parent=None, identifier=-1):
"""
Args:
parent (QUndoCommand, optional): The parent command, used for defining macros.
"""
super().__init__(parent=parent)
self.parent = parent
self._age = -1
self.children = []
self._id = identifier
self._buddies = []
self.merged = False

def id(self):
return self._id

def ours(self):
return [self] + self._buddies

def mergeWith(self, command):
if not isinstance(command, AgedUndoCommand):
return False
self._buddies += command.ours()
command.merged = True
return True

def redo(self):
if self.merged:
return
super().redo()
for cmd in self._buddies:
cmd.redo()
self._age = time.time()

def undo(self):
if self.merged:
return
for cmd in reversed(self._buddies):
cmd.undo()
super().undo()
self._age = time.time()

@property
def age(self):
return self._age

def check(self):
if self.children and all(cmd.isObsolete() for cmd in self.children):
self.setObsolete(True)

def __del__(self):
# TODO: try to make sure this works as intended. The idea is to fix a segfault at exiting the app
# after running an 'import_data' command.
if self.parent is not None:
return
while self.children:
self.children.pop().parent = None


class SpineDBCommand(AgedUndoCommand):
"""Base class for all commands that modify a Spine DB."""

def __init__(self, db_mngr, db_map, parent=None):
def __init__(self, db_mngr, db_map, **kwargs):
"""
Args:
db_mngr (SpineDBManager): SpineDBManager instance
db_map (DiffDatabaseMapping): DiffDatabaseMapping instance
"""
super().__init__(parent=parent)
super().__init__(**kwargs)
self.db_mngr = db_mngr
self.db_map = db_map
if isinstance(parent, AgedUndoCommand):
# Stores a ref to this in the parent so Python doesn't delete it
parent.children.append(self)


class AddItemsCommand(SpineDBCommand):
def __init__(self, db_mngr, db_map, item_type, data, check=True, parent=None):
def __init__(self, db_mngr, db_map, item_type, data, check=True, **kwargs):
"""
Args:
db_mngr (SpineDBManager): SpineDBManager instance
db_map (DiffDatabaseMapping): DiffDatabaseMapping instance
data (list): list of dict-items to add
item_type (str): the item type
"""
super().__init__(db_mngr, db_map, parent=parent)
super().__init__(db_mngr, db_map, **kwargs)
if not data:
self.setObsolete(True)
self.item_type = item_type
Expand All @@ -119,15 +126,15 @@ def undo(self):


class UpdateItemsCommand(SpineDBCommand):
def __init__(self, db_mngr, db_map, item_type, data, check=True, parent=None):
def __init__(self, db_mngr, db_map, item_type, data, check=True, **kwargs):
"""
Args:
db_mngr (SpineDBManager): SpineDBManager instance
db_map (DiffDatabaseMapping): DiffDatabaseMapping instance
item_type (str): the item type
data (list): list of dict-items to update
"""
super().__init__(db_mngr, db_map, parent=parent)
super().__init__(db_mngr, db_map, **kwargs)
if not data:
self.setObsolete(True)
self.item_type = item_type
Expand All @@ -151,15 +158,15 @@ def undo(self):


class RemoveItemsCommand(SpineDBCommand):
def __init__(self, db_mngr, db_map, item_type, ids, parent=None):
def __init__(self, db_mngr, db_map, item_type, ids, **kwargs):
"""
Args:
db_mngr (SpineDBManager): SpineDBManager instance
db_map (DiffDatabaseMapping): DiffDatabaseMapping instance
item_type (str): the item type
ids (set): set of ids to remove
"""
super().__init__(db_mngr, db_map, parent=parent)
super().__init__(db_mngr, db_map, **kwargs)
if not ids:
self.setObsolete(True)
self.item_type = item_type
Expand Down
35 changes: 4 additions & 31 deletions spinetoolbox/spine_db_editor/widgets/add_items_dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
ManageItemsDialog,
ManageItemsDialogBase,
)
from ...spine_db_commands import AgedUndoCommand, AddItemsCommand, RemoveItemsCommand


class AddReadyEntitiesDialog(ManageItemsDialogBase):
Expand Down Expand Up @@ -492,32 +491,8 @@ def accept(self):
return
element_id = entities[entity_class_id, element_name]["id"]
element_id_list.append(element_id)
active_alt_ids = []
inactive_alt_ids = []
alternative_ids = self.db_map_alt_id_lookup[db_map]
for alt_name in active_alts:
if alt_name not in alternative_ids:
self.parent().msg_error.emit(
f"Invalid alternative '{alt_name}' for db '{db_name}' at row {i + 1}"
)
return
active_alt_ids.append(alternative_ids[alt_name])
for alt_name in inactive_alts:
if alt_name not in alternative_ids:
self.parent().msg_error.emit(
f"Invalid alternative '{alt_name}' for db '{db_name}' at row {i + 1}"
)
return
inactive_alt_ids.append(alternative_ids[alt_name])
item = pre_item.copy()
item.update(
{
'element_id_list': element_id_list,
'class_id': class_id,
'active_alternative_id_list': active_alt_ids,
'inactive_alternative_id_list': inactive_alt_ids,
}
)
item.update({'element_id_list': element_id_list, 'class_id': class_id})
db_map_data.setdefault(db_map, []).append(item)
if not db_map_data:
self.parent().msg_error.emit("Nothing to add")
Expand Down Expand Up @@ -913,11 +888,9 @@ def accept(self):
{"entity_id": ent["id"], "entity_class_id": ent["class_id"], "member_id": member_id} for member_id in added
]
ids_to_remove = [x["id"] for x in self._entity_groups() if x["member_id"] in removed]
macro = AgedUndoCommand()
macro.setText(f"manage {self.entity_item.display_data}'s members")
identifier = self.db_mngr.get_command_identifier()
if items_to_add:
AddItemsCommand(self.db_mngr, self.db_map, "entity_group", items_to_add, parent=macro)
self.db_mngr.add_items("entity_group", {self.db_map: items_to_add}, identifier=identifier)
if ids_to_remove:
RemoveItemsCommand(self.db_mngr, self.db_map, "entity_group", ids_to_remove, parent=macro)
self.db_mngr.undo_stack[self.db_map].push(macro)
self.db_mngr.remove_items({self.db_map: {"entity_group": ids_to_remove}}, identifier=identifier)
super().accept()
2 changes: 0 additions & 2 deletions spinetoolbox/spine_db_editor/widgets/graph_view_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ def add_db_map_ids_to_items(self, db_map_data):
Returns:
list: tuples (db_map, id) that didn't match any item in the view.
"""
# FIXME: It looks like undoing twice and then redoing once restores all the items.
# It should only restores the items corresponding to one redo operation at a time
added_db_map_ids_by_key = {}
for db_map, entities in db_map_data.items():
for entity in entities:
Expand Down
80 changes: 42 additions & 38 deletions spinetoolbox/spine_db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
The SpineDBManager class
"""

import json
import os
import json
from PySide6.QtCore import Qt, QObject, Signal, Slot
from PySide6.QtWidgets import QMessageBox, QWidget
from PySide6.QtGui import QFontMetrics, QFont, QWindow
Expand Down Expand Up @@ -46,7 +46,7 @@
from spinedb_api.spine_io.exporters.excel import export_spine_database_to_xlsx
from .spine_db_icon_manager import SpineDBIconManager
from .spine_db_worker import SpineDBWorker
from .spine_db_commands import AgedUndoStack, AgedUndoCommand, AddItemsCommand, UpdateItemsCommand, RemoveItemsCommand
from .spine_db_commands import AgedUndoStack, AddItemsCommand, UpdateItemsCommand, RemoveItemsCommand
from .mvcmodels.shared import PARSED_ROLE
from .spine_db_editor.widgets.multi_spine_db_editor import MultiSpineDBEditor
from .helpers import get_upgrade_db_promt_text, busy_effect
Expand Down Expand Up @@ -102,6 +102,7 @@ def __init__(self, settings, parent):
self.redo_action = {}
self._icon_mngr = {}
self._connect_signals()
self._cmd_id = 0

def _connect_signals(self):
self.error_msg.connect(self.receive_error_msg)
Expand Down Expand Up @@ -836,30 +837,23 @@ def import_data(self, db_map_data, command_text="Import data"):
to `get_data_for_import`
command_text (str, optional): What to call the command that condenses the operation.
"""
db_map_error_log = dict()
db_map_error_log = {}
for db_map, data in db_map_data.items():
try:
data_for_import = get_data_for_import(db_map, **data)
except (TypeError, ValueError) as err:
msg = f"Failed to import data: {err}. Please check that your data source has the right format."
db_map_error_log.setdefault(db_map, []).append(msg)
continue
macro = AgedUndoCommand()
macro.setText(command_text)
# NOTE: we push the import macro before adding the children,
# because we *need* to call redo() on the children one by one so the data gets in gradually
self.undo_stack[db_map].push(macro)
identifier = self.get_command_identifier()
for item_type, (to_add, to_update, import_error_log) in data_for_import:
if item_type in ("object_class", "relationship_class", "object", "relationship"):
continue
db_map_error_log.setdefault(db_map, []).extend([str(x) for x in import_error_log])
if to_update:
UpdateItemsCommand(self, db_map, item_type, to_update, check=False, parent=macro).redo()
self.update_items(item_type, {db_map: to_update}, check=False, identifier=identifier)
if to_add:
AddItemsCommand(self, db_map, item_type, to_add, check=False, parent=macro).redo()
macro.check()
if macro.isObsolete():
self.undo_stack[db_map].undo()
self.add_items(item_type, {db_map: to_add}, check=False, identifier=identifier)
if any(db_map_error_log.values()):
self.error_msg.emit(db_map_error_log)

Expand Down Expand Up @@ -961,13 +955,11 @@ def add_parameter_value_metadata(self, db_map_data):

def _add_ext_item_metadata(self, db_map_data, item_type):
for db_map, items in db_map_data.items():
macro = AgedUndoCommand()
macro.setText(f"add {item_type} to {db_map.codename}")
identifier = self.get_command_identifier()
metadata_items = db_map.get_metadata_to_add_with_entity_metadata_items(*items)
if metadata_items:
AddItemsCommand(self, db_map, "metadata", metadata_items, parent=macro)
AddItemsCommand(self, db_map, item_type, items, parent=macro)
self.undo_stack[db_map].push(macro)
self.add_items("metadata", {db_map: metadata_items}, identifier=identifier)
self.add_items(item_type, {db_map: items}, identifier=identifier)

def add_ext_entity_metadata(self, db_map_data):
"""Adds entity metadata together with all necessary metadata to db.
Expand Down Expand Up @@ -1099,13 +1091,11 @@ def update_parameter_value_metadata(self, db_map_data):

def _update_ext_item_metadata(self, db_map_data, item_type):
for db_map, items in db_map_data.items():
macro = AgedUndoCommand()
macro.setText(f"update {item_type} to {db_map.codename}")
identifier = self.get_command_identifier()
metadata_items = db_map.get_metadata_to_add_with_entity_metadata_items(*items)
if metadata_items:
AddItemsCommand(self, db_map, "metadata", metadata_items, parent=macro)
UpdateItemsCommand(self, db_map, item_type, items, parent=macro)
self.undo_stack[db_map].push(macro)
self.add_items("metadata", {db_map: metadata_items}, identifier=identifier)
self.update_items(item_type, {db_map: items}, identifier=identifier)

def update_ext_entity_metadata(self, db_map_data):
"""Updates entity metadata in db.
Expand All @@ -1131,20 +1121,18 @@ def set_scenario_alternatives(self, db_map_data):
"""
db_map_error_log = {}
for db_map, data in db_map_data.items():
macro = AgedUndoCommand()
macro.setText(f"set scenario alternatives in {db_map.codename}")
identifier = self.get_command_identifier()
items_to_add, ids_to_remove, errors = db_map.get_data_to_set_scenario_alternatives(*data)
if ids_to_remove:
RemoveItemsCommand(self, db_map, "scenario_alternative", ids_to_remove, parent=macro)
self.remove_items({db_map: {"scenario_alternative": ids_to_remove}}, identifier=identifier)
if items_to_add:
AddItemsCommand(self, db_map, "scenario_alternative", items_to_add, parent=macro)
self.add_items("scenario_alternative", {db_map: items_to_add}, identifier=identifier)
if errors:
db_map_error_log.setdefault(db_map, []).extend([str(x) for x in errors])
self.undo_stack[db_map].push(macro)
if any(db_map_error_log.values()):
self.error_msg.emit(db_map_error_log)

def purge_items(self, db_map_item_types):
def purge_items(self, db_map_item_types, **kwargs):
"""Purges selected items from given database.
Args:
Expand All @@ -1154,25 +1142,41 @@ def purge_items(self, db_map_item_types):
db_map: {item_type: {Asterisk} for item_type in item_types}
for db_map, item_types in db_map_item_types.items()
}
self.remove_items(db_map_typed_data)
self.remove_items(db_map_typed_data, **kwargs)

def add_items(self, item_type, db_map_data):
def add_items(self, item_type, db_map_data, identifier=None, **kwargs):
"""Pushes commands to add items to undo stack."""
if identifier is None:
identifier = self.get_command_identifier()
for db_map, data in db_map_data.items():
self.undo_stack[db_map].push(AddItemsCommand(self, db_map, item_type, data))
self.undo_stack[db_map].push(
AddItemsCommand(self, db_map, item_type, data, identifier=identifier, **kwargs)
)

def update_items(self, item_type, db_map_data):
def update_items(self, item_type, db_map_data, identifier=None, **kwargs):
"""Pushes commands to update items to undo stack."""
if identifier is None:
identifier = self.get_command_identifier()
for db_map, data in db_map_data.items():
self.undo_stack[db_map].push(UpdateItemsCommand(self, db_map, item_type, data))
self.undo_stack[db_map].push(
UpdateItemsCommand(self, db_map, item_type, data, identifier=identifier, **kwargs)
)

def remove_items(self, db_map_typed_ids):
def remove_items(self, db_map_typed_ids, identifier=None, **kwargs):
"""Pushes commands to remove items to undo stack."""
if identifier is None:
identifier = self.get_command_identifier()
for db_map, ids_per_type in db_map_typed_ids.items():
macro = AgedUndoCommand()
for item_type, ids in ids_per_type.items():
RemoveItemsCommand(self, db_map, item_type, ids, parent=macro)
self.undo_stack[db_map].push(macro)
self.undo_stack[db_map].push(
RemoveItemsCommand(self, db_map, item_type, ids, identifier=identifier, **kwargs)
)

def get_command_identifier(self):
try:
return self._cmd_id
finally:
self._cmd_id += 1

@busy_effect
def do_add_items(self, db_map, item_type, data, check=True):
Expand Down

0 comments on commit e3f2eff

Please sign in to comment.