Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse shorthand when creating the condition 2 #2841

Merged
merged 8 commits into from
Feb 19, 2023
11 changes: 1 addition & 10 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pandas as pd
import numpy as np

from altair.utils.schemapi import SchemaBase, Undefined
from altair.utils.schemapi import SchemaBase, Undefined # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? If this PR removed the only place where Undefined was used, shouldn't we remove the import instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking on this @joelostblom. I left the Undefined import here because of this line: https://github.com/altair-viz/altair/blob/7bc754e8dd6748a8b4633b6f3cd2345ab7078fb2/altair/utils/__init__.py#L10

Do you think it's worth importing Undefined directly in that __init__ file instead? I haven't experimented at all, so it's possible other errors will show up if I move it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that would make more sense to me. Could you try that? If it doesn't work and seems laborious to resolve we can just keep it here, but let's give it a quick try first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seemed to work @joelostblom in c91468f
My quick tests did not catch any unwanted side effects.


try:
from pandas.api.types import infer_dtype as _infer_dtype
Expand Down Expand Up @@ -669,15 +669,6 @@ def infer_encoding_types(args, kwargs, channels):
kwargs[encoding] = arg

def _wrap_in_channel_class(obj, encoding):
try:
condition = obj["condition"]
except (KeyError, TypeError):
pass
else:
if condition is not Undefined:
obj = obj.copy()
obj["condition"] = _wrap_in_channel_class(condition, encoding)

if isinstance(obj, SchemaBase):
return obj

Expand Down
2 changes: 1 addition & 1 deletion altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def condition(predicate, if_true, if_false, **kwargs):
# dict in the appropriate schema
if_true = if_true.to_dict()
elif isinstance(if_true, str):
if_true = {"shorthand": if_true}
if_true = utils.parse_shorthand(if_true)
if_true.update(kwargs)
condition.update(if_true)

Expand Down
2 changes: 1 addition & 1 deletion altair/vegalite/v5/schema/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def to_dict(self, validate=True, ignore=(), context=None):
elif 'field' in condition and 'type' not in condition:
kwds = parse_shorthand(condition['field'], context.get('data', None))
copy = self.copy(deep=['condition'])
copy.condition.update(kwds)
copy['condition'].update(kwds)
return super(ValueChannelMixin, copy).to_dict(validate=validate,
ignore=ignore,
context=context)
Expand Down
34 changes: 26 additions & 8 deletions tests/utils/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,35 @@ def test_infer_encoding_types(channels):
assert infer_encoding_types(args, kwds, channels) == expected


def test_infer_encoding_types_with_condition(channels):
def test_infer_encoding_types_with_condition():
channels = alt.channels

args, kwds = _getargs(
x=alt.condition("pred1", alt.value(1), alt.value(2)),
y=alt.condition("pred2", alt.value(1), "yval"),
strokeWidth=alt.condition("pred3", "sval", alt.value(2)),
size=alt.condition("pred1", alt.value(1), alt.value(2)),
color=alt.condition("pred2", alt.value("red"), "cfield:N"),
opacity=alt.condition("pred3", "ofield:N", alt.value(0.2)),
)

expected = dict(
x=channels.XValue(2, condition=channels.XValue(1, test="pred1")),
y=channels.Y("yval", condition=channels.YValue(1, test="pred2")),
strokeWidth=channels.StrokeWidthValue(
2, condition=channels.StrokeWidth("sval", test="pred3")
size=channels.SizeValue(
2,
condition=alt.ConditionalPredicateValueDefnumberExprRef(
value=1, test=alt.Predicate("pred1")
),
),
color=channels.Color(
"cfield:N",
condition=alt.ConditionalPredicateValueDefGradientstringnullExprRef(
value="red", test=alt.Predicate("pred2")
),
),
opacity=channels.OpacityValue(
0.2,
condition=alt.ConditionalPredicateMarkPropFieldOrDatumDef(
field=alt.FieldName("ofield"),
test=alt.Predicate("pred3"),
type=alt.StandardType("nominal"),
),
),
)
assert infer_encoding_types(args, kwds, channels) == expected