Skip to content

Commit

Permalink
Implement basic aggregation, group_by and window (ibis-project#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Aug 7, 2019
1 parent 09f4325 commit 0064606
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 2 deletions.
64 changes: 64 additions & 0 deletions ibis/pyspark/compiler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import ibis.common as com
import ibis.sql.compiler as comp
import ibis.expr.window as window
import ibis.expr.operations as ops


from ibis.pyspark.operations import PysparkTable

from ibis.sql.compiler import Dialect

import pyspark.sql.functions as F
from pyspark.sql.window import Window

_operation_registry = {
}

Expand Down Expand Up @@ -45,6 +50,7 @@ def compile_datasource(t, expr):
name, _, client = op.args
return client._session.table(name)


@compiles(ops.Selection)
def compile_selection(t, expr):
op = expr.op()
Expand All @@ -63,6 +69,64 @@ def compile_column(t, expr):
return t.translate(op.table)[op.name]


@compiles(ops.Multiply)
def compile_multiply(t, expr):
op = expr.op()
return t.translate(op.left) * t.translate(op.right)


@compiles(ops.Subtract)
def compile_subtract(t, expr):
op = expr.op()
return t.translate(op.left) - t.translate(op.right)


@compiles(ops.Aggregation)
def compile_aggregation(t, expr):
op = expr.op()

src_table = t.translate(op.table)
aggs = [t.translate(m) for m in op.metrics]

if op.by:
bys = [t.translate(b) for b in op.by]
return src_table.groupby(*bys).agg(*aggs)
else:
return src_table.agg(*aggs)


@compiles(ops.Max)
def compile_max(t, expr):
op = expr.op()

# TODO: Derive the UDF output type from schema
@F.pandas_udf('long', F.PandasUDFType.GROUPED_AGG)
def max(v):
return v.max()

src_column = t.translate(op.arg)
return max(src_column)

@compiles(ops.Mean)
def compile_mean(t, expr):
op = expr.op()
src_column = t.translate(op.arg)

return F.mean(src_column)

@compiles(ops.WindowOp)
def compile_window_op(t, expr):
op = expr.op()
return t.translate(op.expr).over(compile_window(op.window))

# Cannot register with @compiles because window doesn't have an
# op() object
def compile_window(expr):
window = expr
spark_window = Window.partitionBy()
return spark_window


t = PysparkExprTranslator()

def translate(expr):
Expand Down
34 changes: 32 additions & 2 deletions ibis/pyspark/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest

import ibis
import pyspark.sql.functions as F
from pyspark.sql.window import Window

@pytest.fixture(scope='session')
def client():
Expand All @@ -21,7 +23,6 @@ def test_basic(client):


def test_projection(client):
import ipdb; ipdb.set_trace()
table = client.table('table1')
result1 = table.mutate(v=table['id']).compile().toPandas()

Expand All @@ -46,5 +47,34 @@ def test_projection(client):
tm.assert_frame_equal(result2, expected2)


def test_udf(client):
def test_aggregation(client):
table = client.table('table1')
result = table.aggregate(table['id'].max()).compile()
expected = table.compile().agg(F.max('id'))

tm.assert_frame_equal(result.toPandas(), expected.toPandas())


def test_groupby(client):
table = client.table('table1')
result = table.groupby('id').aggregate(table['id'].max()).compile()
expected = table.compile().groupby('id').agg(F.max('id'))

tm.assert_frame_equal(result.toPandas(), expected.toPandas())


def test_window(client):
table = client.table('table1')
w = ibis.window()
result = table.mutate(grouped_demeaned = table['id'] - table['id'].mean().over(w)).compile()
result2 = table.groupby('id').mutate(grouped_demeaned = table['id'] - table['id'].mean()).compile()

spark_window = Window.partitionBy()
spark_table = table.compile()
expected = spark_table.withColumn(
'grouped_demeaned',
spark_table['id'] - F.mean(spark_table['id']).over(spark_window)
)

tm.assert_frame_equal(result.toPandas(), expected.toPandas())
tm.assert_frame_equal(result2.toPandas(), expected.toPandas())

0 comments on commit 0064606

Please sign in to comment.