Skip to content

Commit

Permalink
Implement compiler rules to pass all/test_aggregation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Aug 22, 2019
1 parent 215c0d9 commit 675a89f
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 29 deletions.
9 changes: 8 additions & 1 deletion ibis/pyspark/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ibis.expr.types as types
from ibis.pyspark.compiler import translate
from ibis.pyspark.operations import PysparkTable
from ibis.spark.client import SparkClient
Expand All @@ -17,4 +18,10 @@ def compile(self, expr, *args, **kwargs):
return translate(expr)

def execute(self, expr, params=None, limit='default', **kwargs):
return self.compile(expr).toPandas()

if isinstance(expr, types.TableExpr):
return self.compile(expr).toPandas()
elif isinstance(expr, types.ScalarExpr):
return self.compile(expr).toPandas().iloc[0, 0]
else:
raise ValueError("Unexpected type: ", type(expr))
172 changes: 153 additions & 19 deletions ibis/pyspark/compiler.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import collections
import functools

import pyspark.sql.functions as F
from pyspark.sql.window import Window

import ibis.common as com
import ibis.sql.compiler as comp
import ibis.expr.window as window
import ibis.expr.operations as ops
import ibis.expr.types as types


from ibis.pyspark.operations import PysparkTable

from ibis.sql.compiler import Dialect

import pyspark.sql.functions as F
from pyspark.sql.window import Window


_operation_registry = {}


Expand All @@ -27,13 +24,13 @@ def decorator(f):

return decorator

def translate(self, expr):
def translate(self, expr, **kwargs):
# The operation node type the typed expression wraps
op = expr.op()

if type(op) in self._registry:
formatter = self._registry[type(op)]
return formatter(self, expr)
return formatter(self, expr, **kwargs)
else:
raise com.OperationNotDefinedError(
'No translation rule for {}'.format(type(op))
Expand All @@ -46,6 +43,7 @@ class PysparkDialect(Dialect):

compiles = PysparkExprTranslator.compiles


@compiles(PysparkTable)
def compile_datasource(t, expr):
op = expr.op()
Expand Down Expand Up @@ -88,6 +86,18 @@ def compile_equals(t, expr):
return t.translate(op.left) == t.translate(op.right)


@compiles(ops.Greater)
def compile_greater(t, expr):
op = expr.op()
return t.translate(op.left) > t.translate(op.right)


@compiles(ops.GreaterEqual)
def compile_greater_equal(t, expr):
op = expr.op()
return t.translate(op.left) >= t.translate(op.right)


@compiles(ops.Multiply)
def compile_multiply(t, expr):
op = expr.op()
Expand All @@ -100,12 +110,28 @@ def compile_subtract(t, expr):
return t.translate(op.left) - t.translate(op.right)


@compiles(ops.Literal)
def compile_literal(t, expr):
value = expr.op().value

if isinstance(value, collections.abc.Set):
# Don't wrap set with F.lit
if isinstance(value, frozenset):
# Spark doens't like frozenset
return set(value)
else:
return value
else:
return F.lit(expr.op().value)


@compiles(ops.Aggregation)
def compile_aggregation(t, expr):
op = expr.op()

src_table = t.translate(op.table)
aggs = [t.translate(m) for m in op.metrics]
aggs = [t.translate(m, context="agg")
for m in op.metrics]

if op.by:
bys = [t.translate(b) for b in op.by]
Expand All @@ -114,18 +140,127 @@ def compile_aggregation(t, expr):
return src_table.agg(*aggs)


@compiles(ops.Max)
def compile_max(t, expr):
@compiles(ops.Contains)
def compile_contains(t, expr):
col = t.translate(expr.op().value)
return col.isin(t.translate(expr.op().options))


def compile_aggregator(t, expr, fn, context=None):
op = expr.op()
return F.max(t.translate(op.arg))
src_col = t.translate(op.arg)

if getattr(op, "where", None) is not None:
condition = t.translate(op.where)
src_col = F.when(condition, src_col)

col = fn(src_col)
if context:
return col
else:
return t.translate(expr.op().arg.op().table).select(col)


@compiles(ops.GroupConcat)
def compile_group_concat(t, expr, context=None):
sep = expr.op().sep.op().value

def fn(col):
return F.concat_ws(sep, F.collect_list(col))
return compile_aggregator(t, expr, fn, context)


@compiles(ops.Any)
def compile_any(t, expr, context=None):
return compile_aggregator(t, expr, F.max, context)


@compiles(ops.NotAny)
def compile_notany(t, expr, context=None):

def fn(col):
return ~F.max(col)
return compile_aggregator(t, expr, fn, context)


@compiles(ops.All)
def compile_all(t, expr, context=None):
return compile_aggregator(t, expr, F.min, context)


@compiles(ops.NotAll)
def compile_notall(t, expr, context=None):

def fn(col):
return ~F.min(col)
return compile_aggregator(t, expr, fn, context)


@compiles(ops.Count)
def compile_count(t, expr, context=None):
return compile_aggregator(t, expr, F.count, context)


@compiles(ops.Max)
def compile_max(t, expr, context=None):
return compile_aggregator(t, expr, F.max, context)


@compiles(ops.Min)
def compile_min(t, expr, context=None):
return compile_aggregator(t, expr, F.min, context)


@compiles(ops.Mean)
def compile_mean(t, expr):
op = expr.op()
return F.mean(t.translate(op.arg))
def compile_mean(t, expr, context=None):
return compile_aggregator(t, expr, F.mean, context)


@compiles(ops.Sum)
def compile_sum(t, expr, context=None):
return compile_aggregator(t, expr, F.sum, context)


@compiles(ops.StandardDev)
def compile_std(t, expr, context=None):
how = expr.op().how

if how == 'sample':
fn = F.stddev_samp
elif how == 'pop':
fn = F.stddev_pop
else:
raise AssertionError("Unexpected how: {}".format(how))

return compile_aggregator(t, expr, fn, context)


@compiles(ops.Variance)
def compile_variance(t, expr, context=None):
how = expr.op().how

if how == 'sample':
fn = F.var_samp
elif how == 'pop':
fn = F.var_pop
else:
raise AssertionError("Unexpected how: {}".format(how))

return compile_aggregator(t, expr, fn, context)


@compiles(ops.Arbitrary)
def compile_arbitrary(t, expr, context=None):
how = expr.op().how

if how == 'first':
fn = functools.partial(F.first, ignorenulls=True)
elif how == 'last':
fn = functools.partial(F.last, ignorenulls=True)
else:
raise NotImplementedError

return compile_aggregator(t, expr, fn, context)


@compiles(ops.WindowOp)
Expand Down Expand Up @@ -170,7 +305,6 @@ def compile_join(t, expr, how):
# Cannot register with @compiles because window doesn't have an
# op() object
def compile_window(expr):
window = expr
spark_window = Window.partitionBy()
return spark_window

Expand Down
50 changes: 42 additions & 8 deletions ibis/pyspark/tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import pandas as pd
import pandas.util.testing as tm
import pyspark.sql.functions as F
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.window import Window

import ibis
import pyspark.sql.functions as F
from pyspark.sql.window import Window


@pytest.fixture(scope='session')
def client():
client = ibis.pyspark.connect()
session = SparkSession.builder.getOrCreate()
client = ibis.pyspark.connect(session)
df = client._session.range(0, 10)
df = df.withColumn("str_col", F.lit('value'))
df.createTempView('table1')

df1 = client._session.createDataFrame([(True,), (False,)]).toDF('v')
df1.createTempView('table2')
return client


Expand All @@ -33,7 +40,10 @@ def test_projection(client):
}
)

