diff --git a/CHANGELOG.md b/CHANGELOG.md index a79a6b51985..d135720aa1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ This project adheres to [Semantic Versioning](http://semver.org/). ## [4.12.0] - unreleased +### Added + +- For all `go.Figure` functions accepting a selector argument (e.g., `select_traces`), this argument can now also be a function which is passed each relevant graph object (in the case of `select_traces`, it is passed every trace in the figure). For graph objects where this function returns true, the graph object is included in the selection. + ### Updated - Updated Plotly.js to version 1.57.0. See the [plotly.js CHANGELOG](https://github.com/plotly/plotly.js/blob/v1.57.0/CHANGELOG.md) for more information. These changes are reflected in the auto-generated `plotly.graph_objects` module. diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index 7bc0e66f56d..9733d5b3241 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -813,24 +813,33 @@ def _perform_select_traces(self, filter_by_subplot, grid_subplot_refs, selector) def _selector_matches(obj, selector): if selector is None: return True + # If selector is a dict, compare the fields + if (type(selector) == type(dict())) or isinstance(selector, BasePlotlyType): + # This returns True if selector is an empty dict + for k in selector: + if k not in obj: + return False - for k in selector: - if k not in obj: - return False - - obj_val = obj[k] - selector_val = selector[k] - - if isinstance(obj_val, BasePlotlyType): - obj_val = obj_val.to_plotly_json() + obj_val = obj[k] + selector_val = selector[k] - if isinstance(selector_val, BasePlotlyType): - selector_val = selector_val.to_plotly_json() + if isinstance(obj_val, BasePlotlyType): + obj_val = obj_val.to_plotly_json() - if obj_val != selector_val: - return False + if isinstance(selector_val, BasePlotlyType): + selector_val = selector_val.to_plotly_json() - return True + if obj_val != selector_val: + return False + return True + # If selector is a function, call it with the obj as the argument + elif type(selector) == type(lambda x: True): + return selector(obj) + else: + raise TypeError( + "selector must be dict or a function " + "accepting a graph object returning a boolean." + ) def for_each_trace(self, fn, selector=None, row=None, col=None, secondary_y=None): """ diff --git a/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_selector_matches.py b/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_selector_matches.py new file mode 100644 index 00000000000..138d0b48921 --- /dev/null +++ b/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_selector_matches.py @@ -0,0 +1,91 @@ +import pytest + +import plotly.graph_objects as go +from plotly.basedatatypes import BaseFigure + + +def test_selector_none(): + # should return True + assert BaseFigure._selector_matches({}, None) == True # arbitrary, + + +def test_selector_empty_dict(): + # should return True + assert ( + BaseFigure._selector_matches(dict(hello="everybody"), {}) == True # arbitrary, + ) + + +def test_selector_matches_subset_of_obj(): + # should return True + assert ( + BaseFigure._selector_matches( + dict(hello="everybody", today="cloudy", myiq=55), + dict(myiq=55, today="cloudy"), + ) + == True + ) + + +def test_selector_has_nonmatching_key(): + # should return False + assert ( + BaseFigure._selector_matches( + dict(hello="everybody", today="cloudy", myiq=55), + dict(myiq=55, cronenberg="scanners"), + ) + == False + ) + + +def test_selector_has_nonmatching_value(): + # should return False + assert ( + BaseFigure._selector_matches( + dict(hello="everybody", today="cloudy", myiq=55), + dict(myiq=55, today="sunny"), + ) + == False + ) + + +def test_baseplotlytypes_could_match(): + # should return True + obj = go.layout.Annotation(x=1, y=2, text="pat metheny") + sel = go.layout.Annotation(x=1, y=2, text="pat metheny") + assert BaseFigure._selector_matches(obj, sel) == True + + +def test_baseplotlytypes_could_not_match(): + # should return False + obj = go.layout.Annotation(x=1, y=3, text="pat metheny") + sel = go.layout.Annotation(x=1, y=2, text="pat metheny") + assert BaseFigure._selector_matches(obj, sel) == False + + +def test_baseplotlytypes_cannot_match_subset(): + # should return False because "undefined" keys in sel return None, and are + # compared (because "key in sel" returned True, it's value was None) + obj = go.layout.Annotation(x=1, y=2, text="pat metheny") + sel = go.layout.Annotation(x=1, y=2,) + assert BaseFigure._selector_matches(obj, sel) == False + + +def test_function_selector_could_match(): + # should return True + obj = go.layout.Annotation(x=1, y=2, text="pat metheny") + + def _sel(d): + return d["x"] == 1 and d["y"] == 2 and d["text"] == "pat metheny" + + assert BaseFigure._selector_matches(obj, _sel) == True + + +def test_function_selector_could_not_match(): + # should return False + obj = go.layout.Annotation(x=1, y=2, text="pat metheny") + + def _sel(d): + return d["x"] == 1 and d["y"] == 3 and d["text"] == "pat metheny" + + assert BaseFigure._selector_matches(obj, _sel) == False diff --git a/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_traces.py b/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_traces.py index b1c8487ceae..ccfe50382e0 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_traces.py +++ b/packages/python/plotly/plotly/tests/test_core/test_update_objects/test_update_traces.py @@ -225,6 +225,43 @@ def test_select_property_and_grid(self): # Valid row/col and valid selector but the intersection is empty self.assert_select_traces([], selector={"type": "markers"}, row=3, col=1) + def test_select_with_function(self): + def _check_trace_key(k, v): + def f(t): + try: + return t[k] == v + except LookupError: + return False + + return f + + # (1, 1) + self.assert_select_traces( + [0], selector=_check_trace_key("mode", "markers"), row=1, col=1 + ) + self.assert_select_traces( + [1], selector=_check_trace_key("type", "bar"), row=1, col=1 + ) + + # (2, 1) + self.assert_select_traces( + [2, 9], selector=_check_trace_key("mode", "lines"), row=2, col=1 + ) + + # (1, 2) + self.assert_select_traces( + [4], selector=_check_trace_key("marker.color", "green"), row=1, col=2 + ) + + # Valid row/col and valid selector but the intersection is empty + self.assert_select_traces( + [], selector=_check_trace_key("type", "markers"), row=3, col=1 + ) + + def test_select_traces_type_error(self): + with self.assertRaises(TypeError): + self.assert_select_traces([0], selector=123, row=1, col=1) + def test_for_each_trace_lowercase_names(self): # Names are all uppercase to start original_names = [t.name for t in self.fig.data]