Skip to content

Commit

Permalink
chore: improve local python development (#2252)
Browse files Browse the repository at this point in the history
* chore: improve python development experience

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip
  • Loading branch information
mhamilton723 authored Jul 19, 2024
1 parent 8fb3e0a commit 5a191b5
Show file tree
Hide file tree
Showing 18 changed files with 111 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from langchain.prompts import PromptTemplate
from langchain.llms import AzureOpenAI
from synapse.ml.services.langchain import LangchainTransformer
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)

#######################################################
# this part is to correct a bug in langchain,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
import unittest

from synapse.ml.io.http import *
from synapsemltest.spark import *
from synapse.ml.core.init_spark import *
from pyspark.sql.functions import struct
from pyspark.sql.types import *

from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


class SimpleHTTPTransformerSmokeTest(unittest.TestCase):
def test_simple(self):
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/python/synapse/ml/core/init_spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

from synapse.ml.core import __spark_package_version__


def init_spark():
from pyspark.sql import SparkSession, SQLContext

return (
SparkSession.builder.master("local[*]")
.appName("PysparkTests")
.config(
"spark.jars.packages",
"com.microsoft.azure:synapseml_2.12:"
+ __spark_package_version__
+ ",org.apache.spark:spark-avro_2.12:3.4.1",
)
.config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven")
.config("spark.executor.heartbeatInterval", "60s")
.config("spark.sql.shuffle.partitions", 10)
.config("spark.sql.crossJoin.enabled", "true")
.getOrCreate()
)
6 changes: 5 additions & 1 deletion core/src/test/python/synapsemltest/core/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import logging
from synapse.ml.core.logging.SynapseMLLogger import SynapseMLLogger
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


class SampleTransformer(SynapseMLLogger):
Expand Down
6 changes: 5 additions & 1 deletion core/src/test/python/synapsemltest/core/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from pyspark.sql import types as t, functions as f
import synapse.ml.core.spark.functions as SF
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TemplateSpec(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
)

from synapsemltest.cyber.explain_tester import ExplainTester
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)
epsilon = 10**-3


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from pyspark.sql import DataFrame, types as t, functions as f
from synapse.ml.cyber.anomaly.complement_access import ComplementAccessTransformer
from synapsemltest.cyber.explain_tester import ExplainTester
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TestComplementAccessTransformer(unittest.TestCase):
Expand Down
6 changes: 5 additions & 1 deletion core/src/test/python/synapsemltest/cyber/explain_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from typing import Any, Callable, List

from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


