Skip to content

Commit

Permalink
validate AsOfJoin tolerance and attempt interval unit conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
emilyreff7 committed Sep 6, 2019
1 parent ffe68dd commit 97fa826
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 14 deletions.
64 changes: 53 additions & 11 deletions ibis/expr/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
try:
if sys.version_info >= (3, 6):
import shapely.geometry

IS_SHAPELY_AVAILABLE = True
except ImportError:
...
Expand Down Expand Up @@ -405,6 +406,21 @@ class Interval(DataType):
ns='nanosecond',
)

_timedelta_to_interval_units = dict(
days='D',
hours='h',
minutes='m',
seconds='s',
milliseconds='ms',
microseconds='us',
nanoseconds='ns',
)

def _convert_timedelta_unit_to_interval_unit(self, unit: str):
if unit not in self._timedelta_to_interval_units:
raise ValueError
return self._timedelta_to_interval_units[unit]

def __init__(
self,
unit: str = 's',
Expand All @@ -413,7 +429,10 @@ def __init__(
) -> None:
super().__init__(nullable=nullable)
if unit not in self._units:
raise ValueError('Unsupported interval unit `{}`'.format(unit))
try:
unit = self._convert_timedelta_unit_to_interval_unit(unit)
except ValueError:
raise ValueError('Unsupported interval unit `{}`'.format(unit))

if value_type is None:
value_type = int32
Expand Down Expand Up @@ -536,11 +555,10 @@ def _literal_value_hash_key(self, value):

def _tuplize(values):
"""Recursively convert `values` to a tuple of tuples."""

def tuplize_iter(values):
yield from (
tuple(tuplize_iter(value))
if util.is_iterable(value)
else value
tuple(tuplize_iter(value)) if util.is_iterable(value) else value
for value in values
)

Expand Down Expand Up @@ -665,7 +683,7 @@ def _literal_value_hash_key(self, value):
shapely.geometry.Polygon,
shapely.geometry.MultiLineString,
shapely.geometry.MultiPoint,
shapely.geometry.MultiPolygon
shapely.geometry.MultiPolygon,
)
if isinstance(value, geo_shapes):
return self, value.wkt
Expand Down Expand Up @@ -1422,6 +1440,17 @@ def type(self) -> DataType:
validate_type = dtype


def _get_timedelta_units(timedelta: datetime.timedelta) -> List[str]:
unit_fields = timedelta.components._fields
time_units = []
[
time_units.append(field)
for field in unit_fields
if getattr(timedelta.components, field) > 0
]
return time_units


@dtype.register(object)
def default(value, **kwargs) -> DataType:
raise com.IbisTypeError('Value {!r} is not a valid datatype'.format(value))
Expand Down Expand Up @@ -1536,7 +1565,14 @@ def infer_timestamp(value: datetime.datetime) -> Timestamp:

@infer.register(datetime.timedelta)
def infer_interval(value: datetime.timedelta) -> Interval:
return interval
time_units = _get_timedelta_units(value)
# we can attempt a conversion in the simplest case, i.e. there is exactly
# one unit (e.g. pd.Timedelta('2 days') vs. pd.Timedelta('2 days 3 hours')
if len(time_units) == 1:
unit = time_units[0]
return Interval(unit)
else:
return interval


@infer.register(str)
Expand Down Expand Up @@ -1572,13 +1608,14 @@ def infer_null(value: Optional[Null]) -> Null:


if IS_SHAPELY_AVAILABLE:

@infer.register(shapely.geometry.Point)
def infer_shapely_point(value: shapely.geometry.Point) -> Point:
return point

@infer.register(shapely.geometry.LineString)
def infer_shapely_linestring(
value: shapely.geometry.LineString
value: shapely.geometry.LineString,
) -> LineString:
return linestring

Expand All @@ -1588,19 +1625,19 @@ def infer_shapely_polygon(value: shapely.geometry.Polygon) -> Polygon:

@infer.register(shapely.geometry.MultiLineString)
def infer_shapely_multilinestring(
value: shapely.geometry.MultiLineString
value: shapely.geometry.MultiLineString,
) -> MultiLineString:
return multilinestring

@infer.register(shapely.geometry.MultiPoint)
def infer_shapely_multipoint(
value: shapely.geometry.MultiPoint
value: shapely.geometry.MultiPoint,
) -> MultiPoint:
return multipoint

@infer.register(shapely.geometry.MultiPolygon)
def infer_shapely_multipolygon(
value: shapely.geometry.MultiPolygon
value: shapely.geometry.MultiPolygon,
) -> MultiPolygon:
return multipolygon

