diff --git a/ibis/pyspark/compiler.py b/ibis/pyspark/compiler.py index 7ca071afebc1..2ec1423c918f 100644 --- a/ibis/pyspark/compiler.py +++ b/ibis/pyspark/compiler.py @@ -2,6 +2,7 @@ import ibis.sql.compiler as comp import ibis.expr.window as window import ibis.expr.operations as ops +import ibis.expr.types as types from ibis.pyspark.operations import PysparkTable @@ -54,11 +55,16 @@ def compile_datasource(t, expr): @compiles(ops.Selection) def compile_selection(t, expr): op = expr.op() - src_table = t.translate(op.selections[0]) - for selection in op.selections[1:]: - column_name = selection.get_name() - column = t.translate(selection) - src_table = src_table.withColumn(column_name, column) + + if isinstance(op.selections[0], types.ColumnExpr): + column_names = [expr.op().name for expr in op.selections] + src_table = t.translate(op.table)[column_names] + elif isinstance(op.selections[0], types.TableExpr): + src_table = t.translate(op.table) + for selection in op.selections[1:]: + column_name = selection.get_name() + column = t.translate(selection) + src_table = src_table.withColumn(column_name, column) return src_table @@ -107,6 +113,7 @@ def max(v): src_column = t.translate(op.arg) return max(src_column) + @compiles(ops.Mean) def compile_mean(t, expr): op = expr.op() @@ -114,11 +121,30 @@ def compile_mean(t, expr): return F.mean(src_column) + @compiles(ops.WindowOp) def compile_window_op(t, expr): op = expr.op() return t.translate(op.expr).over(compile_window(op.window)) + +@compiles(ops.Greatest) +def compile_greatest(t, expr): + op = expr.op() + + src_columns = t.translate(op.arg) + if len(src_columns) == 1: + return src_columns[0] + else: + return F.greatest(*src_columns) + + +@compiles(ops.ValueList) +def compile_value_list(t, expr): + op = expr.op() + return [t.translate(col) for col in op.values] + + # Cannot register with @compiles because window doesn't have an # op() object def compile_window(expr): diff --git a/ibis/pyspark/tests/test_basic.py b/ibis/pyspark/tests/test_basic.py index 4686d0df17cc..5bc94103e05d 100644 --- a/ibis/pyspark/tests/test_basic.py +++ b/ibis/pyspark/tests/test_basic.py @@ -78,3 +78,24 @@ def test_window(client): tm.assert_frame_equal(result.toPandas(), expected.toPandas()) tm.assert_frame_equal(result2.toPandas(), expected.toPandas()) + + +def test_greatest(client): + table = client.table('table1') + result = table.mutate(greatest = ibis.greatest(table.id)).compile() + df = table.compile() + expected = table.compile().withColumn('greatest', df.id) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + + +def test_selection(client): + table = client.table('table1') + table = table.mutate(id2=table['id']) + + result1 = table[['id']].compile() + result2 = table[['id', 'id2']].compile() + + df = table.compile() + tm.assert_frame_equal(result1.toPandas(), df[['id']].toPandas()) + tm.assert_frame_equal(result2.toPandas(), df[['id', 'id2']].toPandas())