diff --git a/ibis/expr/datatypes.py b/ibis/expr/datatypes.py index 739b1d63b468..23fe88905eaa 100644 --- a/ibis/expr/datatypes.py +++ b/ibis/expr/datatypes.py @@ -32,6 +32,7 @@ try: if sys.version_info >= (3, 6): import shapely.geometry + IS_SHAPELY_AVAILABLE = True except ImportError: ... @@ -405,6 +406,21 @@ class Interval(DataType): ns='nanosecond', ) + _timedelta_to_interval_units = dict( + days='D', + hours='h', + minutes='m', + seconds='s', + milliseconds='ms', + microseconds='us', + nanoseconds='ns', + ) + + def _convert_timedelta_unit_to_interval_unit(self, unit: str): + if unit not in self._timedelta_to_interval_units: + raise ValueError + return self._timedelta_to_interval_units[unit] + def __init__( self, unit: str = 's', @@ -413,7 +429,10 @@ def __init__( ) -> None: super().__init__(nullable=nullable) if unit not in self._units: - raise ValueError('Unsupported interval unit `{}`'.format(unit)) + try: + unit = self._convert_timedelta_unit_to_interval_unit(unit) + except ValueError: + raise ValueError('Unsupported interval unit `{}`'.format(unit)) if value_type is None: value_type = int32 @@ -536,11 +555,10 @@ def _literal_value_hash_key(self, value): def _tuplize(values): """Recursively convert `values` to a tuple of tuples.""" + def tuplize_iter(values): yield from ( - tuple(tuplize_iter(value)) - if util.is_iterable(value) - else value + tuple(tuplize_iter(value)) if util.is_iterable(value) else value for value in values ) @@ -665,7 +683,7 @@ def _literal_value_hash_key(self, value): shapely.geometry.Polygon, shapely.geometry.MultiLineString, shapely.geometry.MultiPoint, - shapely.geometry.MultiPolygon + shapely.geometry.MultiPolygon, ) if isinstance(value, geo_shapes): return self, value.wkt @@ -1422,6 +1440,17 @@ def type(self) -> DataType: validate_type = dtype +def _get_timedelta_units(timedelta: datetime.timedelta) -> List[str]: + unit_fields = timedelta.components._fields + time_units = [] + [ + time_units.append(field) + for field in unit_fields + if getattr(timedelta.components, field) > 0 + ] + return time_units + + @dtype.register(object) def default(value, **kwargs) -> DataType: raise com.IbisTypeError('Value {!r} is not a valid datatype'.format(value)) @@ -1536,7 +1565,14 @@ def infer_timestamp(value: datetime.datetime) -> Timestamp: @infer.register(datetime.timedelta) def infer_interval(value: datetime.timedelta) -> Interval: - return interval + time_units = _get_timedelta_units(value) + # we can attempt a conversion in the simplest case, i.e. there is exactly + # one unit (e.g. pd.Timedelta('2 days') vs. pd.Timedelta('2 days 3 hours') + if len(time_units) == 1: + unit = time_units[0] + return Interval(unit) + else: + return interval @infer.register(str) @@ -1572,13 +1608,14 @@ def infer_null(value: Optional[Null]) -> Null: if IS_SHAPELY_AVAILABLE: + @infer.register(shapely.geometry.Point) def infer_shapely_point(value: shapely.geometry.Point) -> Point: return point @infer.register(shapely.geometry.LineString) def infer_shapely_linestring( - value: shapely.geometry.LineString + value: shapely.geometry.LineString, ) -> LineString: return linestring @@ -1588,19 +1625,19 @@ def infer_shapely_polygon(value: shapely.geometry.Polygon) -> Polygon: @infer.register(shapely.geometry.MultiLineString) def infer_shapely_multilinestring( - value: shapely.geometry.MultiLineString + value: shapely.geometry.MultiLineString, ) -> MultiLineString: return multilinestring @infer.register(shapely.geometry.MultiPoint) def infer_shapely_multipoint( - value: shapely.geometry.MultiPoint + value: shapely.geometry.MultiPoint, ) -> MultiPoint: return multipoint @infer.register(shapely.geometry.MultiPolygon) def infer_shapely_multipolygon( - value: shapely.geometry.MultiPolygon + value: shapely.geometry.MultiPolygon, ) -> MultiPolygon: return multipolygon @@ -1721,7 +1758,12 @@ def can_cast_variadic( # geo spatial data type # cast between same type, used to cast from/to geometry and geography GEO_TYPES = ( - Point, LineString, Polygon, MultiLineString, MultiPoint, MultiPolygon + Point, + LineString, + Polygon, + MultiLineString, + MultiPoint, + MultiPolygon, ) diff --git a/ibis/expr/operations.py b/ibis/expr/operations.py index d86ca2b73276..0cc1f33392c7 100644 --- a/ibis/expr/operations.py +++ b/ibis/expr/operations.py @@ -3,6 +3,7 @@ import itertools import operator from contextlib import suppress +from typing import List import toolz @@ -1026,8 +1027,10 @@ def __init__(self, expr, window): window = window.bind(table) if window.max_lookback is not None: - error_msg = ("'max lookback' windows must be ordered " - "by a timestamp column") + error_msg = ( + "'max lookback' windows must be ordered " + "by a timestamp column" + ) if len(window._order_by) != 1: raise com.IbisInputError(error_msg) order_var = window._order_by[0].op().args[0] @@ -1730,6 +1733,13 @@ def __init__(self, left, right, predicates, by, tolerance): super().__init__(left, right, predicates) self.by = _clean_join_predicates(self.left, self.right, by) self.tolerance = tolerance + self._validate_args(['by', 'tolerance']) + + def _validate_args(self, args: List[str]): + for arg in args: + argument = self.signature[arg] + value = argument.validate(getattr(self, arg)) + setattr(self, arg, value) class Union(TableNode, HasSchema): @@ -3196,6 +3206,7 @@ class GeoSRID(GeoSpatialUnOp): class GeoSetSRID(GeoSpatialUnOp): """Set the spatial reference identifier for the ST_Geometry.""" + srid = Arg(rlz.integer) output_type = rlz.shape_like('args', dt.geometry) @@ -3221,6 +3232,7 @@ class GeoDFullyWithin(GeoSpatialBinOp): """Returns True if the geometries are fully within the specified distance of one another. """ + distance = Arg(rlz.floating) output_type = rlz.shape_like('args', dt.boolean) @@ -3230,6 +3242,7 @@ class GeoDWithin(GeoSpatialBinOp): """Returns True if the geometries are within the specified distance of one another. """ + distance = Arg(rlz.floating) output_type = rlz.shape_like('args', dt.boolean) diff --git a/ibis/expr/tests/test_table.py b/ibis/expr/tests/test_table.py index a70640b2ad0d..1a48ff859099 100644 --- a/ibis/expr/tests/test_table.py +++ b/ibis/expr/tests/test_table.py @@ -1,7 +1,9 @@ import pickle import re +import pandas as pd import pytest +from pytest import param import ibis import ibis.common.exceptions as com @@ -695,6 +697,59 @@ def test_asof_join_with_by(): assert by.left.op().name == by.right.op().name == 'key' +@pytest.mark.parametrize( + ('ibis_interval', 'timedelta_interval'), + [ + [ibis.interval(days=2), pd.Timedelta('2 days')], + [ibis.interval(hours=5), pd.Timedelta('5 hours')], + [ibis.interval(minutes=7), pd.Timedelta('7 minutes')], + [ibis.interval(seconds=9), pd.Timedelta('7 seconds')], + [ibis.interval(milliseconds=9), pd.Timedelta('9 milliseconds')], + [ibis.interval(microseconds=11), pd.Timedelta('11 microseconds')], + [ibis.interval(nanoseconds=17), pd.Timedelta('17 nanoseconds')], + param( + ibis.interval(weeks=3), + pd.Timedelta('3 W'), + id='weeks', + marks=pytest.mark.xfail( + reason='Week conversion from Timedelta to ibis interval ' + 'not supported' + ), + ), + param( + ibis.interval(years=3), + pd.Timedelta('3 Y'), + id='years', + marks=pytest.mark.xfail( + reason='Year conversion from Timedelta to ibis interval ' + 'not supported' + ), + ), + ], +) +def test_asof_join_with_tolerance(ibis_interval, timedelta_interval): + left = ibis.table( + [('time', 'int32'), ('key', 'int32'), ('value', 'double')] + ) + right = ibis.table( + [('time', 'int32'), ('key', 'int32'), ('value2', 'double')] + ) + + joined = api.asof_join(left, right, 'time', tolerance=ibis_interval) + tolerance = joined.op().tolerance + assert_equal(tolerance, ibis_interval) + + joined = api.asof_join(left, right, 'time', tolerance=timedelta_interval) + tolerance = joined.op().tolerance + + assert isinstance(tolerance, ir.IntervalScalar) + assert isinstance(tolerance.op(), ops.Literal) + + ibis_interval_unit = ibis_interval.op().dtype.unit + timedelta_unit = tolerance.op().dtype.unit + assert timedelta_unit == ibis_interval_unit + + def test_equijoin_schema_merge(): table1 = ibis.table([('key1', 'string'), ('value1', 'double')]) table2 = ibis.table([('key2', 'string'), ('stuff', 'int32')]) @@ -1064,7 +1119,7 @@ def test_cannot_use_existence_expression_in_join(table): def test_not_exists_predicate(t1, t2): - cond = -(t1.key1 == t2.key1).any() + cond = -((t1.key1 == t2.key1).any()) assert isinstance(cond.op(), ops.NotAny)