diff --git a/ibis/expr/datatypes.py b/ibis/expr/datatypes.py index 25e875521453..599a04d6d539 100644 --- a/ibis/expr/datatypes.py +++ b/ibis/expr/datatypes.py @@ -406,6 +406,21 @@ class Interval(DataType): ns='nanosecond', ) + _timedelta_to_interval_units = dict( + days='D', + hours='h', + minutes='m', + seconds='s', + milliseconds='ms', + microseconds='us', + nanoseconds='ns', + ) + + def _convert_timedelta_unit_to_interval_unit(self, unit: str): + if unit not in self._timedelta_to_interval_units: + raise ValueError + return self._timedelta_to_interval_units[unit] + def __init__( self, unit: str = 's', @@ -414,7 +429,10 @@ def __init__( ) -> None: super().__init__(nullable=nullable) if unit not in self._units: - raise ValueError('Unsupported interval unit `{}`'.format(unit)) + try: + unit = self._convert_timedelta_unit_to_interval_unit(unit) + except ValueError: + raise ValueError('Unsupported interval unit `{}`'.format(unit)) if value_type is None: value_type = int32 @@ -1422,6 +1440,25 @@ def type(self) -> DataType: validate_type = dtype +def _get_timedelta_units(timedelta: datetime.timedelta) -> List[str]: + # pandas Timedelta has more granularity + if hasattr(timedelta, 'components'): + unit_fields = timedelta.components._fields + base_object = timedelta.components + # datetime.timedelta only stores days, seconds, and microseconds internally + else: + unit_fields = ['days', 'seconds', 'microseconds'] + base_object = timedelta + + time_units = [] + [ + time_units.append(field) + for field in unit_fields + if getattr(base_object, field) > 0 + ] + return time_units + + @dtype.register(object) def default(value, **kwargs) -> DataType: raise com.IbisTypeError('Value {!r} is not a valid datatype'.format(value)) @@ -1536,7 +1573,14 @@ def infer_timestamp(value: datetime.datetime) -> Timestamp: @infer.register(datetime.timedelta) def infer_interval(value: datetime.timedelta) -> Interval: - return interval + time_units = _get_timedelta_units(value) + # we can attempt a conversion in the simplest case, i.e. there is exactly + # one unit (e.g. pd.Timedelta('2 days') vs. pd.Timedelta('2 days 3 hours') + if len(time_units) == 1: + unit = time_units[0] + return Interval(unit) + else: + return interval @infer.register(str) diff --git a/ibis/expr/operations.py b/ibis/expr/operations.py index 34abb12501a6..0cc1f33392c7 100644 --- a/ibis/expr/operations.py +++ b/ibis/expr/operations.py @@ -3,6 +3,7 @@ import itertools import operator from contextlib import suppress +from typing import List import toolz @@ -1732,6 +1733,13 @@ def __init__(self, left, right, predicates, by, tolerance): super().__init__(left, right, predicates) self.by = _clean_join_predicates(self.left, self.right, by) self.tolerance = tolerance + self._validate_args(['by', 'tolerance']) + + def _validate_args(self, args: List[str]): + for arg in args: + argument = self.signature[arg] + value = argument.validate(getattr(self, arg)) + setattr(self, arg, value) class Union(TableNode, HasSchema): diff --git a/ibis/expr/tests/test_datatypes.py b/ibis/expr/tests/test_datatypes.py index e0d1dadd8bb3..16ee60fe407f 100644 --- a/ibis/expr/tests/test_datatypes.py +++ b/ibis/expr/tests/test_datatypes.py @@ -1,9 +1,11 @@ import datetime from collections import OrderedDict +import pandas as pd import pytest import pytz from multipledispatch.conflict import ambiguities +from pytest import param import ibis import ibis.expr.datatypes as dt @@ -367,7 +369,13 @@ def test_time_valid(): ('foo', dt.string), (datetime.date.today(), dt.date), (datetime.datetime.now(), dt.timestamp), - (datetime.timedelta(days=3), dt.interval), + (datetime.timedelta(days=3), dt.Interval(unit='D')), + (pd.Timedelta('5 hours'), dt.Interval(unit='h')), + (pd.Timedelta('7 minutes'), dt.Interval(unit='m')), + (datetime.timedelta(seconds=9), dt.Interval(unit='s')), + (pd.Timedelta('11 milliseconds'), dt.Interval(unit='ms')), + (datetime.timedelta(microseconds=15), dt.Interval(unit='us')), + (pd.Timedelta('17 nanoseconds'), dt.Interval(unit='ns')), # numeric types (5, dt.int8), (5, dt.int8), @@ -417,6 +425,51 @@ def test_time_valid(): ] ), ), + param( + datetime.timedelta(hours=5), + dt.Interval(unit='h'), + id='dateime hours', + marks=pytest.mark.xfail( + reason='Hour conversion from datetime.timedelta to ibis ' + 'interval not supported' + ), + ), + param( + datetime.timedelta(minutes=7), + dt.Interval(unit='m'), + id='dateime minutes', + marks=pytest.mark.xfail( + reason='Minute conversion from datetime.timedelta to ibis ' + 'interval not supported' + ), + ), + param( + datetime.timedelta(milliseconds=11), + dt.Interval(unit='ms'), + id='dateime milliseconds', + marks=pytest.mark.xfail( + reason='Millisecond conversion from datetime.timedelta to ' + 'ibis interval not supported' + ), + ), + param( + pd.Timedelta('3', unit='W'), + dt.Interval(unit='W'), + id='weeks', + marks=pytest.mark.xfail( + reason='Week conversion from Timedelta to ibis interval ' + 'not supported' + ), + ), + param( + pd.Timedelta('3', unit='Y'), + dt.Interval(unit='Y'), + id='years', + marks=pytest.mark.xfail( + reason='Year conversion from Timedelta to ibis interval ' + 'not supported' + ), + ), ], ) def test_infer_dtype(value, expected_dtype): diff --git a/ibis/expr/tests/test_table.py b/ibis/expr/tests/test_table.py index ce121617efcf..e8950fac00d6 100644 --- a/ibis/expr/tests/test_table.py +++ b/ibis/expr/tests/test_table.py @@ -1,6 +1,8 @@ +import datetime import pickle import re +import pandas as pd import pytest import ibis @@ -695,6 +697,42 @@ def test_asof_join_with_by(): assert by.left.op().name == by.right.op().name == 'key' +@pytest.mark.parametrize( + ('ibis_interval', 'timedelta_interval'), + [ + [ibis.interval(days=2), pd.Timedelta('2 days')], + [ibis.interval(days=2), datetime.timedelta(days=2)], + [ibis.interval(hours=5), pd.Timedelta('5 hours')], + [ibis.interval(hours=5), datetime.timedelta(hours=5)], + [ibis.interval(minutes=7), pd.Timedelta('7 minutes')], + [ibis.interval(minutes=7), datetime.timedelta(minutes=7)], + [ibis.interval(seconds=9), pd.Timedelta('9 seconds')], + [ibis.interval(seconds=9), datetime.timedelta(seconds=9)], + [ibis.interval(milliseconds=11), pd.Timedelta('11 milliseconds')], + [ibis.interval(milliseconds=11), datetime.timedelta(milliseconds=11)], + [ibis.interval(microseconds=15), pd.Timedelta('15 microseconds')], + [ibis.interval(microseconds=15), datetime.timedelta(microseconds=15)], + [ibis.interval(nanoseconds=17), pd.Timedelta('17 nanoseconds')], + ], +) +def test_asof_join_with_tolerance(ibis_interval, timedelta_interval): + left = ibis.table( + [('time', 'int32'), ('key', 'int32'), ('value', 'double')] + ) + right = ibis.table( + [('time', 'int32'), ('key', 'int32'), ('value2', 'double')] + ) + + joined = api.asof_join(left, right, 'time', tolerance=ibis_interval) + tolerance = joined.op().tolerance + assert_equal(tolerance, ibis_interval) + + joined = api.asof_join(left, right, 'time', tolerance=timedelta_interval) + tolerance = joined.op().tolerance + assert isinstance(tolerance, ir.IntervalScalar) + assert isinstance(tolerance.op(), ops.Literal) + + def test_equijoin_schema_merge(): table1 = ibis.table([('key1', 'string'), ('value1', 'double')]) table2 = ibis.table([('key2', 'string'), ('stuff', 'int32')])