Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: fix encoding channel argument parsing #1597

Merged
merged 3 commits into from
Jul 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions altair/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .core import (
infer_vegalite_type,
infer_encoding_types,
sanitize_dataframe,
parse_shorthand,
use_signature,
Expand All @@ -16,6 +17,7 @@

__all__ = (
'infer_vegalite_type',
'infer_encoding_types',
'sanitize_dataframe',
'spec_to_html',
'parse_shorthand',
Expand Down
84 changes: 84 additions & 0 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import traceback
import warnings

import jsonschema
import six
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -406,3 +407,86 @@ def display_traceback(in_ipython=True):
ip.showtraceback(exc_info)
else:
traceback.print_exception(*exc_info)


def infer_encoding_types(args, kwargs, channels):
"""Infer typed keyword arguments for args and kwargs

Parameters
----------
args : tuple
List of function args
kwargs : dict
Dict of function kwargs
channels : module
The module containing all altair encoding channel classes.

Returns
-------
kwargs : dict
All args and kwargs in a single dict, with keys and types
based on the channels mapping.
"""
# Construct a dictionary of channel type to encoding name
# TODO: cache this somehow?
channel_objs = (getattr(channels, name) for name in dir(channels))
channel_objs = (c for c in channel_objs
if isinstance(c, type) and issubclass(c, SchemaBase))
channel_to_name = {c: c._encoding_name for c in channel_objs}
name_to_channel = {}
for chan, name in channel_to_name.items():
chans = name_to_channel.setdefault(name, {})
key = 'value' if chan.__name__.endswith('Value') else 'field'
chans[key] = chan

# First use the mapping to convert args to kwargs based on their types.
for arg in args:
if isinstance(arg, (list, tuple)) and len(arg) > 0:
type_ = type(arg[0])
else:
type_ = type(arg)

encoding = channel_to_name.get(type_, None)
if encoding is None:
raise NotImplementedError("positional of type {}"
"".format(type_))
if encoding in kwargs:
raise ValueError("encoding {} specified twice.".format(encoding))
kwargs[encoding] = arg

def _wrap_in_channel_class(obj, encoding):
try:
condition = obj['condition']
except (KeyError, TypeError):
pass
else:
if condition is not Undefined:
obj = obj.copy()
obj['condition'] = _wrap_in_channel_class(condition, encoding)

if isinstance(obj, SchemaBase):
return obj

if isinstance(obj, six.string_types):
obj = {'shorthand': obj}

if isinstance(obj, (list, tuple)):
return [_wrap_in_channel_class(subobj, encoding) for subobj in obj]

if encoding not in name_to_channel:
warnings.warn("Unrecognized encoding channel '{}'".format(encoding))
return obj

classes = name_to_channel[encoding]
cls = classes['value'] if 'value' in obj else classes['field']

try:
# Don't force validation here; some objects won't be valid until
# they're created in the context of a chart.
return cls.from_dict(obj, validate=False)
except jsonschema.ValidationError:
# our attempts at finding the correct class have failed
return obj

return {encoding: _wrap_in_channel_class(obj, encoding)
for encoding, obj in kwargs.items()}
103 changes: 102 additions & 1 deletion altair/utils/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,58 @@
import types

import pandas as pd
import pytest

import altair as alt
from .. import parse_shorthand, update_nested
from .. import parse_shorthand, update_nested, infer_encoding_types

FAKE_CHANNELS_MODULE = '''
"""Fake channels module for utility tests."""

from altair.utils import schemapi


class FieldChannel(object):
def __init__(self, shorthand, **kwargs):
kwargs['shorthand'] = shorthand
return super(FieldChannel, self).__init__(**kwargs)


class ValueChannel(object):
def __init__(self, value, **kwargs):
kwargs['value'] = value
return super(ValueChannel, self).__init__(**kwargs)


class X(FieldChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "x"


class XValue(ValueChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "x"


class Y(FieldChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "y"


class YValue(ValueChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "y"


class StrokeWidth(FieldChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "strokeWidth"


class StrokeWidthValue(ValueChannel, schemapi.SchemaBase):
_schema = {}
_encoding_name = "strokeWidth"
'''


def test_parse_shorthand():
Expand Down Expand Up @@ -124,3 +175,53 @@ def test_update_nested():
output2 = update_nested(original, update)
assert output2 is original
assert output == output2


@pytest.fixture
def channels():
channels = types.ModuleType('channels')
exec(FAKE_CHANNELS_MODULE, channels.__dict__)
return channels


def _getargs(*args, **kwargs):
return args, kwargs


def test_infer_encoding_types(channels):
expected = dict(x=channels.X('xval'),
y=channels.YValue('yval'),
strokeWidth=channels.StrokeWidthValue(value=4))

# All positional args
args, kwds = _getargs(channels.X('xval'),
channels.YValue('yval'),
channels.StrokeWidthValue(4))
assert infer_encoding_types(args, kwds, channels) == expected

# All keyword args
args, kwds = _getargs(x='xval',
y=alt.value('yval'),
strokeWidth=alt.value(4))
assert infer_encoding_types(args, kwds, channels) == expected

# Mixed positional & keyword
args, kwds = _getargs(channels.X('xval'),
channels.YValue('yval'),
strokeWidth=alt.value(4))
assert infer_encoding_types(args, kwds, channels) == expected


def test_infer_encoding_types_with_condition(channels):
args, kwds = _getargs(
x=alt.condition('pred1', alt.value(1), alt.value(2)),
y=alt.condition('pred2', alt.value(1), 'yval'),
strokeWidth=alt.condition('pred3', 'sval', alt.value(2))
)
expected = dict(
x=channels.XValue(2, condition=channels.XValue(1, test='pred1')),
y=channels.Y('yval', condition=channels.YValue(1, test='pred2')),
strokeWidth=channels.StrokeWidthValue(2,
condition=channels.StrokeWidth('sval', test='pred3'))
)
assert infer_encoding_types(args, kwds, channels) == expected
69 changes: 5 additions & 64 deletions altair/vegalite/v2/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,72 +1198,13 @@ def serve(self, ip='127.0.0.1', port=8888, n_retries=50, files=None,
class EncodingMixin(object):
@utils.use_signature(core.EncodingWithFacet)
def encode(self, *args, **kwargs):
# First convert args to kwargs by inferring the class from the argument
if args:
channels_mapping = _get_channels_mapping()
for arg in args:
if isinstance(arg, (list, tuple)) and len(arg) > 0:
type_ = type(arg[0])
else:
type_ = type(arg)

encoding = channels_mapping.get(type_, None)
if encoding is None:
raise NotImplementedError("non-keyword arg of type {}"
"".format(type(arg)))
if encoding in kwargs:
raise ValueError("encode: encoding {} specified twice"
"".format(encoding))
kwargs[encoding] = arg

def _wrap_in_channel_class(obj, prop):
clsname = prop.title()

if isinstance(obj, core.SchemaBase):
return obj

if isinstance(obj, six.string_types):
obj = {'shorthand': obj}

if isinstance(obj, (list, tuple)):
return [_wrap_in_channel_class(subobj, prop) for subobj in obj]

if 'value' in obj:
clsname += 'Value'

try:
cls = getattr(channels, clsname)
except AttributeError:
raise ValueError("Unrecognized encoding channel '{}'".format(prop))

try:
# Don't force validation here; some objects won't be valid until
# they're created in the context of a chart.
return cls.from_dict(obj, validate=False)
except jsonschema.ValidationError:
# our attempts at finding the correct class have failed
return obj

for prop, obj in list(kwargs.items()):
try:
condition = obj['condition']
except (KeyError, TypeError):
pass
else:
if condition is not Undefined:
obj['condition'] = _wrap_in_channel_class(condition, prop)
kwargs[prop] = _wrap_in_channel_class(obj, prop)


copy = self.copy(deep=True, ignore=['data'])
# Convert args to kwargs based on their types.
kwargs = utils.infer_encoding_types(args, kwargs, channels)

# get a copy of the dict representation of the previous encoding
encoding = copy.encoding
if encoding is Undefined:
encoding = {}
elif isinstance(encoding, dict):
pass
else:
copy = self.copy(deep=['encoding'])
encoding = copy._get('encoding', {})
if isinstance(encoding, core.VegaLiteSchema):
encoding = {k: v for k, v in encoding._kwds.items()
if v is not Undefined}

Expand Down
Loading