diff --git a/core/src/main/python/synapse/ml/core/platform/Platform.py b/core/src/main/python/synapse/ml/core/platform/Platform.py index 2636679d73..48a2f948d2 100644 --- a/core/src/main/python/synapse/ml/core/platform/Platform.py +++ b/core/src/main/python/synapse/ml/core/platform/Platform.py @@ -3,6 +3,7 @@ import os +PLATFORM_SYNAPSE_INTERNAL = "synapse_internal" PLATFORM_SYNAPSE = "synapse" PLATFORM_BINDER = "binder" PLATFORM_DATABRICKS = "databricks" @@ -13,7 +14,14 @@ def current_platform(): if os.environ.get("AZURE_SERVICE", None) == SYNAPSE_PROJECT_NAME: - return PLATFORM_SYNAPSE + from pyspark.sql import SparkSession + + sc = SparkSession.builder.getOrCreate().sparkContext + cluster_type = sc.getConf().get("spark.cluster.type") + if cluster_type == "synapse": + return PLATFORM_SYNAPSE + else: + return PLATFORM_SYNAPSE_INTERNAL elif "dbfs" in os.listdir("/"): return PLATFORM_DATABRICKS elif os.environ.get("BINDER_LAUNCH_HOST", None) is not None: @@ -22,6 +30,10 @@ def current_platform(): return PLATFORM_UNKNOWN +def running_on_synapse_internal(): + return current_platform() is PLATFORM_SYNAPSE_INTERNAL + + def running_on_synapse(): return current_platform() is PLATFORM_SYNAPSE