class ExplainTester:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from pyspark.sql import types as t, functions as f
from synapse.ml.cyber.feature import indexers
from synapsemltest.cyber.explain_tester import ExplainTester
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TestIndexers(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from pyspark.sql import functions as f, types as t
from synapse.ml.cyber.feature import LinearScalarScaler, StandardScalarScaler
from synapsemltest.cyber.explain_tester import ExplainTester
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TestScalers(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
HasSetInputCol,
HasSetOutputCol,
)
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


class TestDataFrameUtils(unittest.TestCase):
Expand Down
6 changes: 5 additions & 1 deletion core/src/test/python/synapsemltest/nn/test_ball_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import unittest

from synapse.ml.nn.ConditionalBallTree import ConditionalBallTree
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


class NNSpec(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
# Prepare training and test data.
import unittest

from pyspark.sql import SQLContext
from synapse.ml.recommendation import RankingAdapter
from synapse.ml.recommendation import RankingEvaluator
from synapse.ml.recommendation import RankingTrainValidationSplit
from synapse.ml.recommendation import RecommendationIndexer
from synapse.ml.recommendation import SAR
from synapsemltest.spark import *
from synapse.ml.core.init_spark import *
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import ParamGridBuilder

spark = init_spark()
sc = SQLContext(spark.sparkContext)

USER_ID = "originalCustomerID"
ITEM_ID = "newCategoryID"
RATING_ID = "rating"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,6 @@ object PyTestGen {
if (!dir.exists()) {
dir.mkdirs()
}
writeFile(join(dir, "spark.py"),
s"""
|# Copyright (C) Microsoft Corporation. All rights reserved.
|# Licensed under the MIT License. See LICENSE in project root for information.
|
|from pyspark.sql import SparkSession, SQLContext
|import os
|import synapse.ml
|from synapse.ml.core import __spark_package_version__
|
|spark = (SparkSession.builder
| .master("local[*]")
| .appName("PysparkTests")
| .config("spark.jars.packages", "$SparkMavenPackageList")
| .config("spark.jars.repositories", "$SparkMavenRepositoryList")
| .config("spark.executor.heartbeatInterval", "60s")
| .config("spark.sql.shuffle.partitions", 10)
| .config("spark.sql.crossJoin.enabled", "true")
| .getOrCreate())
|
|sc = SQLContext(spark.sparkContext)
|
|""".stripMargin, StandardOpenOption.CREATE)
}

def main(args: Array[String]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,17 @@ trait PyTestFuzzing[S <: PipelineStage] extends TestBase with DataFrameEquality
val importPathString = importPath.mkString(".").replaceAllLiterally("com.microsoft.azure.synapse.ml", "synapse.ml")
val testClass =
s"""import unittest
|from synapsemltest.spark import *
|from pyspark.sql import SQLContext
|from synapse.ml.core.init_spark import *
|from $importPathString import $stageName
|from os.path import join
|import json
|import mlflow
|from pyspark.ml import PipelineModel
|
|spark = init_spark()
|sc = SQLContext(spark.sparkContext)
|
|test_data_dir = "${pyTestDataDir(conf).toString.replaceAllLiterally("\\", "\\\\")}"
|
|
Expand Down
16 changes: 12 additions & 4 deletions project/CodegenPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ object CodegenPlugin extends AutoPlugin {

val packagePython = TaskKey[Unit]("packagePython", "Package python sdk")
val installPipPackage = TaskKey[Unit]("installPipPackage", "install python sdk")
val removePipPackage = TaskKey[Unit]("removePipPackage",
"remove the installed synapseml pip package from local env")

val publishPython = TaskKey[Unit]("publishPython", "publish python wheel")
val testPython = TaskKey[Unit]("testPython", "test python sdk")
val pyCodegen = TaskKey[Unit]("pyCodegen", "Generate python code")
Expand Down Expand Up @@ -236,17 +239,22 @@ object CodegenPlugin extends AutoPlugin {
FileUtils.copyDirectory(sourcePyDir, destPyDir)
packagePythonWheelCmd(packageDir, pythonSrcDir)
},
removePipPackage := {
runCmd(activateCondaEnv ++ Seq("pip", "uninstall", "-y", name.value))
},
installPipPackage := {
packagePython.value
publishLocal.value
val packagePythonResult: Unit = packagePython.value
val publishLocalResult: Unit = (publishLocal dependsOn packagePython).value
val rootPublishLocalResult: Unit = (LocalRootProject / Compile / publishLocal).value
runCmd(
activateCondaEnv ++ Seq("pip", "install", "-I",
s"${name.value.replace("-", "_")}-${pythonizedVersion(version.value)}-py2.py3-none-any.whl"),
join(codegenDir.value, "package", "python"))
},
publishPython := {
publishLocal.value
packagePython.value
val packagePythonResult: Unit = packagePython.value
val publishLocalResult: Unit = (publishLocal dependsOn packagePython).value
val rootPublishLocalResult: Unit = (LocalRootProject / Compile / publishLocal).value
val fn = s"${name.value.replace("-", "_")}-${pythonizedVersion(version.value)}-py2.py3-none-any.whl"
singleUploadToBlob(
join(codegenDir.value, "package", "python", fn).toString,
Expand Down
6 changes: 5 additions & 1 deletion vw/src/test/python/synapsemltest/vw/test_vw.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from synapse.ml.vw.VowpalWabbitFeaturizer import VowpalWabbitFeaturizer

from pyspark.sql.types import *
from synapsemltest.spark import *
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


class VowpalWabbitSpec(unittest.TestCase):
Expand Down
7 changes: 5 additions & 2 deletions vw/src/test/python/synapsemltest/vw/test_vw_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import tempfile
import pyspark

from synapsemltest.spark import *

from synapse.ml.vw import VowpalWabbitContextualBandit
from synapse.ml.vw import VowpalWabbitFeaturizer
from synapse.ml.vw import VectorZipper
Expand All @@ -16,6 +14,11 @@
from pyspark.ml.wrapper import *
from pyspark.ml.common import inherit_doc, _java2py, _py2java
from pyspark.sql.utils import AnalysisException
from pyspark.sql import SQLContext
from synapse.ml.core.init_spark import *

spark = init_spark()
sc = SQLContext(spark.sparkContext)


def has_column(df, col):
Expand Down

0 comments on commit 5a191b5

Please sign in to comment.