Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse shorthand when creating the condition 2 #2841

Merged
merged 8 commits into from
Feb 19, 2023
2 changes: 1 addition & 1 deletion altair/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
update_nested,
display_traceback,
SchemaBase,
Undefined,
)
from .html import spec_to_html
from .plugin_registry import PluginRegistry
from .deprecation import AltairDeprecationWarning
from .schemapi import Undefined


__all__ = (
Expand Down
11 changes: 1 addition & 10 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pandas as pd
import numpy as np

from altair.utils.schemapi import SchemaBase, Undefined
from altair.utils.schemapi import SchemaBase

try:
from pandas.api.types import infer_dtype as _infer_dtype
Expand Down Expand Up @@ -669,15 +669,6 @@ def infer_encoding_types(args, kwargs, channels):
kwargs[encoding] = arg

def _wrap_in_channel_class(obj, encoding):
try:
condition = obj["condition"]
except (KeyError, TypeError):
pass
else:
if condition is not Undefined:
obj = obj.copy()
obj["condition"] = _wrap_in_channel_class(condition, encoding)

if isinstance(obj, SchemaBase):
return obj

Expand Down
2 changes: 1 addition & 1 deletion altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def condition(predicate, if_true, if_false, **kwargs):
# dict in the appropriate schema
if_true = if_true.to_dict()
elif isinstance(if_true, str):
if_true = {"shorthand": if_true}
if_true = utils.parse_shorthand(if_true)
if_true.update(kwargs)
condition.update(if_true)

Expand Down
2 changes: 1 addition & 1 deletion altair/vegalite/v5/schema/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def to_dict(self, validate=True, ignore=(), context=None):
elif 'field' in condition and 'type' not in condition:
kwds = parse_shorthand(condition['field'], context.get('data', None))
copy = self.copy(deep=['condition'])
copy.condition.update(kwds)
copy['condition'].update(kwds)
return super(ValueChannelMixin, copy).to_dict(validate=validate,
ignore=ignore,
context=context)
Expand Down
34 changes: 26 additions & 8 deletions tests/utils/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,35 @@ def test_infer_encoding_types(channels):
assert infer_encoding_types(args, kwds, channels) == expected


def test_infer_encoding_types_with_condition(channels):
def test_infer_encoding_types_with_condition():
channels = alt.channels

args, kwds = _getargs(
x=alt.condition("pred1", alt.value(1), alt.value(2)),
y=alt.condition("pred2", alt.value(1), "yval"),
strokeWidth=alt.condition("pred3", "sval", alt.value(2)),
size=alt.condition("pred1", alt.value(1), alt.value(2)),
color=alt.condition("pred2", alt.value("red"), "cfield:N"),
opacity=alt.condition("pred3", "ofield:N", alt.value(0.2)),
)

expected = dict(
x=channels.XValue(2, condition=channels.XValue(1, test="pred1")),
y=channels.Y("yval", condition=channels.YValue(1, test="pred2")),
strokeWidth=channels.StrokeWidthValue(
2, condition=channels.StrokeWidth("sval", test="pred3")
size=channels.SizeValue(
2,
condition=alt.ConditionalPredicateValueDefnumberExprRef(
value=1, test=alt.Predicate("pred1")
),
),
color=channels.Color(
"cfield:N",
condition=alt.ConditionalPredicateValueDefGradientstringnullExprRef(
value="red", test=alt.Predicate("pred2")
),
),
opacity=channels.OpacityValue(
0.2,
condition=alt.ConditionalPredicateMarkPropFieldOrDatumDef(
field=alt.FieldName("ofield"),
test=alt.Predicate("pred3"),
type=alt.StandardType("nominal"),
),
),
)
assert infer_encoding_types(args, kwds, channels) == expected