diff --git a/traits/observers/expression.py b/traits/observers/expression.py index 27d5fd5d7..abdc42fd7 100644 --- a/traits/observers/expression.py +++ b/traits/observers/expression.py @@ -17,6 +17,8 @@ ObserverGraph as _ObserverGraph, ) +# Expression is a public user interface for constructing ObserverGraph. + class Expression: """ @@ -25,25 +27,8 @@ class Expression: ``HasTraits.observe`` method or the ``observe`` decorator. An Expression is typically created using one of the top-level functions - provided in this module, e.g.``trait``. + provided in this module, e.g. ``trait``. """ - def __init__(self): - # ``_levels`` is a list of list of IObserver. - # Each item corresponds to a layer of branches in the ObserverGraph. - # The last item is the most nested level. - # e.g. _levels = [[observer1, observer2], [observer3, observer4]] - # observer3 and observer4 are both leaf nodes of a tree, and they are - # "siblings" of each other. Each of observer3 and observer4 has two - # parents: observer1 and observer2. - # When ObserverGraph(s) are constructured from this expression, one - # starts from the end of this list, to the top, and then continues to - # the prior_expressions - self._levels = [] - - # Represent prior expressions to be combined in series (JOIN) - # or in parallel (OR). This is either an instance of _SeriesExpression - # or an instance of _ParallelExpression. - self._prior_expression = None def __eq__(self, other): """ Return true if the other value is an Expression with equivalent @@ -59,7 +44,7 @@ def __eq__(self, other): def __or__(self, expression): """ Create a new expression that matches this expression OR - the given expression. Equivalent expressions will be ignored. + the given expression. e.g. ``trait("age") | trait("number")`` will match either trait **age** or trait **number** on an object. @@ -72,11 +57,7 @@ def __or__(self, expression): ------- new_expression : traits.observers.expression.Expression """ - if self == expression: - return self._copy() - new = Expression() - new._prior_expression = _ParallelExpression([self, expression]) - return new + return ParallelExpression(self, expression) def then(self, expression): """ Create a new expression by extending this expression with @@ -93,14 +74,7 @@ def then(self, expression): ------- new_expression : traits.observers.expression.Expression """ - - if self._prior_expression is None and not self._levels: - # this expression is empty... - new = expression._copy() - else: - new = Expression() - new._prior_expression = _SeriesExpression([self, expression]) - return new + return SeriesExpression(self, expression) def trait(self, name, notify=True, optional=False): """ Create a new expression for observing a trait with the exact @@ -135,162 +109,75 @@ def _as_graphs(self): ------- graphs : list of ObserverGraph """ - return _create_graphs(self) + return self._create_graphs(branches=[]) - def _new_with_branches(self, nodes): - """ Create a new Expression with a new leaf nodes. + def _create_graphs(self, branches): + """ Return a list of ObserverGraph with the given branches. Parameters ---------- - nodes : list of IObserver - - Returns - ------- - new_expression : traits.observers.expression.Expression - """ - expression = self._copy() - expression._levels.append(nodes) - return expression - - def _copy(self): - """ Return a copy of this expression. + branches : list of ObserverGraph + Graphs to be used as branches. Returns ------- - new_expression : traits.observers.expression.Expression + graphs : list of ObserverGraph """ - expression = Expression() - expression._levels = self._levels.copy() - if self._prior_expression is not None: - expression._prior_expression = self._prior_expression._copy() - return expression - + raise NotImplementedError("'_create_graphs' must be implemented.") -def _create_graphs(expression, graphs=None): - """ Create ObserverGraphs from a given expression. - Parameters - ---------- - expression : traits.observers.expression.Expression - graphs : collection of ObserverGraph - Leaf graphs to be added. - Needed when this function is called recursively. - - Returns - ------- - graphs : list of ObserverGraph - New graphs +class SingleObserverExpression(Expression): + """ Container of Expression for wrapping a single observer. """ - if graphs is None: - graphs = [] - - for nodes in expression._levels[::-1]: - graphs = [ - _ObserverGraph(node=node, children=graphs) for node in nodes - ] - if expression._prior_expression is not None: - graphs = expression._prior_expression._create_graphs( - graphs=graphs, - ) - return graphs + def __init__(self, observer): + self.observer = observer + def _create_graphs(self, branches): + return [ + _ObserverGraph(node=self.observer, children=branches), + ] -# _SeriesExpression and _ParallelExpression share an undeclared interface -# which require the classes to have implemented ``copy`` and ``_create_graphs`` -class _SeriesExpression: +class SeriesExpression(Expression): """ Container of Expression for joining expressions in series. - Used internally in this module. Parameters ---------- - expressions : list of Expression - List of Expression to be combined in series. + first : traits.observers.expression.Expression + Left expression to be joined in series. + second : traits.observers.expression.Expression + Right expression to be joined in series. """ - def __init__(self, expressions): - self.expressions = expressions.copy() - - def _copy(self): - """ Return a copy of this instance. - The internal ``expressions`` list is copied so it can be mutated. - - Returns - ------- - series_expression : _SeriesExpression - """ - return _SeriesExpression(self.expressions) - - def _create_graphs(self, graphs): - """ - Create new ObserverGraph(s) from the joined expressions. - - Parameters - ---------- - graphs : collection of ObserverGraph - Leaf graphs to be added. - Needed when this function is called recursively. + def __init__(self, first, second): + self._first = first + self._second = second - Returns - ------- - graphs : list of ObserverGraph - New graphs - """ - for expr in self.expressions[::-1]: - graphs = _create_graphs( - expr, - graphs=graphs, - ) - return graphs + def _create_graphs(self, branches): + branches = self._second._create_graphs(branches=branches) + return self._first._create_graphs(branches=branches) -class _ParallelExpression: +class ParallelExpression(Expression): """ Container of Expression for joining expressions in parallel. - Used internally in this module. Parameters ---------- - expressions : list of Expression - List of Expression to be combined in parallel. + left : traits.observers.expression.Expression + Left expression to be joined in parallel. + right : traits.observers.expression.Expression + Right expression to be joined in parallel. """ - def __init__(self, expressions): - self.expressions = expressions.copy() - - def _copy(self): - """ Return a copy of this instance. - The internal ``expressions`` list is copied so it can be mutated. - - Returns - ------- - parallel_expression : _ParallelExpression - """ - return _ParallelExpression(self.expressions) - - def _create_graphs(self, graphs): - """ - Create new ObserverGraph(s) from the joined expressions. - - Parameters - ---------- - graphs : collection of ObserverGraph - Leaf graphs to be added. - Needed when this function is called recursively. + def __init__(self, left, right): + self._left = left + self._right = right - Returns - ------- - graphs : list of ObserverGraph - New graphs - """ - new_graphs = [] - for expr in self.expressions: - or_graphs = _create_graphs( - expr, - graphs=graphs, - ) - new_graphs.extend(or_graphs) - return new_graphs + def _create_graphs(self, branches): + left_graphs = self._left._create_graphs(branches=branches) + right_graphs = self._right._create_graphs(branches=branches) + return left_graphs + right_graphs def join_(*expressions): @@ -330,4 +217,4 @@ def trait(name, notify=True, optional=False): """ observer = _NamedTraitObserver( name=name, notify=notify, optional=optional) - return Expression()._new_with_branches(nodes=[observer]) + return SingleObserverExpression(observer) diff --git a/traits/observers/tests/test_expression.py b/traits/observers/tests/test_expression.py index 071996d92..509b40998 100644 --- a/traits/observers/tests/test_expression.py +++ b/traits/observers/tests/test_expression.py @@ -46,7 +46,7 @@ def create_expression(observer): ------- expression : Expression """ - return expression.Expression()._new_with_branches(nodes=[observer]) + return expression.SingleObserverExpression(observer) class TestExpressionComposition(unittest.TestCase): @@ -75,19 +75,6 @@ def test_or_operator(self): actual = expr._as_graphs() self.assertEqual(actual, expected) - def test_or_operator_same_elements(self): - observer = 1 - expr1 = create_expression(observer) - expr2 = create_expression(observer) - expr = expr1 | expr2 - - # the two elements are equal - expected = [ - create_graph(observer), - ] - actual = expr._as_graphs() - self.assertEqual(actual, expected) - def test_or_maintain_order(self): # Test __or__ will maintain the order provided by the user. observer1 = 1 @@ -116,18 +103,6 @@ def test_then_operator(self): actual = expr._as_graphs() self.assertEqual(actual, expected) - def test_then_optimization(self): - # If the expression is empty to start with, just make a copy - # An empty bootstrapping expression is common when an user creates an - # expression using a high-level helper function. - expr1 = expression.Expression() - expr2 = create_expression(1) - expr = expr1.then(expr2) - - self.assertEqual(expr._levels, expr2._levels) - self.assertIsNot(expr._levels, expr2._levels) - self.assertIsNone(expr._prior_expression) - def test_chained_then_or(self): observer1 = 1 observer2 = 2 @@ -306,48 +281,3 @@ def test_join_equality_with_then(self): def test_equality_different_type(self): expr = create_expression(1) self.assertNotEqual(expr, "1") - - -class TestExpressionCopy(unittest.TestCase): - """ Test the Expression._copy method.""" - - def test_expression_copy_current_levels(self): - expr = create_expression(1) - copied = expr._copy() - self.assertEqual(expr._levels, copied._levels) - self.assertIsNot(copied._levels, expr._levels) - self.assertEqual(copied._as_graphs(), expr._as_graphs()) - - def test_expression_copy_prior_expression_parallel(self): - expr = create_expression(1) | create_expression(2) - self.assertIsNotNone(expr._prior_expression) - - copied = expr._copy() - self.assertEqual(copied._as_graphs(), expr._as_graphs()) - self.assertIsNotNone(copied._prior_expression) - self.assertIsNot(copied._prior_expression, expr._prior_expression) - self.assertEqual( - copied._prior_expression.expressions, - expr._prior_expression.expressions, - ) - self.assertIsNot( - copied._prior_expression.expressions, - expr._prior_expression.expressions, - ) - - def test_expression_copy_prior_expression_serial(self): - expr = create_expression(1).then(create_expression(2)) - self.assertIsNotNone(expr._prior_expression) - - copied = expr._copy() - self.assertEqual(copied._as_graphs(), expr._as_graphs()) - self.assertIsNotNone(copied._prior_expression) - self.assertIsNot(copied._prior_expression, expr._prior_expression) - self.assertEqual( - copied._prior_expression.expressions, - expr._prior_expression.expressions, - ) - self.assertIsNot( - copied._prior_expression.expressions, - expr._prior_expression.expressions, - )