|
8 | 8 | from flaml.tune.spark.utils import check_spark
|
9 | 9 |
|
10 | 10 | warnings.simplefilter(action="ignore")
|
11 |
| -if sys.platform == "darwin" or "nt" in os.name: |
12 |
| - # skip this test if the platform is not linux |
13 |
| - skip_spark = True |
14 |
| -else: |
15 |
| - try: |
16 |
| - import pyspark |
17 |
| - from pyspark.ml.feature import VectorAssembler |
18 |
| - from flaml.automl.spark.utils import to_pandas_on_spark |
19 |
| - |
20 |
| - spark = ( |
21 |
| - pyspark.sql.SparkSession.builder.appName("MyApp") |
22 |
| - .master("local[2]") |
23 |
| - .config( |
24 |
| - "spark.jars.packages", |
25 |
| - ( |
26 |
| - "com.microsoft.azure:synapseml_2.12:0.10.2," |
27 |
| - "org.apache.hadoop:hadoop-azure:3.3.5," |
28 |
| - "com.microsoft.azure:azure-storage:8.6.6," |
29 |
| - f"org.mlflow:mlflow-spark:{mlflow.__version__}" |
30 |
| - ), |
31 |
| - ) |
32 |
| - .config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven") |
33 |
| - .config("spark.sql.debug.maxToStringFields", "100") |
34 |
| - .config("spark.driver.extraJavaOptions", "-Xss1m") |
35 |
| - .config("spark.executor.extraJavaOptions", "-Xss1m") |
36 |
| - .getOrCreate() |
37 |
| - ) |
38 |
| - spark.sparkContext._conf.set( |
39 |
| - "spark.mlflow.pysparkml.autolog.logModelAllowlistFile", |
40 |
| - "https://mmlspark.blob.core.windows.net/publicwasb/log_model_allowlist.txt", |
| 11 | +try: |
| 12 | + import pyspark |
| 13 | + from pyspark.ml.feature import VectorAssembler |
| 14 | + from flaml.automl.spark.utils import to_pandas_on_spark |
| 15 | + |
| 16 | + spark = ( |
| 17 | + pyspark.sql.SparkSession.builder.appName("MyApp") |
| 18 | + .master("local[2]") |
| 19 | + .config( |
| 20 | + "spark.jars.packages", |
| 21 | + ( |
| 22 | + "com.microsoft.azure:synapseml_2.12:0.10.2," |
| 23 | + "org.apache.hadoop:hadoop-azure:3.3.5," |
| 24 | + "com.microsoft.azure:azure-storage:8.6.6," |
| 25 | + f"org.mlflow:mlflow-spark:{mlflow.__version__}" |
| 26 | + ), |
41 | 27 | )
|
42 |
| - # spark.sparkContext.setLogLevel("ERROR") |
43 |
| - spark_available, _ = check_spark() |
44 |
| - skip_spark = not spark_available |
45 |
| - except ImportError: |
46 |
| - skip_spark = True |
| 28 | + .config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven") |
| 29 | + .config("spark.sql.debug.maxToStringFields", "100") |
| 30 | + .config("spark.driver.extraJavaOptions", "-Xss1m") |
| 31 | + .config("spark.executor.extraJavaOptions", "-Xss1m") |
| 32 | + .getOrCreate() |
| 33 | + ) |
| 34 | + spark.sparkContext._conf.set( |
| 35 | + "spark.mlflow.pysparkml.autolog.logModelAllowlistFile", |
| 36 | + "https://mmlspark.blob.core.windows.net/publicwasb/log_model_allowlist.txt", |
| 37 | + ) |
| 38 | + # spark.sparkContext.setLogLevel("ERROR") |
| 39 | + spark_available, _ = check_spark() |
| 40 | + skip_spark = not spark_available |
| 41 | +except ImportError: |
| 42 | + skip_spark = True |
47 | 43 |
|
48 | 44 |
|
49 | 45 | pytestmark = pytest.mark.skipif(skip_spark, reason="Spark is not installed. Skip all spark tests.")
|
|
0 commit comments