From dec7f516cefd22c6e97923f369fcf8b32332d59b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 3 Jul 2019 18:57:53 -0700 Subject: [PATCH 1/3] ENH: use common utility to infer all encoding types --- altair/utils/__init__.py | 2 + altair/utils/core.py | 83 +++++++++++++++++++++++++++ altair/vegalite/v2/api.py | 69 ++-------------------- altair/vegalite/v2/schema/channels.py | 36 ++++++++++++ altair/vegalite/v3/api.py | 61 +------------------- altair/vegalite/v3/schema/channels.py | 55 ++++++++++++++++++ tools/generate_schema_wrapper.py | 3 + tools/schemapi/codegen.py | 8 ++- 8 files changed, 193 insertions(+), 124 deletions(-) diff --git a/altair/utils/__init__.py b/altair/utils/__init__.py index f12576220..bf1ad90d1 100644 --- a/altair/utils/__init__.py +++ b/altair/utils/__init__.py @@ -1,5 +1,6 @@ from .core import ( infer_vegalite_type, + infer_encoding_types, sanitize_dataframe, parse_shorthand, use_signature, @@ -16,6 +17,7 @@ __all__ = ( 'infer_vegalite_type', + 'infer_encoding_types' 'sanitize_dataframe', 'spec_to_html', 'parse_shorthand', diff --git a/altair/utils/core.py b/altair/utils/core.py index b6253453e..a173cfa0d 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -406,3 +406,86 @@ def display_traceback(in_ipython=True): ip.showtraceback(exc_info) else: traceback.print_exception(*exc_info) + + +def infer_encoding_types(args, kwargs, channels): + """Infer typed keyword arguments for args and kwargs + + Parameters + ---------- + args : tuple + List of function args + kwargs : dict + Dict of function kwargs + channels : module + The module containing all altair encoding channel classes. + + Returns + ------- + kwargs : dict + All args and kwargs in a single dict, with keys and types + based on the channels mapping. + """ + # Construct a dictionary of channel type to encoding name + # TODO: cache this somehow? + channel_objs = (getattr(channels, name) for name in dir(channels)) + channel_objs = (c for c in channel_objs + if isinstance(c, type) and issubclass(c, SchemaBase)) + channel_to_name = {c: c._encoding_name for c in channel_objs} + name_to_channel = {} + for chan, name in channel_to_name.items(): + chans = name_to_channel.setdefault(name, {}) + key = 'value' if chan.__name__.endswith('Value') else 'field' + chans[key] = chan + + # First use the mapping to convert args to kwargs based on their types. + for arg in args: + if isinstance(arg, (list, tuple)) and len(arg) > 0: + type_ = type(arg[0]) + else: + type_ = type(arg) + + encoding = channel_to_name.get(type_, None) + if encoding is None: + raise NotImplementedError("positional of type {}" + "".format(type_)) + if encoding in kwargs: + raise ValueError("encoding {} specified twice.".format(encoding)) + kwargs[encoding] = arg + + def _wrap_in_channel_class(obj, encoding): + try: + condition = obj['condition'] + except (KeyError, TypeError): + pass + else: + if condition is not Undefined: + obj = obj.copy() + obj['condition'] = _wrap_in_channel_class(condition, encoding) + + if isinstance(obj, SchemaBase): + return obj + + if isinstance(obj, six.string_types): + obj = {'shorthand': obj} + + if isinstance(obj, (list, tuple)): + return [_wrap_in_channel_class(subobj, encoding) for subobj in obj] + + if encoding not in name_to_channel: + warnings.warn("Unrecognized encoding channel '{}'".format(encoding)) + return obj + + classes = name_to_channel[encoding] + cls = classes['value'] if 'value' in obj else classes['field'] + + try: + # Don't force validation here; some objects won't be valid until + # they're created in the context of a chart. + return cls.from_dict(obj, validate=False) + except jsonschema.ValidationError: + # our attempts at finding the correct class have failed + return obj + + return {encoding: _wrap_in_channel_class(obj, encoding) + for encoding, obj in kwargs.items()} diff --git a/altair/vegalite/v2/api.py b/altair/vegalite/v2/api.py index 2a97f0d2d..c62ec73d8 100644 --- a/altair/vegalite/v2/api.py +++ b/altair/vegalite/v2/api.py @@ -1198,72 +1198,13 @@ def serve(self, ip='127.0.0.1', port=8888, n_retries=50, files=None, class EncodingMixin(object): @utils.use_signature(core.EncodingWithFacet) def encode(self, *args, **kwargs): - # First convert args to kwargs by inferring the class from the argument - if args: - channels_mapping = _get_channels_mapping() - for arg in args: - if isinstance(arg, (list, tuple)) and len(arg) > 0: - type_ = type(arg[0]) - else: - type_ = type(arg) - - encoding = channels_mapping.get(type_, None) - if encoding is None: - raise NotImplementedError("non-keyword arg of type {}" - "".format(type(arg))) - if encoding in kwargs: - raise ValueError("encode: encoding {} specified twice" - "".format(encoding)) - kwargs[encoding] = arg - - def _wrap_in_channel_class(obj, prop): - clsname = prop.title() - - if isinstance(obj, core.SchemaBase): - return obj - - if isinstance(obj, six.string_types): - obj = {'shorthand': obj} - - if isinstance(obj, (list, tuple)): - return [_wrap_in_channel_class(subobj, prop) for subobj in obj] - - if 'value' in obj: - clsname += 'Value' - - try: - cls = getattr(channels, clsname) - except AttributeError: - raise ValueError("Unrecognized encoding channel '{}'".format(prop)) - - try: - # Don't force validation here; some objects won't be valid until - # they're created in the context of a chart. - return cls.from_dict(obj, validate=False) - except jsonschema.ValidationError: - # our attempts at finding the correct class have failed - return obj - - for prop, obj in list(kwargs.items()): - try: - condition = obj['condition'] - except (KeyError, TypeError): - pass - else: - if condition is not Undefined: - obj['condition'] = _wrap_in_channel_class(condition, prop) - kwargs[prop] = _wrap_in_channel_class(obj, prop) - - - copy = self.copy(deep=True, ignore=['data']) + # Convert args to kwargs based on their types. + kwargs = utils.infer_encoding_types(args, kwargs, channels) # get a copy of the dict representation of the previous encoding - encoding = copy.encoding - if encoding is Undefined: - encoding = {} - elif isinstance(encoding, dict): - pass - else: + copy = self.copy(deep=['encoding']) + encoding = copy._get('encoding', {}) + if isinstance(encoding, core.VegaLiteSchema): encoding = {k: v for k, v in encoding._kwds.items() if v is not Undefined} diff --git a/altair/vegalite/v2/schema/channels.py b/altair/vegalite/v2/schema/channels.py index dc9a9b4ad..8187599d5 100644 --- a/altair/vegalite/v2/schema/channels.py +++ b/altair/vegalite/v2/schema/channels.py @@ -197,6 +197,7 @@ class Color(FieldChannelMixin, core.MarkPropFieldDefWithCondition): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "color" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -222,6 +223,7 @@ class ColorValue(ValueChannelMixin, core.MarkPropValueDefWithCondition): A constant value in visual domain. """ _class_is_valid_at_instantiation = False + _encoding_name = "color" def __init__(self, value, condition=Undefined, **kwds): super(ColorValue, self).__init__(value=value, condition=condition, **kwds) @@ -324,6 +326,7 @@ class Column(FieldChannelMixin, core.FacetFieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "column" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, header=Undefined, sort=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, @@ -403,6 +406,7 @@ class Detail(FieldChannelMixin, core.FieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "detail" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -528,6 +532,7 @@ class Fill(FieldChannelMixin, core.MarkPropFieldDefWithCondition): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "fill" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -553,6 +558,7 @@ class FillValue(ValueChannelMixin, core.MarkPropValueDefWithCondition): A constant value in visual domain. """ _class_is_valid_at_instantiation = False + _encoding_name = "fill" def __init__(self, value, condition=Undefined, **kwds): super(FillValue, self).__init__(value=value, condition=condition, **kwds) @@ -634,6 +640,7 @@ class Href(FieldChannelMixin, core.FieldDefWithCondition): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "href" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -658,6 +665,7 @@ class HrefValue(ValueChannelMixin, core.ValueDefWithCondition): A constant value in visual domain. """ _class_is_valid_at_instantiation = False + _encoding_name = "href" def __init__(self, value, condition=Undefined, **kwds): super(HrefValue, self).__init__(value=value, condition=condition, **kwds) @@ -733,6 +741,7 @@ class Key(FieldChannelMixin, core.FieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "key" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -810,6 +819,7 @@ class Latitude(FieldChannelMixin, core.FieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "latitude" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -887,6 +897,7 @@ class Latitude2(FieldChannelMixin, core.FieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "latitude2" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -964,6 +975,7 @@ class Longitude(FieldChannelMixin, core.FieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "longitude" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -1041,6 +1053,7 @@ class Longitude2(FieldChannelMixin, core.FieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "longitude2" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -1166,6 +1179,7 @@ class Opacity(FieldChannelMixin, core.MarkPropFieldDefWithCondition): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "opacity" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -1191,6 +1205,7 @@ class OpacityValue(ValueChannelMixin, core.MarkPropValueDefWithCondition): A constant value in visual domain. """ _class_is_valid_at_instantiation = False + _encoding_name = "opacity" def __init__(self, value, condition=Undefined, **kwds): super(OpacityValue, self).__init__(value=value, condition=condition, **kwds) @@ -1267,6 +1282,7 @@ class Order(FieldChannelMixin, core.OrderFieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "order" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, sort=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -1288,6 +1304,7 @@ class OrderValue(ValueChannelMixin, core.ValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "order" def __init__(self, value, **kwds): super(OrderValue, self).__init__(value=value, **kwds) @@ -1390,6 +1407,7 @@ class Row(FieldChannelMixin, core.FacetFieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "row" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, header=Undefined, sort=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, @@ -1517,6 +1535,7 @@ class Shape(FieldChannelMixin, core.MarkPropFieldDefWithCondition): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "shape" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -1542,6 +1561,7 @@ class ShapeValue(ValueChannelMixin, core.MarkPropValueDefWithCondition): A constant value in visual domain. """ _class_is_valid_at_instantiation = False + _encoding_name = "shape" def __init__(self, value, condition=Undefined, **kwds): super(ShapeValue, self).__init__(value=value, condition=condition, **kwds) @@ -1665,6 +1685,7 @@ class Size(FieldChannelMixin, core.MarkPropFieldDefWithCondition): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "size" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -1690,6 +1711,7 @@ class SizeValue(ValueChannelMixin, core.MarkPropValueDefWithCondition): A constant value in visual domain. """ _class_is_valid_at_instantiation = False + _encoding_name = "size" def __init__(self, value, condition=Undefined, **kwds): super(SizeValue, self).__init__(value=value, condition=condition, **kwds) @@ -1813,6 +1835,7 @@ class Stroke(FieldChannelMixin, core.MarkPropFieldDefWithCondition): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "stroke" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -1838,6 +1861,7 @@ class StrokeValue(ValueChannelMixin, core.MarkPropValueDefWithCondition): A constant value in visual domain. """ _class_is_valid_at_instantiation = False + _encoding_name = "stroke" def __init__(self, value, condition=Undefined, **kwds): super(StrokeValue, self).__init__(value=value, condition=condition, **kwds) @@ -1922,6 +1946,7 @@ class Text(FieldChannelMixin, core.TextFieldDefWithCondition): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "text" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, format=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, @@ -1947,6 +1972,7 @@ class TextValue(ValueChannelMixin, core.TextValueDefWithCondition): A constant value in visual domain. """ _class_is_valid_at_instantiation = False + _encoding_name = "text" def __init__(self, value, condition=Undefined, **kwds): super(TextValue, self).__init__(value=value, condition=condition, **kwds) @@ -2031,6 +2057,7 @@ class Tooltip(FieldChannelMixin, core.TextFieldDefWithCondition): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "tooltip" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, format=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, @@ -2056,6 +2083,7 @@ class TooltipValue(ValueChannelMixin, core.TextValueDefWithCondition): A constant value in visual domain. """ _class_is_valid_at_instantiation = False + _encoding_name = "tooltip" def __init__(self, value, condition=Undefined, **kwds): super(TooltipValue, self).__init__(value=value, condition=condition, **kwds) @@ -2199,6 +2227,7 @@ class X(FieldChannelMixin, core.PositionFieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "x" def __init__(self, shorthand=Undefined, aggregate=Undefined, axis=Undefined, bin=Undefined, field=Undefined, scale=Undefined, sort=Undefined, stack=Undefined, timeUnit=Undefined, @@ -2222,6 +2251,7 @@ class XValue(ValueChannelMixin, core.ValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "x" def __init__(self, value, **kwds): super(XValue, self).__init__(value=value, **kwds) @@ -2297,6 +2327,7 @@ class X2(FieldChannelMixin, core.FieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "x2" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -2318,6 +2349,7 @@ class X2Value(ValueChannelMixin, core.ValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "x2" def __init__(self, value, **kwds): super(X2Value, self).__init__(value=value, **kwds) @@ -2461,6 +2493,7 @@ class Y(FieldChannelMixin, core.PositionFieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "y" def __init__(self, shorthand=Undefined, aggregate=Undefined, axis=Undefined, bin=Undefined, field=Undefined, scale=Undefined, sort=Undefined, stack=Undefined, timeUnit=Undefined, @@ -2484,6 +2517,7 @@ class YValue(ValueChannelMixin, core.ValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "y" def __init__(self, value, **kwds): super(YValue, self).__init__(value=value, **kwds) @@ -2559,6 +2593,7 @@ class Y2(FieldChannelMixin, core.FieldDef): `__. """ _class_is_valid_at_instantiation = False + _encoding_name = "y2" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -2580,6 +2615,7 @@ class Y2Value(ValueChannelMixin, core.ValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "y2" def __init__(self, value, **kwds): super(Y2Value, self).__init__(value=value, **kwds) diff --git a/altair/vegalite/v3/api.py b/altair/vegalite/v3/api.py index 2ae58a949..2c35f375e 100644 --- a/altair/vegalite/v3/api.py +++ b/altair/vegalite/v3/api.py @@ -1389,66 +1389,11 @@ def serve(self, ip='127.0.0.1', port=8888, n_retries=50, files=None, class EncodingMixin(object): @utils.use_signature(core.FacetedEncoding) def encode(self, *args, **kwargs): - # First convert args to kwargs by inferring the class from the argument - if args: - channels_mapping = _get_channels_mapping() - for arg in args: - if isinstance(arg, (list, tuple)) and len(arg) > 0: - type_ = type(arg[0]) - else: - type_ = type(arg) - - encoding = channels_mapping.get(type_, None) - if encoding is None: - raise NotImplementedError("non-keyword arg of type {}" - "".format(type(arg))) - if encoding in kwargs: - raise ValueError("encode: encoding {} specified twice" - "".format(encoding)) - kwargs[encoding] = arg - - def _wrap_in_channel_class(obj, prop): - clsname = prop.title() - - if isinstance(obj, core.SchemaBase): - return obj - - if isinstance(obj, six.string_types): - obj = {'shorthand': obj} - - if isinstance(obj, (list, tuple)): - return [_wrap_in_channel_class(subobj, prop) for subobj in obj] - - if 'value' in obj: - clsname += 'Value' - - try: - cls = getattr(channels, clsname) - except AttributeError: - raise ValueError("Unrecognized encoding channel '{}'".format(prop)) - - try: - # Don't force validation here; some objects won't be valid until - # they're created in the context of a chart. - return cls.from_dict(obj, validate=False) - except jsonschema.ValidationError: - # our attempts at finding the correct class have failed - return obj - - for prop, obj in list(kwargs.items()): - try: - condition = obj['condition'] - except (KeyError, TypeError): - pass - else: - if condition is not Undefined: - obj['condition'] = _wrap_in_channel_class(condition, prop) - if obj is not None: - kwargs[prop] = _wrap_in_channel_class(obj, prop) - - copy = self.copy(deep=['encoding']) + # Convert args to kwargs based on their types. + kwargs = utils.infer_encoding_types(args, kwargs, channels) # get a copy of the dict representation of the previous encoding + copy = self.copy(deep=['encoding']) encoding = copy._get('encoding', {}) if isinstance(encoding, core.VegaLiteSchema): encoding = {k: v for k, v in encoding._kwds.items() diff --git a/altair/vegalite/v3/schema/channels.py b/altair/vegalite/v3/schema/channels.py index ba5f5902c..0601271c8 100644 --- a/altair/vegalite/v3/schema/channels.py +++ b/altair/vegalite/v3/schema/channels.py @@ -243,6 +243,7 @@ class Color(FieldChannelMixin, core.StringFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "color" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -261,6 +262,7 @@ class ColorValue(ValueChannelMixin, core.StringValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "color" def __init__(self, value, **kwds): super(ColorValue, self).__init__(value=value, **kwds) @@ -403,6 +405,7 @@ class Column(FieldChannelMixin, core.FacetFieldDef): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "column" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, header=Undefined, sort=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, @@ -522,6 +525,7 @@ class Detail(FieldChannelMixin, core.FieldDefWithoutScale): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "detail" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -666,6 +670,7 @@ class Facet(FieldChannelMixin, core.FacetFieldDef): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "facet" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, header=Undefined, sort=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, @@ -839,6 +844,7 @@ class Fill(FieldChannelMixin, core.StringFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "fill" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -857,6 +863,7 @@ class FillValue(ValueChannelMixin, core.StringValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "fill" def __init__(self, value, **kwds): super(FillValue, self).__init__(value=value, **kwds) @@ -1026,6 +1033,7 @@ class FillOpacity(FieldChannelMixin, core.NumericFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "fillOpacity" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -1044,6 +1052,7 @@ class FillOpacityValue(ValueChannelMixin, core.NumericValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "fillOpacity" def __init__(self, value, **kwds): super(FillOpacityValue, self).__init__(value=value, **kwds) @@ -1193,6 +1202,7 @@ class Href(FieldChannelMixin, core.TextFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "href" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, format=Undefined, formatType=Undefined, timeUnit=Undefined, @@ -1212,6 +1222,7 @@ class HrefValue(ValueChannelMixin, core.TextValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "href" def __init__(self, value, **kwds): super(HrefValue, self).__init__(value=value, **kwds) @@ -1327,6 +1338,7 @@ class Key(FieldChannelMixin, core.FieldDefWithoutScale): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "key" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -1443,6 +1455,7 @@ class Latitude(FieldChannelMixin, core.LatLongFieldDef): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "latitude" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -1464,6 +1477,7 @@ class LatitudeValue(ValueChannelMixin, core.NumberValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "latitude" def __init__(self, value, **kwds): super(LatitudeValue, self).__init__(value=value, **kwds) @@ -1547,6 +1561,7 @@ class Latitude2(FieldChannelMixin, core.SecondaryFieldDef): defined, axis/header/legend title will be used. """ _class_is_valid_at_instantiation = False + _encoding_name = "latitude2" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, **kwds): @@ -1568,6 +1583,7 @@ class Latitude2Value(ValueChannelMixin, core.NumberValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "latitude2" def __init__(self, value, **kwds): super(Latitude2Value, self).__init__(value=value, **kwds) @@ -1682,6 +1698,7 @@ class Longitude(FieldChannelMixin, core.LatLongFieldDef): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "longitude" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -1703,6 +1720,7 @@ class LongitudeValue(ValueChannelMixin, core.NumberValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "longitude" def __init__(self, value, **kwds): super(LongitudeValue, self).__init__(value=value, **kwds) @@ -1786,6 +1804,7 @@ class Longitude2(FieldChannelMixin, core.SecondaryFieldDef): defined, axis/header/legend title will be used. """ _class_is_valid_at_instantiation = False + _encoding_name = "longitude2" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, **kwds): @@ -1807,6 +1826,7 @@ class Longitude2Value(ValueChannelMixin, core.NumberValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "longitude2" def __init__(self, value, **kwds): super(Longitude2Value, self).__init__(value=value, **kwds) @@ -1976,6 +1996,7 @@ class Opacity(FieldChannelMixin, core.NumericFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "opacity" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -1994,6 +2015,7 @@ class OpacityValue(ValueChannelMixin, core.NumericValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "opacity" def __init__(self, value, **kwds): super(OpacityValue, self).__init__(value=value, **kwds) @@ -2110,6 +2132,7 @@ class Order(FieldChannelMixin, core.OrderFieldDef): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "order" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, sort=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, **kwds): @@ -2131,6 +2154,7 @@ class OrderValue(ValueChannelMixin, core.NumberValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "order" def __init__(self, value, **kwds): super(OrderValue, self).__init__(value=value, **kwds) @@ -2273,6 +2297,7 @@ class Row(FieldChannelMixin, core.FacetFieldDef): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "row" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, header=Undefined, sort=Undefined, timeUnit=Undefined, title=Undefined, type=Undefined, @@ -2446,6 +2471,7 @@ class Shape(FieldChannelMixin, core.ShapeFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "shape" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -2464,6 +2490,7 @@ class ShapeValue(ValueChannelMixin, core.ShapeValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "shape" def __init__(self, value, **kwds): super(ShapeValue, self).__init__(value=value, **kwds) @@ -2633,6 +2660,7 @@ class Size(FieldChannelMixin, core.NumericFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "size" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -2651,6 +2679,7 @@ class SizeValue(ValueChannelMixin, core.NumericValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "size" def __init__(self, value, **kwds): super(SizeValue, self).__init__(value=value, **kwds) @@ -2820,6 +2849,7 @@ class Stroke(FieldChannelMixin, core.StringFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "stroke" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -2838,6 +2868,7 @@ class StrokeValue(ValueChannelMixin, core.StringValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "stroke" def __init__(self, value, **kwds): super(StrokeValue, self).__init__(value=value, **kwds) @@ -3007,6 +3038,7 @@ class StrokeOpacity(FieldChannelMixin, core.NumericFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "strokeOpacity" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -3026,6 +3058,7 @@ class StrokeOpacityValue(ValueChannelMixin, core.NumericValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "strokeOpacity" def __init__(self, value, **kwds): super(StrokeOpacityValue, self).__init__(value=value, **kwds) @@ -3195,6 +3228,7 @@ class StrokeWidth(FieldChannelMixin, core.NumericFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "strokeWidth" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, legend=Undefined, scale=Undefined, sort=Undefined, timeUnit=Undefined, @@ -3213,6 +3247,7 @@ class StrokeWidthValue(ValueChannelMixin, core.NumericValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "strokeWidth" def __init__(self, value, **kwds): super(StrokeWidthValue, self).__init__(value=value, **kwds) @@ -3362,6 +3397,7 @@ class Text(FieldChannelMixin, core.TextFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "text" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, format=Undefined, formatType=Undefined, timeUnit=Undefined, @@ -3381,6 +3417,7 @@ class TextValue(ValueChannelMixin, core.TextValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "text" def __init__(self, value, **kwds): super(TextValue, self).__init__(value=value, **kwds) @@ -3530,6 +3567,7 @@ class Tooltip(FieldChannelMixin, core.TextFieldDefWithCondition): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "tooltip" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, condition=Undefined, field=Undefined, format=Undefined, formatType=Undefined, timeUnit=Undefined, @@ -3549,6 +3587,7 @@ class TooltipValue(ValueChannelMixin, core.TextValueDefWithCondition): optional. """ _class_is_valid_at_instantiation = False + _encoding_name = "tooltip" def __init__(self, value, **kwds): super(TooltipValue, self).__init__(value=value, **kwds) @@ -3743,6 +3782,7 @@ class X(FieldChannelMixin, core.PositionFieldDef): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "x" def __init__(self, shorthand=Undefined, aggregate=Undefined, axis=Undefined, bin=Undefined, field=Undefined, impute=Undefined, scale=Undefined, sort=Undefined, stack=Undefined, @@ -3766,6 +3806,7 @@ class XValue(ValueChannelMixin, core.XValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "x" def __init__(self, value, **kwds): super(XValue, self).__init__(value=value, **kwds) @@ -3849,6 +3890,7 @@ class X2(FieldChannelMixin, core.SecondaryFieldDef): defined, axis/header/legend title will be used. """ _class_is_valid_at_instantiation = False + _encoding_name = "x2" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, **kwds): @@ -3870,6 +3912,7 @@ class X2Value(ValueChannelMixin, core.XValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "x2" def __init__(self, value, **kwds): super(X2Value, self).__init__(value=value, **kwds) @@ -3953,6 +3996,7 @@ class XError(FieldChannelMixin, core.SecondaryFieldDef): defined, axis/header/legend title will be used. """ _class_is_valid_at_instantiation = False + _encoding_name = "xError" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, **kwds): @@ -3974,6 +4018,7 @@ class XErrorValue(ValueChannelMixin, core.NumberValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "xError" def __init__(self, value, **kwds): super(XErrorValue, self).__init__(value=value, **kwds) @@ -4057,6 +4102,7 @@ class XError2(FieldChannelMixin, core.SecondaryFieldDef): defined, axis/header/legend title will be used. """ _class_is_valid_at_instantiation = False + _encoding_name = "xError2" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, **kwds): @@ -4078,6 +4124,7 @@ class XError2Value(ValueChannelMixin, core.NumberValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "xError2" def __init__(self, value, **kwds): super(XError2Value, self).__init__(value=value, **kwds) @@ -4272,6 +4319,7 @@ class Y(FieldChannelMixin, core.PositionFieldDef): ``x``, ``y`` ). """ _class_is_valid_at_instantiation = False + _encoding_name = "y" def __init__(self, shorthand=Undefined, aggregate=Undefined, axis=Undefined, bin=Undefined, field=Undefined, impute=Undefined, scale=Undefined, sort=Undefined, stack=Undefined, @@ -4295,6 +4343,7 @@ class YValue(ValueChannelMixin, core.YValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "y" def __init__(self, value, **kwds): super(YValue, self).__init__(value=value, **kwds) @@ -4378,6 +4427,7 @@ class Y2(FieldChannelMixin, core.SecondaryFieldDef): defined, axis/header/legend title will be used. """ _class_is_valid_at_instantiation = False + _encoding_name = "y2" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, **kwds): @@ -4399,6 +4449,7 @@ class Y2Value(ValueChannelMixin, core.YValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "y2" def __init__(self, value, **kwds): super(Y2Value, self).__init__(value=value, **kwds) @@ -4482,6 +4533,7 @@ class YError(FieldChannelMixin, core.SecondaryFieldDef): defined, axis/header/legend title will be used. """ _class_is_valid_at_instantiation = False + _encoding_name = "yError" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, **kwds): @@ -4503,6 +4555,7 @@ class YErrorValue(ValueChannelMixin, core.NumberValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "yError" def __init__(self, value, **kwds): super(YErrorValue, self).__init__(value=value, **kwds) @@ -4586,6 +4639,7 @@ class YError2(FieldChannelMixin, core.SecondaryFieldDef): defined, axis/header/legend title will be used. """ _class_is_valid_at_instantiation = False + _encoding_name = "yError2" def __init__(self, shorthand=Undefined, aggregate=Undefined, bin=Undefined, field=Undefined, timeUnit=Undefined, title=Undefined, **kwds): @@ -4607,6 +4661,7 @@ class YError2Value(ValueChannelMixin, core.NumberValueDef): between ``0`` to ``1`` for opacity). """ _class_is_valid_at_instantiation = False + _encoding_name = "yError2" def __init__(self, value, **kwds): super(YError2Value, self).__init__(value=value, **kwds) diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 8fe8ba668..f2f18700d 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -148,6 +148,7 @@ class FieldSchemaGenerator(SchemaGenerator): class {classname}(FieldChannelMixin, core.{basename}): """{docstring}""" _class_is_valid_at_instantiation = False + _encoding_name = "{encodingname}" {init_code} ''') @@ -158,6 +159,7 @@ class ValueSchemaGenerator(SchemaGenerator): class {classname}(ValueChannelMixin, core.{basename}): """{docstring}""" _class_is_valid_at_instantiation = False + _encoding_name = "{encodingname}" {init_code} ''') @@ -320,6 +322,7 @@ def generate_vegalite_channel_wrappers(schemafile, version, imports=None): gen = Generator(classname=classname, basename=basename, schema=defschema, rootschema=schema, + encodingname=prop, nodefault=nodefault) contents.append(gen.schema_class()) return '\n'.join(contents) diff --git a/tools/schemapi/codegen.py b/tools/schemapi/codegen.py index 7913715b0..8ded209d3 100644 --- a/tools/schemapi/codegen.py +++ b/tools/schemapi/codegen.py @@ -75,6 +75,8 @@ class SchemaGenerator(object): rootschemarepr : CodeSnippet or object, optional An object whose repr will be used in the place of the explicit root schema. + **kwargs : dict + Additional keywords for derived classes. """ schema_class_template = textwrap.dedent(''' class {classname}({basename}): @@ -95,7 +97,7 @@ def _process_description(self, description): def __init__(self, classname, schema, rootschema=None, basename='SchemaBase', schemarepr=None, rootschemarepr=None, - nodefault=()): + nodefault=(), **kwargs): self.classname = classname self.schema = schema self.rootschema = rootschema @@ -103,6 +105,7 @@ def __init__(self, classname, schema, rootschema=None, self.schemarepr = schemarepr self.rootschemarepr = rootschemarepr self.nodefault = nodefault + self.kwargs = kwargs def schema_class(self): """Generate code for a schema class""" @@ -120,7 +123,8 @@ def schema_class(self): schema=schemarepr, rootschema=rootschemarepr, docstring=self.docstring(indent=4), - init_code=self.init_code(indent=4) + init_code=self.init_code(indent=4), + **self.kwargs ) def docstring(self, indent=0): From 6aebe43723689521b9ec2bb6f354a9819bf767be Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 3 Jul 2019 22:35:10 -0700 Subject: [PATCH 2/3] TST: add unit test for utils.infer_encoding_types --- altair/utils/tests/test_core.py | 103 +++++++++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 1 deletion(-) diff --git a/altair/utils/tests/test_core.py b/altair/utils/tests/test_core.py index 0f168a9f9..ed94b8b10 100644 --- a/altair/utils/tests/test_core.py +++ b/altair/utils/tests/test_core.py @@ -1,7 +1,58 @@ +import types + import pandas as pd +import pytest import altair as alt -from .. import parse_shorthand, update_nested +from .. import parse_shorthand, update_nested, infer_encoding_types + +FAKE_CHANNELS_MODULE = ''' +"""Fake channels module for utility tests.""" + +from altair.utils import schemapi + + +class FieldChannel(object): + def __init__(self, shorthand, **kwargs): + kwargs['shorthand'] = shorthand + return super(FieldChannel, self).__init__(**kwargs) + + +class ValueChannel(object): + def __init__(self, value, **kwargs): + kwargs['value'] = value + return super(ValueChannel, self).__init__(**kwargs) + + +class X(FieldChannel, schemapi.SchemaBase): + _schema = {} + _encoding_name = "x" + + +class XValue(ValueChannel, schemapi.SchemaBase): + _schema = {} + _encoding_name = "x" + + +class Y(FieldChannel, schemapi.SchemaBase): + _schema = {} + _encoding_name = "y" + + +class YValue(ValueChannel, schemapi.SchemaBase): + _schema = {} + _encoding_name = "y" + + +class StrokeWidth(FieldChannel, schemapi.SchemaBase): + _schema = {} + _encoding_name = "strokeWidth" + + +class StrokeWidthValue(ValueChannel, schemapi.SchemaBase): + _schema = {} + _encoding_name = "strokeWidth" +''' def test_parse_shorthand(): @@ -124,3 +175,53 @@ def test_update_nested(): output2 = update_nested(original, update) assert output2 is original assert output == output2 + + +@pytest.fixture +def channels(): + channels = types.ModuleType('channels') + exec(FAKE_CHANNELS_MODULE, channels.__dict__) + return channels + + +def _getargs(*args, **kwargs): + return args, kwargs + + +def test_infer_encoding_types(channels): + expected = dict(x=channels.X('xval'), + y=channels.YValue('yval'), + strokeWidth=channels.StrokeWidthValue(value=4)) + + # All positional args + args, kwds = _getargs(channels.X('xval'), + channels.YValue('yval'), + channels.StrokeWidthValue(4)) + assert infer_encoding_types(args, kwds, channels) == expected + + # All keyword args + args, kwds = _getargs(x='xval', + y=alt.value('yval'), + strokeWidth=alt.value(4)) + assert infer_encoding_types(args, kwds, channels) == expected + + # Mixed positional & keyword + args, kwds = _getargs(channels.X('xval'), + channels.YValue('yval'), + strokeWidth=alt.value(4)) + assert infer_encoding_types(args, kwds, channels) == expected + + +def test_infer_encoding_types_with_condition(channels): + args, kwds = _getargs( + x=alt.condition('pred1', alt.value(1), alt.value(2)), + y=alt.condition('pred2', alt.value(1), 'yval'), + strokeWidth=alt.condition('pred3', 'sval', alt.value(2)) + ) + expected = dict( + x=channels.XValue(2, condition=channels.XValue(1, test='pred1')), + y=channels.Y('yval', condition=channels.YValue(1, test='pred2')), + strokeWidth=channels.StrokeWidthValue(2, + condition=channels.StrokeWidth('sval', test='pred3')) + ) + assert infer_encoding_types(args, kwds, channels) == expected From 6cd6511e4f9a1d2e8d664636bd1775ccbee63361 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 3 Jul 2019 22:41:31 -0700 Subject: [PATCH 3/3] MAINT: fix flake8 errors --- altair/utils/__init__.py | 2 +- altair/utils/core.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/altair/utils/__init__.py b/altair/utils/__init__.py index bf1ad90d1..5c7233db8 100644 --- a/altair/utils/__init__.py +++ b/altair/utils/__init__.py @@ -17,7 +17,7 @@ __all__ = ( 'infer_vegalite_type', - 'infer_encoding_types' + 'infer_encoding_types', 'sanitize_dataframe', 'spec_to_html', 'parse_shorthand', diff --git a/altair/utils/core.py b/altair/utils/core.py index a173cfa0d..0e440a373 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -9,6 +9,7 @@ import traceback import warnings +import jsonschema import six import pandas as pd import numpy as np