Skip to content

Commit

Permalink
Implement basic join
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Aug 22, 2019
1 parent c4a2b79 commit 88705fe
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 11 deletions.
45 changes: 34 additions & 11 deletions ibis/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import pyspark.sql.functions as F
from pyspark.sql.window import Window

_operation_registry = {
}

_operation_registry = {}


class PysparkExprTranslator:
_registry = _operation_registry
Expand Down Expand Up @@ -75,6 +76,18 @@ def compile_column(t, expr):
return t.translate(op.table)[op.name]


@compiles(ops.SelfReference)
def compile_self_reference(t, expr):
op = expr.op()
return t.translate(op.table)


@compiles(ops.Equals)
def compile_equals(t, expr):
op = expr.op()
return t.translate(op.left) == t.translate(op.right)


@compiles(ops.Multiply)
def compile_multiply(t, expr):
op = expr.op()
Expand Down Expand Up @@ -104,22 +117,15 @@ def compile_aggregation(t, expr):
@compiles(ops.Max)
def compile_max(t, expr):
op = expr.op()
return F.max(t.translate(op.arg))

# TODO: Derive the UDF output type from schema
@F.pandas_udf('long', F.PandasUDFType.GROUPED_AGG)
def max(v):
return v.max()

src_column = t.translate(op.arg)
return max(src_column)


@compiles(ops.Mean)
def compile_mean(t, expr):
op = expr.op()
src_column = t.translate(op.arg)
return F.mean(t.translate(op.arg))

return F.mean(src_column)


@compiles(ops.WindowOp)
Expand All @@ -145,6 +151,22 @@ def compile_value_list(t, expr):
return [t.translate(col) for col in op.values]


@compiles(ops.InnerJoin)
def compile_inner_join(t, expr):
return compile_join(t, expr, 'inner')


def compile_join(t, expr, how):
op = expr.op()

left_df = t.translate(op.left)
right_df = t.translate(op.right)
# TODO: Handle multiple predicates
predicates = t.translate(op.predicates[0])

return left_df.join(right_df, predicates, how)


# Cannot register with @compiles because window doesn't have an
# op() object
def compile_window(expr):
Expand All @@ -155,5 +177,6 @@ def compile_window(expr):

t = PysparkExprTranslator()


def translate(expr):
return t.translate(expr)
9 changes: 9 additions & 0 deletions ibis/pyspark/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,12 @@ def test_selection(client):
df = table.compile()
tm.assert_frame_equal(result1.toPandas(), df[['id']].toPandas())
tm.assert_frame_equal(result2.toPandas(), df[['id', 'id2']].toPandas())


def test_join(client):
table = client.table('table1')
result = table.join(table, 'id').compile()
spark_table = table.compile()
expected = spark_table.join(spark_table, spark_table['id'] == spark_table['id'])

tm.assert_frame_equal(result.toPandas(), expected.toPandas())

0 comments on commit 88705fe

Please sign in to comment.