diff --git a/altair/utils/__init__.py b/altair/utils/__init__.py index cb7920b03..0bd8ec5e3 100644 --- a/altair/utils/__init__.py +++ b/altair/utils/__init__.py @@ -7,11 +7,11 @@ update_nested, display_traceback, SchemaBase, - Undefined, ) from .html import spec_to_html from .plugin_registry import PluginRegistry from .deprecation import AltairDeprecationWarning +from .schemapi import Undefined __all__ = ( diff --git a/altair/utils/core.py b/altair/utils/core.py index c47b9a04d..38bf09504 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -14,7 +14,7 @@ import pandas as pd import numpy as np -from altair.utils.schemapi import SchemaBase, Undefined +from altair.utils.schemapi import SchemaBase try: from pandas.api.types import infer_dtype as _infer_dtype @@ -669,15 +669,6 @@ def infer_encoding_types(args, kwargs, channels): kwargs[encoding] = arg def _wrap_in_channel_class(obj, encoding): - try: - condition = obj["condition"] - except (KeyError, TypeError): - pass - else: - if condition is not Undefined: - obj = obj.copy() - obj["condition"] = _wrap_in_channel_class(condition, encoding) - if isinstance(obj, SchemaBase): return obj diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 478c3e5d0..5a65f86bb 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -516,7 +516,7 @@ def condition(predicate, if_true, if_false, **kwargs): # dict in the appropriate schema if_true = if_true.to_dict() elif isinstance(if_true, str): - if_true = {"shorthand": if_true} + if_true = utils.parse_shorthand(if_true) if_true.update(kwargs) condition.update(if_true) diff --git a/altair/vegalite/v5/schema/channels.py b/altair/vegalite/v5/schema/channels.py index bb078f59e..a59f4bd1e 100644 --- a/altair/vegalite/v5/schema/channels.py +++ b/altair/vegalite/v5/schema/channels.py @@ -70,7 +70,7 @@ def to_dict(self, validate=True, ignore=(), context=None): elif 'field' in condition and 'type' not in condition: kwds = parse_shorthand(condition['field'], context.get('data', None)) copy = self.copy(deep=['condition']) - copy.condition.update(kwds) + copy['condition'].update(kwds) return super(ValueChannelMixin, copy).to_dict(validate=validate, ignore=ignore, context=context) diff --git a/tests/utils/tests/test_core.py b/tests/utils/tests/test_core.py index c46d1744e..e0b575671 100644 --- a/tests/utils/tests/test_core.py +++ b/tests/utils/tests/test_core.py @@ -249,17 +249,35 @@ def test_infer_encoding_types(channels): assert infer_encoding_types(args, kwds, channels) == expected -def test_infer_encoding_types_with_condition(channels): +def test_infer_encoding_types_with_condition(): + channels = alt.channels + args, kwds = _getargs( - x=alt.condition("pred1", alt.value(1), alt.value(2)), - y=alt.condition("pred2", alt.value(1), "yval"), - strokeWidth=alt.condition("pred3", "sval", alt.value(2)), + size=alt.condition("pred1", alt.value(1), alt.value(2)), + color=alt.condition("pred2", alt.value("red"), "cfield:N"), + opacity=alt.condition("pred3", "ofield:N", alt.value(0.2)), ) + expected = dict( - x=channels.XValue(2, condition=channels.XValue(1, test="pred1")), - y=channels.Y("yval", condition=channels.YValue(1, test="pred2")), - strokeWidth=channels.StrokeWidthValue( - 2, condition=channels.StrokeWidth("sval", test="pred3") + size=channels.SizeValue( + 2, + condition=alt.ConditionalPredicateValueDefnumberExprRef( + value=1, test=alt.Predicate("pred1") + ), + ), + color=channels.Color( + "cfield:N", + condition=alt.ConditionalPredicateValueDefGradientstringnullExprRef( + value="red", test=alt.Predicate("pred2") + ), + ), + opacity=channels.OpacityValue( + 0.2, + condition=alt.ConditionalPredicateMarkPropFieldOrDatumDef( + field=alt.FieldName("ofield"), + test=alt.Predicate("pred3"), + type=alt.StandardType("nominal"), + ), ), ) assert infer_encoding_types(args, kwds, channels) == expected