From 1f9409bb25f6b9e5a50b4906f6981cf099e15f77 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 13 Aug 2019 14:13:47 -0400 Subject: [PATCH] Change pyspark imports to optional --- ibis/pyspark/tests/test_basic.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ibis/pyspark/tests/test_basic.py b/ibis/pyspark/tests/test_basic.py index 2686a3ef4e91..765502760791 100644 --- a/ibis/pyspark/tests/test_basic.py +++ b/ibis/pyspark/tests/test_basic.py @@ -1,18 +1,18 @@ import pandas as pd import pandas.util.testing as tm -import pyspark.sql.functions as F import pytest -from pyspark.sql import SparkSession -from pyspark.sql.window import Window import ibis pytestmark = pytest.mark.pyspark -pytest.importorskip('pyspark') @pytest.fixture(scope='session') def client(): + pytest.importorskip('pyspark') + from pyspark.sql import SparkSession + import pyspark.sql.functions as F + session = SparkSession.builder.getOrCreate() client = ibis.pyspark.connect(session) df = client._session.range(0, 10) @@ -72,6 +72,8 @@ def test_aggregation_col(client): def test_aggregation(client): + import pyspark.sql.functions as F + table = client.table('table1') result = table.aggregate(table['id'].max()).compile() expected = table.compile().agg(F.max('id')) @@ -80,6 +82,8 @@ def test_aggregation(client): def test_groupby(client): + import pyspark.sql.functions as F + table = client.table('table1') result = table.groupby('id').aggregate(table['id'].max()).compile() expected = table.compile().groupby('id').agg(F.max('id')) @@ -88,6 +92,9 @@ def test_groupby(client): def test_window(client): + import pyspark.sql.functions as F + from pyspark.sql.window import Window + table = client.table('table1') w = ibis.window() result = (