result2 = table.mutate(v=table['id']).mutate(v2=table['id']).compile().toPandas()
result2 = (
table.mutate(v=table['id']).mutate(v2=table['id'])
.compile().toPandas()
)

expected2 = pd.DataFrame(
{
Expand All @@ -47,6 +57,12 @@ def test_projection(client):
tm.assert_frame_equal(result2, expected2)


def test_aggregation_col(client):
table = client.table('table1')
result = table['id'].count().execute()
assert result == table.compile().count()


def test_aggregation(client):
table = client.table('table1')
result = table.aggregate(table['id'].max()).compile()
Expand All @@ -66,8 +82,19 @@ def test_groupby(client):
def test_window(client):
table = client.table('table1')
w = ibis.window()
result = table.mutate(grouped_demeaned = table['id'] - table['id'].mean().over(w)).compile()
result2 = table.groupby('id').mutate(grouped_demeaned = table['id'] - table['id'].mean()).compile()
result = (
table
.mutate(
grouped_demeaned=table['id'] - table['id'].mean().over(w))
.compile()
)
result2 = (
table
.groupby('id')
.mutate(
grouped_demeaned=table['id'] - table['id'].mean())
.compile()
)

spark_window = Window.partitionBy()
spark_table = table.compile()
Expand All @@ -82,7 +109,11 @@ def test_window(client):

def test_greatest(client):
table = client.table('table1')
result = table.mutate(greatest = ibis.greatest(table.id)).compile()
result = (
table
.mutate(greatest=ibis.greatest(table.id))
.compile()
)
df = table.compile()
expected = table.compile().withColumn('greatest', df.id)

Expand All @@ -105,6 +136,9 @@ def test_join(client):
table = client.table('table1')
result = table.join(table, 'id').compile()
spark_table = table.compile()
expected = spark_table.join(spark_table, spark_table['id'] == spark_table['id'])
expected = (
spark_table
.join(spark_table, spark_table['id'] == spark_table['id'])
)

tm.assert_frame_equal(result.toPandas(), expected.toPandas())
3 changes: 2 additions & 1 deletion ibis/tests/all/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,5 @@ def test_group_concat(backend, alltypes, df, result_fn, expected_fn):
expr = result_fn(alltypes)
result = expr.execute()
expected = expected_fn(df)
assert set(result) == set(expected)

assert set(result.iloc[:, 1]) == set(expected.iloc[:, 1])

0 comments on commit 675a89f

Please sign in to comment.