diff --git a/ibis/__init__.py b/ibis/__init__.py index bd4cb7596053d..be2c227bef6ca 100644 --- a/ibis/__init__.py +++ b/ibis/__init__.py @@ -57,6 +57,9 @@ # pip install ibis-framework[spark] import ibis.spark.api as spark # noqa: F401 +with suppress(ImportError): + import ibis.pyspark.api as pyspark + def hdfs_connect( host='localhost', diff --git a/ibis/pyspark/api.py b/ibis/pyspark/api.py new file mode 100644 index 0000000000000..8539532849427 --- /dev/null +++ b/ibis/pyspark/api.py @@ -0,0 +1,12 @@ +from ibis.pyspark.client import PysparkClient + + +def connect(**kwargs): + """ + Create a `SparkClient` for use with Ibis. Pipes **kwargs into SparkClient, + which pipes them into SparkContext. See documentation for SparkContext: + https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext + """ + client = PysparkClient(**kwargs) + + return client diff --git a/ibis/pyspark/client.py b/ibis/pyspark/client.py new file mode 100644 index 0000000000000..fd0890a6f9a09 --- /dev/null +++ b/ibis/pyspark/client.py @@ -0,0 +1,19 @@ +from ibis.spark.client import SparkClient +from ibis.pyspark.operations import PysparkTable +from ibis.pyspark.compiler import translate + +class PysparkClient(SparkClient): + """ + An ibis client that uses Pyspark SQL Dataframe + """ + + dialect = None + table_class = PysparkTable + + def compile(self, expr, *args, **kwargs): + """Compile an ibis expression to a Pyspark DataFrame object + """ + return translate(expr) + + def execute(self, df, params=None, limit='default', **kwargs): + return df.toPandas() diff --git a/ibis/pyspark/compiler.py b/ibis/pyspark/compiler.py new file mode 100644 index 0000000000000..fa02e1be9d7b8 --- /dev/null +++ b/ibis/pyspark/compiler.py @@ -0,0 +1,69 @@ +import ibis.common as com +import ibis.sql.compiler as comp +import ibis.expr.operations as ops + + +from ibis.pyspark.operations import PysparkTable +from ibis.sql.compiler import Dialect + +_operation_registry = { +} + +class PysparkExprTranslator: + _registry = _operation_registry + + @classmethod + def compiles(cls, klass): + def decorator(f): + cls._registry[klass] = f + return f + + return decorator + + def translate(self, expr): + # The operation node type the typed expression wraps + op = expr.op() + + if type(op) in self._registry: + formatter = self._registry[type(op)] + return formatter(self, expr) + else: + raise com.OperationNotDefinedError( + 'No translation rule for {}'.format(type(op)) + ) + + +class PysparkDialect(Dialect): + translator = PysparkExprTranslator + + +compiles = PysparkExprTranslator.compiles + +@compiles(PysparkTable) +def compile_datasource(t, expr): + op = expr.op() + name, _, client = op.args + return client._session.table(name) + +@compiles(ops.Selection) +def compile_selection(t, expr): + op = expr.op() + src_table = t.translate(op.selections[0]) + for selection in op.selections[1:]: + column_name = selection.get_name() + column = t.translate(selection) + src_table = src_table.withColumn(column_name, column) + + return src_table + + +@compiles(ops.TableColumn) +def compile_column(t, expr): + op = expr.op() + return t.translate(op.table)[op.name] + + +t = PysparkExprTranslator() + +def translate(expr): + return t.translate(expr) diff --git a/ibis/pyspark/operations.py b/ibis/pyspark/operations.py new file mode 100644 index 0000000000000..ccc8025d5d6ec --- /dev/null +++ b/ibis/pyspark/operations.py @@ -0,0 +1,4 @@ +import ibis.expr.operations as ops + +class PysparkTable(ops.DatabaseTable): + pass diff --git a/ibis/pyspark/tests/test_basic.py b/ibis/pyspark/tests/test_basic.py new file mode 100644 index 0000000000000..0b69f4edb0d92 --- /dev/null +++ b/ibis/pyspark/tests/test_basic.py @@ -0,0 +1,50 @@ +import pandas as pd +import pandas.util.testing as tm +import pytest + +import ibis + +@pytest.fixture(scope='session') +def client(): + client = ibis.pyspark.connect() + df = client._session.range(0, 10) + df.createTempView('table1') + return client + + +def test_basic(client): + table = client.table('table1') + result = table.compile().toPandas() + expected = pd.DataFrame({'id': range(0, 10)}) + + tm.assert_frame_equal(result, expected) + + +def test_projection(client): + import ipdb; ipdb.set_trace() + table = client.table('table1') + result1 = table.mutate(v=table['id']).compile().toPandas() + + expected1 = pd.DataFrame( + { + 'id': range(0, 10), + 'v': range(0, 10) + } + ) + + result2 = table.mutate(v=table['id']).mutate(v2=table['id']).compile().toPandas() + + expected2 = pd.DataFrame( + { + 'id': range(0, 10), + 'v': range(0, 10), + 'v2': range(0, 10) + } + ) + + tm.assert_frame_equal(result1, expected1) + tm.assert_frame_equal(result2, expected2) + + +def test_udf(client): + table = client.table('table1')