From 99a2f2e4685a137f64f8242bee0ec44695ede34f Mon Sep 17 00:00:00 2001 From: Li Jin Date: Thu, 22 Aug 2019 15:34:08 -0400 Subject: [PATCH] FEAT: PySpark backend This is a Pyspark backend for ibis. This is different from the spark backend where the ibis expr is compiled to SQL string. Instead, the pyspark backend compiles the ibis expr to pyspark.DataFrame exprs. Author: Li Jin Author: Hyonjee Joo <5000208+hjoo@users.noreply.github.com> Closes #1913 from icexelloss/pyspark-backend-prototype and squashes the following commits: 213e3719 [Li Jin] Add pyspark/__init__.py 8f1c35ea [Li Jin] Address comments f1734252 [Li Jin] Fix tests 0969b0a9 [Li Jin] Skip unimplemented tests 1f9409bb [Li Jin] Change pyspark imports to optional 26b041c7 [Li Jin] Add importskip 108ccd85 [Li Jin] Add scope e00dc000 [Li Jin] Address PR comments 4764a4e7 [Li Jin] Add pyspark marker to setup.cfg 7cc2a9ea [Li Jin] Remove dead code 72b45f87 [Li Jin] Fix rebase errors 9ad663f4 [Hyonjee Joo] implement pyspark numeric operations to pass all/test_numeric.py (#9) 675a89fc [Li Jin] Implement compiler rules to pass all/test_aggregation.py 215c0d99 [Li Jin] Link existing tests with PySpark backend (#7) 88705fe4 [Li Jin] Implement basic join c4a2b79f [Hyonjee Joo] add pyspark compile rule for greatest, fix bug with selection (#4) fa4ad23a [Li Jin] Implement basic aggregation, group_by and window (#3) 54c2f2d8 [Li Jin] Initial commit of pyspark DataFrame backend (#1) --- ibis/__init__.py | 3 + ibis/pyspark/__init__.py | 0 ibis/pyspark/api.py | 18 + ibis/pyspark/client.py | 46 +++ ibis/pyspark/compiler.py | 519 +++++++++++++++++++++++++++++ ibis/pyspark/operations.py | 5 + ibis/pyspark/tests/test_basic.py | 176 ++++++++++ ibis/spark/api.py | 4 +- ibis/spark/client.py | 8 +- ibis/tests/all/conftest.py | 24 +- ibis/tests/all/test_aggregation.py | 3 +- ibis/tests/all/test_client.py | 3 +- ibis/tests/backends.py | 23 ++ setup.cfg | 1 + 14 files changed, 822 insertions(+), 11 deletions(-) create mode 100644 ibis/pyspark/__init__.py create mode 100644 ibis/pyspark/api.py create mode 100644 ibis/pyspark/client.py create mode 100644 ibis/pyspark/compiler.py create mode 100644 ibis/pyspark/operations.py create mode 100644 ibis/pyspark/tests/test_basic.py diff --git a/ibis/__init__.py b/ibis/__init__.py index bd4a73ce00fb..40ec9f46dd77 100644 --- a/ibis/__init__.py +++ b/ibis/__init__.py @@ -57,6 +57,9 @@ # pip install ibis-framework[spark] import ibis.spark.api as spark # noqa: F401 +with suppress(ImportError): + import ibis.pyspark.api as pyspark # noqa: F401 + def hdfs_connect( host='localhost', diff --git a/ibis/pyspark/__init__.py b/ibis/pyspark/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/ibis/pyspark/api.py b/ibis/pyspark/api.py new file mode 100644 index 000000000000..36439c6371da --- /dev/null +++ b/ibis/pyspark/api.py @@ -0,0 +1,18 @@ +from ibis.pyspark.client import PySparkClient + + +def connect(session): + """ + Create a `SparkClient` for use with Ibis. Pipes **kwargs into SparkClient, + which pipes them into SparkContext. See documentation for SparkContext: + https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext + """ + client = PySparkClient(session) + + # Spark internally stores timestamps as UTC values, and timestamp data that + # is brought in without a specified time zone is converted as local time to + # UTC with microsecond resolution. + # https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics + client._session.conf.set('spark.sql.session.timeZone', 'UTC') + + return client diff --git a/ibis/pyspark/client.py b/ibis/pyspark/client.py new file mode 100644 index 000000000000..63062cfcee55 --- /dev/null +++ b/ibis/pyspark/client.py @@ -0,0 +1,46 @@ +from pyspark.sql.column import Column + +import ibis.common.exceptions as com +import ibis.expr.types as types +from ibis.pyspark.compiler import PySparkExprTranslator +from ibis.pyspark.operations import PySparkTable +from ibis.spark.client import SparkClient + + +class PySparkClient(SparkClient): + """ + An ibis client that uses PySpark SQL Dataframe + """ + + table_class = PySparkTable + + def __init__(self, session): + super().__init__(session) + self.translator = PySparkExprTranslator() + + def compile(self, expr, *args, **kwargs): + """Compile an ibis expression to a PySpark DataFrame object + """ + return self.translator.translate(expr, scope={}) + + def execute(self, expr, params=None, limit='default', **kwargs): + if isinstance(expr, types.TableExpr): + return self.compile(expr).toPandas() + elif isinstance(expr, types.ColumnExpr): + # expression must be named for the projection + expr = expr.name('tmp') + return self.compile(expr.to_projection()).toPandas()['tmp'] + elif isinstance(expr, types.ScalarExpr): + compiled = self.compile(expr) + if isinstance(compiled, Column): + # attach result column to a fake DataFrame and + # select the result + compiled = self._session.range(0, 1).select(compiled) + return compiled.toPandas().iloc[0, 0] + else: + raise com.IbisError( + "Cannot execute expression of type: {}".format(type(expr))) + + def sql(self, query): + raise NotImplementedError( + "PySpark backend doesn't support sql query") diff --git a/ibis/pyspark/compiler.py b/ibis/pyspark/compiler.py new file mode 100644 index 000000000000..72e0028318ce --- /dev/null +++ b/ibis/pyspark/compiler.py @@ -0,0 +1,519 @@ +import collections +import enum +import functools + +import pyspark.sql.functions as F + +import ibis.common.exceptions as com +import ibis.expr.operations as ops +import ibis.expr.types as types +from ibis.pyspark.operations import PySparkTable + + +class AggregationContext(enum.Enum): + ENTIRE = 0 + WINDOW = 1 + GROUP = 2 + + +class PySparkExprTranslator: + _registry = {} + + @classmethod + def compiles(cls, klass): + def decorator(f): + cls._registry[klass] = f + return f + + return decorator + + def translate(self, expr, scope, **kwargs): + # The operation node type the typed expression wraps + op = expr.op() + + if type(op) in self._registry: + formatter = self._registry[type(op)] + return formatter(self, expr, scope, **kwargs) + else: + raise com.OperationNotDefinedError( + 'No translation rule for {}'.format(type(op)) + ) + + +compiles = PySparkExprTranslator.compiles + + +def compile_with_scope(t, expr, scope): + """Compile a expression and put the result in scope. + + If the expression is already in scope, return it. + """ + op = expr.op() + + if op in scope: + result = scope[op] + else: + result = t.translate(expr, scope) + scope[op] = result + + return result + + +@compiles(PySparkTable) +def compile_datasource(t, expr, scope): + op = expr.op() + name, _, client = op.args + return client._session.table(name) + + +@compiles(ops.Selection) +def compile_selection(t, expr, scope, **kwargs): + # Cache compile results for tables + op = expr.op() + + # TODO: Support predicates and sort_keys + if op.predicates or op.sort_keys: + raise NotImplementedError( + "predicates and sort_keys are not supported with Selection") + + src_table = compile_with_scope(t, op.table, scope) + col_names_in_selection_order = [] + + for selection in op.selections: + if isinstance(selection, types.TableExpr): + col_names_in_selection_order.extend(selection.columns) + elif isinstance(selection, types.ColumnExpr): + column_name = selection.get_name() + col_names_in_selection_order.append(column_name) + column = t.translate(selection, scope=scope) + src_table = src_table.withColumn(column_name, column) + + return src_table[col_names_in_selection_order] + + +@compiles(ops.TableColumn) +def compile_column(t, expr, scope, **kwargs): + op = expr.op() + table = compile_with_scope(t, op.table, scope) + return table[op.name] + + +@compiles(ops.SelfReference) +def compile_self_reference(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.table, scope) + + +@compiles(ops.Equals) +def compile_equals(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.left, scope) == t.translate(op.right, scope) + + +@compiles(ops.Greater) +def compile_greater(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.left, scope) > t.translate(op.right, scope) + + +@compiles(ops.GreaterEqual) +def compile_greater_equal(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.left, scope) >= t.translate(op.right, scope) + + +@compiles(ops.Multiply) +def compile_multiply(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.left, scope) * t.translate(op.right, scope) + + +@compiles(ops.Subtract) +def compile_subtract(t, expr, scope, **kwargs): + op = expr.op() + return t.translate(op.left, scope) - t.translate(op.right, scope) + + +@compiles(ops.Literal) +def compile_literal(t, expr, scope, raw=False, **kwargs): + """ If raw is True, don't wrap the result with F.lit() + """ + value = expr.op().value + + if raw: + return value + + if isinstance(value, collections.abc.Set): + # Don't wrap set with F.lit + if isinstance(value, frozenset): + # Spark doens't like frozenset + return set(value) + else: + return value + else: + return F.lit(expr.op().value) + + +@compiles(ops.Aggregation) +def compile_aggregation(t, expr, scope, **kwargs): + op = expr.op() + + src_table = t.translate(op.table, scope) + + if op.by: + context = AggregationContext.GROUP + aggs = [t.translate(m, scope, context=context) + for m in op.metrics] + bys = [t.translate(b, scope) for b in op.by] + return src_table.groupby(*bys).agg(*aggs) + else: + context = AggregationContext.ENTIRE + aggs = [t.translate(m, scope, context=context) + for m in op.metrics] + return src_table.agg(*aggs) + + +@compiles(ops.Contains) +def compile_contains(t, expr, scope, **kwargs): + op = expr.op() + col = t.translate(op.value, scope) + return col.isin(t.translate(op.options, scope)) + + +def compile_aggregator(t, expr, scope, fn, context=None, **kwargs): + op = expr.op() + src_col = t.translate(op.arg, scope) + + if getattr(op, "where", None) is not None: + condition = t.translate(op.where, scope) + src_col = F.when(condition, src_col) + + col = fn(src_col) + if context: + return col + else: + # We are trying to compile a expr such as some_col.max() + # to a Spark expression. + # Here we get the root table df of that column and compile + # the expr to: + # df.select(max(some_col)) + return t.translate(expr.op().arg.op().table, scope).select(col) + + +@compiles(ops.GroupConcat) +def compile_group_concat(t, expr, scope, context=None, **kwargs): + sep = expr.op().sep.op().value + + def fn(col): + return F.concat_ws(sep, F.collect_list(col)) + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.Any) +def compile_any(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.max, context) + + +@compiles(ops.NotAny) +def compile_notany(t, expr, scope, context=None, **kwargs): + + def fn(col): + return ~F.max(col) + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.All) +def compile_all(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.min, context) + + +@compiles(ops.NotAll) +def compile_notall(t, expr, scope, context=None, **kwargs): + + def fn(col): + return ~F.min(col) + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.Count) +def compile_count(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.count, context) + + +@compiles(ops.Max) +def compile_max(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.max, context) + + +@compiles(ops.Min) +def compile_min(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.min, context) + + +@compiles(ops.Mean) +def compile_mean(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.mean, context) + + +@compiles(ops.Sum) +def compile_sum(t, expr, scope, context=None, **kwargs): + return compile_aggregator(t, expr, scope, F.sum, context) + + +@compiles(ops.StandardDev) +def compile_std(t, expr, scope, context=None, **kwargs): + how = expr.op().how + + if how == 'sample': + fn = F.stddev_samp + elif how == 'pop': + fn = F.stddev_pop + else: + raise com.TranslationError( + "Unexpected 'how' in translation: {}" + .format(how) + ) + + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.Variance) +def compile_variance(t, expr, scope, context=None, **kwargs): + how = expr.op().how + + if how == 'sample': + fn = F.var_samp + elif how == 'pop': + fn = F.var_pop + else: + raise com.TranslationError( + "Unexpected 'how' in translation: {}" + .format(how) + ) + + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.Arbitrary) +def compile_arbitrary(t, expr, scope, context=None, **kwargs): + how = expr.op().how + + if how == 'first': + fn = functools.partial(F.first, ignorenulls=True) + elif how == 'last': + fn = functools.partial(F.last, ignorenulls=True) + else: + raise NotImplementedError( + "Does not support 'how': {}".format(how) + ) + + return compile_aggregator(t, expr, scope, fn, context) + + +@compiles(ops.Greatest) +def compile_greatest(t, expr, scope, **kwargs): + op = expr.op() + + src_columns = t.translate(op.arg, scope) + if len(src_columns) == 1: + return src_columns[0] + else: + return F.greatest(*src_columns) + + +@compiles(ops.Least) +def compile_least(t, expr, scope, **kwargs): + op = expr.op() + + src_columns = t.translate(op.arg, scope) + if len(src_columns) == 1: + return src_columns[0] + else: + return F.least(*src_columns) + + +@compiles(ops.Abs) +def compile_abs(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.abs(src_column) + + +@compiles(ops.Round) +def compile_round(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + scale = (t.translate(op.digits, scope, raw=True) + if op.digits is not None else 0) + rounded = F.round(src_column, scale=scale) + if scale == 0: + rounded = rounded.astype('long') + return rounded + + +@compiles(ops.Ceil) +def compile_ceil(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.ceil(src_column) + + +@compiles(ops.Floor) +def compile_floor(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.floor(src_column) + + +@compiles(ops.Exp) +def compile_exp(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.exp(src_column) + + +@compiles(ops.Sign) +def compile_sign(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + + return F.when(src_column == 0, F.lit(0.0)) \ + .otherwise(F.when(src_column > 0, F.lit(1.0)).otherwise(-1.0)) + + +@compiles(ops.Sqrt) +def compile_sqrt(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.sqrt(src_column) + + +@compiles(ops.Log) +def compile_log(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + # Spark log method only takes float + return F.log(float(t.translate(op.base, scope, raw=True)), src_column) + + +@compiles(ops.Ln) +def compile_ln(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.log(src_column) + + +@compiles(ops.Log2) +def compile_log2(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.log2(src_column) + + +@compiles(ops.Log10) +def compile_log10(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.log10(src_column) + + +@compiles(ops.Modulus) +def compile_modulus(t, expr, scope, **kwargs): + op = expr.op() + + left = t.translate(op.left, scope) + right = t.translate(op.right, scope) + return left % right + + +@compiles(ops.Negate) +def compile_negate(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return -src_column + + +@compiles(ops.Add) +def compile_add(t, expr, scope, **kwargs): + op = expr.op() + + left = t.translate(op.left, scope) + right = t.translate(op.right, scope) + return left + right + + +@compiles(ops.Divide) +def compile_divide(t, expr, scope, **kwargs): + op = expr.op() + + left = t.translate(op.left, scope) + right = t.translate(op.right, scope) + return left / right + + +@compiles(ops.FloorDivide) +def compile_floor_divide(t, expr, scope, **kwargs): + op = expr.op() + + left = t.translate(op.left, scope) + right = t.translate(op.right, scope) + return F.floor(left / right) + + +@compiles(ops.Power) +def compile_power(t, expr, scope, **kwargs): + op = expr.op() + + left = t.translate(op.left, scope) + right = t.translate(op.right, scope) + return F.pow(left, right) + + +@compiles(ops.IsNan) +def compile_isnan(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return F.isnan(src_column) + + +@compiles(ops.IsInf) +def compile_isinf(t, expr, scope, **kwargs): + op = expr.op() + + src_column = t.translate(op.arg, scope) + return (src_column == float('inf')) | (src_column == float('-inf')) + + +@compiles(ops.ValueList) +def compile_value_list(t, expr, scope, **kwargs): + op = expr.op() + return [t.translate(col, scope) for col in op.values] + + +@compiles(ops.InnerJoin) +def compile_inner_join(t, expr, scope, **kwargs): + return compile_join(t, expr, scope, 'inner') + + +def compile_join(t, expr, scope, how): + op = expr.op() + + left_df = t.translate(op.left, scope) + right_df = t.translate(op.right, scope) + # TODO: Handle multiple predicates + predicates = t.translate(op.predicates[0], scope) + + return left_df.join(right_df, predicates, how) diff --git a/ibis/pyspark/operations.py b/ibis/pyspark/operations.py new file mode 100644 index 000000000000..6491c4e058c1 --- /dev/null +++ b/ibis/pyspark/operations.py @@ -0,0 +1,5 @@ +import ibis.expr.operations as ops + + +class PySparkTable(ops.DatabaseTable): + pass diff --git a/ibis/pyspark/tests/test_basic.py b/ibis/pyspark/tests/test_basic.py new file mode 100644 index 000000000000..b83612d12d34 --- /dev/null +++ b/ibis/pyspark/tests/test_basic.py @@ -0,0 +1,176 @@ +import pandas as pd +import pandas.util.testing as tm +import pytest + +import ibis +import ibis.common.exceptions as comm + +pytest.importorskip('pyspark') +pytestmark = pytest.mark.pyspark + + +@pytest.fixture(scope='session') +def client(): + from pyspark.sql import SparkSession + import pyspark.sql.functions as F + + session = SparkSession.builder.getOrCreate() + client = ibis.pyspark.connect(session) + df = client._session.range(0, 10) + df = df.withColumn("str_col", F.lit('value')) + df.createTempView('table1') + + df1 = client._session.createDataFrame([(True,), (False,)]).toDF('v') + df1.createTempView('table2') + return client + + +def test_basic(client): + table = client.table('table1') + result = table.compile().toPandas() + expected = pd.DataFrame({'id': range(0, 10), 'str_col': 'value'}) + + tm.assert_frame_equal(result, expected) + + +def test_projection(client): + table = client.table('table1') + result1 = table.mutate(v=table['id']).compile().toPandas() + + expected1 = pd.DataFrame( + { + 'id': range(0, 10), + 'str_col': 'value', + 'v': range(0, 10), + } + ) + + result2 = ( + table + .mutate(v=table['id']) + .mutate(v2=table['id']) + .mutate(id=table['id'] * 2) + .compile().toPandas() + ) + + expected2 = pd.DataFrame( + { + 'id': range(0, 20, 2), + 'str_col': 'value', + 'v': range(0, 10), + 'v2': range(0, 10), + } + ) + + tm.assert_frame_equal(result1, expected1) + tm.assert_frame_equal(result2, expected2) + + +def test_aggregation_col(client): + table = client.table('table1') + result = table['id'].count().execute() + assert result == table.compile().count() + + +def test_aggregation(client): + import pyspark.sql.functions as F + + table = client.table('table1') + result = table.aggregate(table['id'].max()).compile() + expected = table.compile().agg(F.max('id')) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + + +def test_groupby(client): + import pyspark.sql.functions as F + + table = client.table('table1') + result = table.groupby('id').aggregate(table['id'].max()).compile() + expected = table.compile().groupby('id').agg(F.max('id')) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + + +@pytest.mark.xfail( + reason='This is not implemented yet', + raises=comm.OperationNotDefinedError +) +def test_window(client): + import pyspark.sql.functions as F + from pyspark.sql.window import Window + + table = client.table('table1') + w = ibis.window() + result = ( + table + .mutate( + grouped_demeaned=table['id'] - table['id'].mean().over(w)) + .compile() + ) + result2 = ( + table + .groupby('id') + .mutate( + grouped_demeaned=table['id'] - table['id'].mean()) + .compile() + ) + + spark_window = Window.partitionBy() + spark_table = table.compile() + expected = spark_table.withColumn( + 'grouped_demeaned', + spark_table['id'] - F.mean(spark_table['id']).over(spark_window) + ) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + tm.assert_frame_equal(result2.toPandas(), expected.toPandas()) + + +def test_greatest(client): + table = client.table('table1') + result = ( + table + .mutate(greatest=ibis.greatest(table.id)) + .compile() + ) + df = table.compile() + expected = table.compile().withColumn('greatest', df.id) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) + + +def test_selection(client): + table = client.table('table1') + table = table.mutate(id2=table['id'] * 2) + + result1 = table[['id']].compile() + result2 = table[['id', 'id2']].compile() + result3 = table[[table, (table.id + 1).name('plus1')]].compile() + result4 = table[[(table.id + 1).name('plus1'), table]].compile() + + df = table.compile() + tm.assert_frame_equal(result1.toPandas(), df[['id']].toPandas()) + tm.assert_frame_equal(result2.toPandas(), df[['id', 'id2']].toPandas()) + tm.assert_frame_equal(result3.toPandas(), + df[[df.columns]].withColumn('plus1', df.id + 1) + .toPandas()) + tm.assert_frame_equal(result4.toPandas(), + df.withColumn('plus1', df.id + 1) + [['plus1', *df.columns]].toPandas()) + + +@pytest.mark.xfail( + reason='Join is not fully implemented', + raises=AssertionError +) +def test_join(client): + table = client.table('table1') + result = table.join(table, ['id', 'str_col']).compile() + spark_table = table.compile() + expected = ( + spark_table + .join(spark_table, ['id', 'str_col']) + ) + + tm.assert_frame_equal(result.toPandas(), expected.toPandas()) diff --git a/ibis/spark/api.py b/ibis/spark/api.py index 284de775efeb..5cbf92bde083 100644 --- a/ibis/spark/api.py +++ b/ibis/spark/api.py @@ -28,13 +28,13 @@ def verify(expr, params=None): return False -def connect(**kwargs): +def connect(spark_session): """ Create a `SparkClient` for use with Ibis. Pipes **kwargs into SparkClient, which pipes them into SparkContext. See documentation for SparkContext: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/context.html#SparkContext """ - client = SparkClient(**kwargs) + client = SparkClient(spark_session) # Spark internally stores timestamps as UTC values, and timestamp data that # is brought in without a specified time zone is converted as local time to diff --git a/ibis/spark/client.py b/ibis/spark/client.py index 895c78788469..2056c8a8c1a4 100644 --- a/ibis/spark/client.py +++ b/ibis/spark/client.py @@ -294,10 +294,10 @@ class SparkClient(SQLClient): table_class = SparkDatabaseTable table_expr_class = SparkTable - def __init__(self, **kwargs): - self._context = ps.SparkContext(**kwargs) - self._session = ps.sql.SparkSession(self._context) - self._catalog = self._session.catalog + def __init__(self, session): + self._context = session.sparkContext + self._session = session + self._catalog = session.catalog def close(self): """ diff --git a/ibis/tests/all/conftest.py b/ibis/tests/all/conftest.py index 542d1ef48e2f..89ee85c71276 100644 --- a/ibis/tests/all/conftest.py +++ b/ibis/tests/all/conftest.py @@ -197,18 +197,36 @@ def geo_df(geo): _spark_testing_client = None +_pyspark_testing_client = None def get_spark_testing_client(data_directory): global _spark_testing_client + if _spark_testing_client is None: + _spark_testing_client = get_common_spark_testing_client( + data_directory, + lambda session: ibis.spark.connect(session) + ) + return _spark_testing_client + + +def get_pyspark_testing_client(data_directory): + global _pyspark_testing_client + if _pyspark_testing_client is None: + _pyspark_testing_client = get_common_spark_testing_client( + data_directory, + lambda session: ibis.pyspark.connect(session) + ) + return _pyspark_testing_client - if _spark_testing_client is not None: - return _spark_testing_client +def get_common_spark_testing_client(data_directory, connect): pytest.importorskip('pyspark') import pyspark.sql.types as pt + from pyspark.sql import SparkSession - _spark_testing_client = ibis.spark.connect() + spark = SparkSession.builder.getOrCreate() + _spark_testing_client = connect(spark) s = _spark_testing_client._session df_functional_alltypes = s.read.csv( diff --git a/ibis/tests/all/test_aggregation.py b/ibis/tests/all/test_aggregation.py index e13af9e5915b..a8aee75b699d 100644 --- a/ibis/tests/all/test_aggregation.py +++ b/ibis/tests/all/test_aggregation.py @@ -154,4 +154,5 @@ def test_group_concat(backend, alltypes, df, result_fn, expected_fn): expr = result_fn(alltypes) result = expr.execute() expected = expected_fn(df) - assert set(result) == set(expected) + + assert set(result.iloc[:, 1]) == set(expected.iloc[:, 1]) diff --git a/ibis/tests/all/test_client.py b/ibis/tests/all/test_client.py index c7bca9da848d..1bd7d3193b3a 100644 --- a/ibis/tests/all/test_client.py +++ b/ibis/tests/all/test_client.py @@ -3,7 +3,7 @@ import ibis import ibis.expr.datatypes as dt -from ibis.tests.backends import BigQuery +from ibis.tests.backends import BigQuery, PySpark @pytest.mark.xfail_unsupported @@ -25,6 +25,7 @@ def test_version(backend, con): ), ], ) +@pytest.mark.xfail_backends((PySpark,)) def test_query_schema(backend, con, alltypes, expr_fn, expected): if not hasattr(con, '_build_ast'): pytest.skip( diff --git a/ibis/tests/backends.py b/ibis/tests/backends.py index 6bea2074b7da..b9b87034e01d 100644 --- a/ibis/tests/backends.py +++ b/ibis/tests/backends.py @@ -549,3 +549,26 @@ def batting(self) -> ir.TableExpr: @property def awards_players(self) -> ir.TableExpr: return self.connection.table('awards_players') + + +class PySpark(Backend, RoundAwayFromZero): + @staticmethod + def skip_if_missing_dependencies() -> None: + pytest.importorskip('pyspark') + + @staticmethod + def connect(data_directory): + from ibis.tests.all.conftest import get_pyspark_testing_client + return get_pyspark_testing_client(data_directory) + + @property + def functional_alltypes(self) -> ir.TableExpr: + return self.connection.table('functional_alltypes') + + @property + def batting(self) -> ir.TableExpr: + return self.connection.table('batting') + + @property + def awards_players(self) -> ir.TableExpr: + return self.connection.table('awards_players') diff --git a/setup.cfg b/setup.cfg index 3d9461987eff..124f3c83b7c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,6 +48,7 @@ markers = postgis postgresql postgres_extensions + pyspark skip_backends skip_missing_feature spark