Expand Down Expand Up @@ -1721,7 +1758,12 @@ def can_cast_variadic(
# geo spatial data type
# cast between same type, used to cast from/to geometry and geography
GEO_TYPES = (
Point, LineString, Polygon, MultiLineString, MultiPoint, MultiPolygon
Point,
LineString,
Polygon,
MultiLineString,
MultiPoint,
MultiPolygon,
)


Expand Down
17 changes: 15 additions & 2 deletions ibis/expr/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import operator
from contextlib import suppress
from typing import List

import toolz

Expand Down Expand Up @@ -1026,8 +1027,10 @@ def __init__(self, expr, window):
window = window.bind(table)

if window.max_lookback is not None:
error_msg = ("'max lookback' windows must be ordered "
"by a timestamp column")
error_msg = (
"'max lookback' windows must be ordered "
"by a timestamp column"
)
if len(window._order_by) != 1:
raise com.IbisInputError(error_msg)
order_var = window._order_by[0].op().args[0]
Expand Down Expand Up @@ -1730,6 +1733,13 @@ def __init__(self, left, right, predicates, by, tolerance):
super().__init__(left, right, predicates)
self.by = _clean_join_predicates(self.left, self.right, by)
self.tolerance = tolerance
self._validate_args(['by', 'tolerance'])

def _validate_args(self, args: List[str]):
for arg in args:
argument = self.signature[arg]
value = argument.validate(getattr(self, arg))
setattr(self, arg, value)


class Union(TableNode, HasSchema):
Expand Down Expand Up @@ -3196,6 +3206,7 @@ class GeoSRID(GeoSpatialUnOp):

class GeoSetSRID(GeoSpatialUnOp):
"""Set the spatial reference identifier for the ST_Geometry."""

srid = Arg(rlz.integer)
output_type = rlz.shape_like('args', dt.geometry)

Expand All @@ -3221,6 +3232,7 @@ class GeoDFullyWithin(GeoSpatialBinOp):
"""Returns True if the geometries are fully within the specified distance
of one another.
"""

distance = Arg(rlz.floating)

output_type = rlz.shape_like('args', dt.boolean)
Expand All @@ -3230,6 +3242,7 @@ class GeoDWithin(GeoSpatialBinOp):
"""Returns True if the geometries are within the specified distance
of one another.
"""

distance = Arg(rlz.floating)

output_type = rlz.shape_like('args', dt.boolean)
Expand Down
57 changes: 56 additions & 1 deletion ibis/expr/tests/test_table.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pickle
import re

import pandas as pd
import pytest
from pytest import param

import ibis
import ibis.common.exceptions as com
Expand Down Expand Up @@ -695,6 +697,59 @@ def test_asof_join_with_by():
assert by.left.op().name == by.right.op().name == 'key'


@pytest.mark.parametrize(
('ibis_interval', 'timedelta_interval'),
[
[ibis.interval(days=2), pd.Timedelta('2 days')],
[ibis.interval(hours=5), pd.Timedelta('5 hours')],
[ibis.interval(minutes=7), pd.Timedelta('7 minutes')],
[ibis.interval(seconds=9), pd.Timedelta('7 seconds')],
[ibis.interval(milliseconds=9), pd.Timedelta('9 milliseconds')],
[ibis.interval(microseconds=11), pd.Timedelta('11 microseconds')],
[ibis.interval(nanoseconds=17), pd.Timedelta('17 nanoseconds')],
param(
ibis.interval(weeks=3),
pd.Timedelta('3 W'),
id='weeks',
marks=pytest.mark.xfail(
reason='Week conversion from Timedelta to ibis interval '
'not supported'
),
),
param(
ibis.interval(years=3),
pd.Timedelta('3 Y'),
id='years',
marks=pytest.mark.xfail(
reason='Year conversion from Timedelta to ibis interval '
'not supported'
),
),
],
)
def test_asof_join_with_tolerance(ibis_interval, timedelta_interval):
left = ibis.table(
[('time', 'int32'), ('key', 'int32'), ('value', 'double')]
)
right = ibis.table(
[('time', 'int32'), ('key', 'int32'), ('value2', 'double')]
)

joined = api.asof_join(left, right, 'time', tolerance=ibis_interval)
tolerance = joined.op().tolerance
assert_equal(tolerance, ibis_interval)

joined = api.asof_join(left, right, 'time', tolerance=timedelta_interval)
tolerance = joined.op().tolerance

assert isinstance(tolerance, ir.IntervalScalar)
assert isinstance(tolerance.op(), ops.Literal)

ibis_interval_unit = ibis_interval.op().dtype.unit
timedelta_unit = tolerance.op().dtype.unit
assert timedelta_unit == ibis_interval_unit


def test_equijoin_schema_merge():
table1 = ibis.table([('key1', 'string'), ('value1', 'double')])
table2 = ibis.table([('key2', 'string'), ('stuff', 'int32')])
Expand Down Expand Up @@ -1064,7 +1119,7 @@ def test_cannot_use_existence_expression_in_join(table):


def test_not_exists_predicate(t1, t2):
cond = -(t1.key1 == t2.key1).any()
cond = -((t1.key1 == t2.key1).any())
assert isinstance(cond.op(), ops.NotAny)


Expand Down

0 comments on commit 97fa826

Please sign in to comment.