Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Edit Domain: Partial restore state on categories mismatch #6776

Merged
merged 4 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 83 additions & 51 deletions Orange/widgets/data/oweditdomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@
A widget for manual editing of a domain's attributes.

"""
from __future__ import annotations
import warnings
from xml.sax.saxutils import escape
from itertools import zip_longest, repeat, chain
from itertools import zip_longest, repeat, chain, groupby
from collections import namedtuple, Counter
from functools import singledispatch, partial
from operator import itemgetter
from typing import (
Tuple, List, Any, Optional, Union, Dict, Sequence, Iterable, NamedTuple,
FrozenSet, Type, Callable, TypeVar, Mapping, Hashable, cast, Set
)

import numpy as np
import pandas as pd

from AnyQt.QtWidgets import (
QWidget, QListView, QTreeView, QVBoxLayout, QHBoxLayout, QFormLayout,
QLineEdit, QAction, QActionGroup, QGroupBox,
Expand All @@ -35,6 +38,7 @@
)
from AnyQt.QtCore import pyqtSignal as Signal, pyqtSlot as Slot

from orangecanvas.utils import assocf
from orangewidget.utils.listview import ListViewSearch

import Orange.data
Expand All @@ -46,7 +50,7 @@
from Orange.util import frompyfunc
from Orange.widgets import widget, gui
from Orange.widgets.settings import Setting
from Orange.widgets.utils import itemmodels, ftry, disconnected
from Orange.widgets.utils import itemmodels, ftry, disconnected, unique_everseen as unique
from Orange.widgets.utils.buttons import FixedSizeButton
from Orange.widgets.utils.itemmodels import signal_blocking
from Orange.widgets.utils.widgetpreview import WidgetPreview
Expand All @@ -62,14 +66,6 @@
MAX_HINTS = 1000


def unique(sequence: Iterable[H]) -> Iterable[H]:
"""
Return unique elements in `sequence`, preserving their (first seen) order.
"""
# depending on Python >= 3.6 'ordered' dict implementation detail.
return iter(dict.fromkeys(sequence))


class _DataType:
def __eq__(self, other):
"""Equal if `other` has the same type and all elements compare equal."""
Expand All @@ -83,19 +79,6 @@ def __ne__(self, other):
def __hash__(self):
return hash((type(self), super().__hash__()))

def name_type(self):
"""
Returns a tuple with name and type of the variable.
It is used since it is forbidden to use names of variables in settings.
"""
type_number = {
"Categorical": 0,
"Real": 2,
"Time": 3,
"String": 4
}
return self.name, type_number[type(self).__name__]


#: An ordered sequence of key, value pairs (variable annotations)
AnnotationsType = Tuple[Tuple[str, str], ...]
Expand Down Expand Up @@ -2029,9 +2012,17 @@ class Outputs:
class Error(widget.OWWidget.Error):
duplicate_var_name = widget.Msg("A variable name is duplicated.")

class Warning(widget.OWWidget.Warning):
transform_restore_failed = widget.Msg(
"Failed to restore transform {} for column {}"
)
cat_mapping_does_not_apply = widget.Msg(
"Categories mapping for {} does not apply to current input"
)

settings_version = 4

_domain_change_hints = Setting({}, schema_only=True)
_domain_change_hints: dict = Setting({}, schema_only=True)
_merge_dialog_settings = Setting({}, schema_only=True)
output_table_name = Setting("", schema_only=True)

Expand Down Expand Up @@ -2122,8 +2113,8 @@ def clear(self):
self.data = None
self.variables_model.clear()
self.clear_editor()

self._merge_dialog_settings = {}
self.Warning.clear()

def reset_selected(self):
"""Reset the currently selected variable to its original state."""
Expand Down Expand Up @@ -2178,30 +2169,63 @@ def setup_model(self, data: Orange.data.Table):
for i, d in enumerate(columns):
model.setData(model.index(i), d, Qt.EditRole)

def _sanitize_transform(
self, var: Variable, trs: Sequence[Transform]
) -> tuple[Sequence[Transform], Sequence[tuple[Msg, str]]]:
def does_categories_mapping_apply(
var: Categorical, tr: CategoriesMapping) -> bool:
return set(var.categories) \
== set(ci for ci, _ in tr.mapping if ci is not None)
msgs = []
if isinstance(var, Categorical):
trs_ = []
for tr in trs:
if isinstance(tr, CategoriesMapping):
if does_categories_mapping_apply(var, tr):
trs_.append(tr)
else:

msgs.append((self.Warning.cat_mapping_does_not_apply, var.name))
else:
trs_.append(tr)
return trs_, msgs
else:
return trs, msgs

def _restore(self):
"""
Restore the edit transform from saved state.
"""
model = self.variables_model
hints = self._domain_change_hints
first_key = None
msgs = []
for i in range(model.rowCount()):
midx = model.index(i, 0)
coldesc = model.data(midx, Qt.EditRole) # type: DataVector
tr, key = self._restore_transform(coldesc.vtype)
if tr:
model.setData(midx, tr, TransformRole)
if first_key is None:
first_key = key
res = self._find_stored_transform(coldesc.vtype)
if res:
key, tr = res
if tr:
self._store_transform(coldesc.vtype, tr, key)
tr, msgs_ = self._sanitize_transform(coldesc.vtype, tr)
model.setData(midx, tr, TransformRole)
msgs.extend(msgs_)
if first_key is None:
first_key = key
# Reduce the number of hints to MAX_HINTS, but keep all current hints
# Current hints start with `first_key`.
while len(hints) > MAX_HINTS and \
(key := next(iter(hints))) is not first_key:
(key := next(iter(hints))) != first_key:
del hints[key] # pylint: disable=unsupported-delete-operation

# Show warnings for non-applicable transforms
for msg, names in groupby(msgs, key=itemgetter(0)):
msg(", ".join(map(itemgetter(1), names)))

# Restore the current variable selection
selected_rows = [i for i, vec in enumerate(model)
if vec.vtype.name_type()[0] in self._selected_items]
if vec.vtype.name in self._selected_items]
if not selected_rows and model.rowCount():
selected_rows = [0]
itemmodels.select_rows(self.variables_view, selected_rows)
Expand Down Expand Up @@ -2257,8 +2281,9 @@ def _on_variable_changed(self):
self._store_transform(var, transform)
self._invalidate()

def _store_transform(self, var, transform, deconvar=None):
# type: (Variable, List[Transform]) -> None
def _store_transform(
self, var: Variable, transform: Iterable[Transform], deconvar=None
) -> None:
deconvar = deconvar or deconstruct(var)
# Remove the existing key (if any) to put the new one at the end,
# to make sure it comes after the sentinel
Expand All @@ -2267,25 +2292,32 @@ def _store_transform(self, var, transform, deconvar=None):
self._domain_change_hints[deconvar] = \
[deconstruct(t) for t in transform]

def _restore_transform(self, var):
# type: (Variable) -> List[Transform]
def _find_stored_transform(
self, var: Variable
) -> Tuple[tuple, Sequence[Transform]] | None:
"""Find stored transform for `var`."""
def reconstruct_transform(tr_: list[tuple]) -> list[Transform]:
trs = []
for t in tr_:
try:
trs.append(cast(Transform, reconstruct(*t)))
except (AttributeError, TypeError, NameError):
self.Warning.transform_restore_failed(
str(t), var.name, exc_info=True,
)
return trs

hints = self._domain_change_hints
key = deconstruct(var)
tr_ = self._domain_change_hints.get(key, [])
tr = []
tr = hints.get(key) # exact match
if tr is not None:
return key, reconstruct_transform(tr)

for t in tr_:
try:
tr.append(reconstruct(*t))
except (NameError, TypeError) as err:
warnings.warn(
f"Failed to restore transform: {t}, {err}",
UserWarning, stacklevel=2
)
if tr:
self._store_transform(var, tr, key)
else:
key = None
return tr, key
# match by name and type only
item = assocf(hints.items(),
lambda k: k[0] == key[0] and k[1][0] == var.name)
if item is not None:
return item[0], reconstruct_transform(item[1])

def _invalidate(self):
self._set_modified(True)
Expand Down
25 changes: 25 additions & 0 deletions Orange/widgets/data/tests/test_oweditdomain.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,31 @@ def restore(state):
tr = model.data(model.index(4), TransformRole)
self.assertEqual(tr, [AsString(), Rename("Z")])

restore({viris: [("CategoriesMapping", ([("Iris-setosa", "setosa"),
("Iris-versicolor", "versicolor"),
("Iris-virginica", "virginica")],)),
("Rename", ("Species",))]})
tr = model.data(model.index(4), TransformRole)
self.assertEqual(tr, [CategoriesMapping([("Iris-setosa", "setosa"),
("Iris-versicolor", "versicolor"),
("Iris-virginica", "virginica")]),
Rename("Species")])

viris_1 = ("Categorical", ("iris", ("A", "B"), ()))
restore({viris_1: [("Rename", ("K",),),
("CategoriesMapping", ([("A", "AA"), ("B", "BB")],))]})
self.assertTrue(w.Warning.cat_mapping_does_not_apply.is_shown())
w.commit()
output = self.get_output(w.Outputs.data)
self.assertEqual(output.domain.class_var.name, "K")
self.assertEqual(output.domain.class_var.values,
("Iris-setosa", "Iris-versicolor", "Iris-virginica"))

restore({viris: [("Rename", ("A")), ("NonexistantTransform", ("AA",))]})
tr = model.data(model.index(4), TransformRole)
self.assertEqual(tr, [Rename("A")])
self.assertTrue(w.Warning.transform_restore_failed.is_shown())

def test_reset_selected(self):
w = self.widget
model = w.domain_view.model()
Expand Down
Loading