From a3fc1217b6dbc53fc2964f82c315ff1cbd9c5100 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 16 Jul 2019 17:57:28 -0400 Subject: [PATCH] Implement basic aggregation, group_by and window (#3) --- ibis/pyspark/compiler.py | 64 ++++++++++++++++++++++++++++++++ ibis/pyspark/tests/test_basic.py | 34 ++++++++++++++++- 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/ibis/pyspark/compiler.py b/ibis/pyspark/compiler.py index fa02e1be9d7b..7ca071afebc1 100644 --- a/ibis/pyspark/compiler.py +++ b/ibis/pyspark/compiler.py @@ -1,11 +1,16 @@ import ibis.common as com import ibis.sql.compiler as comp +import ibis.expr.window as window import ibis.expr.operations as ops from ibis.pyspark.operations import PysparkTable + from ibis.sql.compiler import Dialect +import pyspark.sql.functions as F +from pyspark.sql.window import Window + _operation_registry = { } @@ -45,6 +50,7 @@ def compile_datasource(t, expr): name, _, client = op.args return client._session.table(name) + @compiles(ops.Selection) def compile_selection(t, expr): op = expr.op() @@ -63,6 +69,64 @@ def compile_column(t, expr): return t.translate(op.table)[op.name] +@compiles(ops.Multiply) +def compile_multiply(t, expr): + op = expr.op() + return t.translate(op.left) * t.translate(op.right) + + +@compiles(ops.Subtract) +def compile_subtract(t, expr): + op = expr.op() + return t.translate(op.left) - t.translate(op.right) + + +@compiles(ops.Aggregation) +def compile_aggregation(t, expr): + op = expr.op() + + src_table = t.translate(op.table) + aggs = [t.translate(m) for m in op.metrics] + + if op.by: + bys = [t.translate(b) for b in op.by] + return src_table.groupby(*bys).agg(*aggs) + else: + return src_table.agg(*aggs) + + +@compiles(ops.Max) +def compile_max(t, expr): + op = expr.op() + + # TODO: Derive the UDF output type from schema + @F.pandas_udf('long', F.PandasUDFType.GROUPED_AGG) + def max(v): + return v.max() + + src_column = t.translate(op.arg) + return max(src_column) + +@compiles(ops.Mean) +def compile_mean(t, expr): + op = expr.op() + src_column = t.translate(op.arg) + + return F.mean(src_column) + +@compiles(ops.WindowOp) +def compile_window_op(t, expr): + op = expr.op() + return t.translate(op.expr).over(compile_window(op.window)) + +# Cannot register with @compiles because window doesn't have an +# op() object +def compile_window(expr): + window = expr + spark_window = Window.partitionBy() + return spark_window + + t = PysparkExprTranslator() def translate(expr): diff --git a/ibis/pyspark/tests/test_basic.py b/ibis/pyspark/tests/test_basic.py index 0b69f4edb0d9..4686d0df17cc 100644 --- a/ibis/pyspark/tests/test_basic.py +++ b/ibis/pyspark/tests/test_basic.py @@ -3,6 +3,8 @@ import pytest import ibis +import pyspark.sql.functions as F +from pyspark.sql.window import Window @pytest.fixture(scope='session') def client(): @@ -21,7 +23,6 @@ def test_basic(client): def test_projection(client): - import ipdb; ipdb.set_trace() table = client.table('table1') result1 = table.mutate(v=table['id']).compile().toPandas() @@ -46,5 +47,34 @@ def test_projection(client): tm.assert_frame_equal(result2, expected2) -def test_udf(client): +def test_aggregation(client): table = client.table('table1') + result = table.aggregate(table['id'].max()).compile() + expected = table.compile().agg(F.max('id')) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + + +def test_groupby(client): + table = client.table('table1') + result = table.groupby('id').aggregate(table['id'].max()).compile() + expected = table.compile().groupby('id').agg(F.max('id')) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + + +def test_window(client): + table = client.table('table1') + w = ibis.window() + result = table.mutate(grouped_demeaned = table['id'] - table['id'].mean().over(w)).compile() + result2 = table.groupby('id').mutate(grouped_demeaned = table['id'] - table['id'].mean()).compile() + + spark_window = Window.partitionBy() + spark_table = table.compile() + expected = spark_table.withColumn( + 'grouped_demeaned', + spark_table['id'] - F.mean(spark_table['id']).over(spark_window) + ) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + tm.assert_frame_equal(result2.toPandas(), expected.toPandas())