From e6b5a90352b7456333df92dc9f7755b6cb8f300b Mon Sep 17 00:00:00 2001
From: Serena Ruan <82044803+serena-ruan@users.noreply.github.com>
Date: Tue, 22 Nov 2022 07:34:37 +0800
Subject: [PATCH] feat: add simple deep learning text classifier (#1591)
* refactor deep vision model params
* add Text Classifier and tests
* update text classifier
* add default values for transformation edit fields and removed fields
* add deep text classification notebook
* update hovorod installation script
* update environment
* update env
* add installing packages on dbx
* fix python environment
* add more tests
* skip deep text model test
* address comments
* add learning_rate param
* fix notebook style
* fix _find_num_proc
* update newtonsoft.json version in dotnet to resolve security issue
* fix missing learning rate param
* fix dataframe partition error and strange read output
* fix failing test
* fix param updates
* update notebook to make it run faster
* make train data size smaller
* update models
* remove minor version of python to avoid warnings in pipeline
* ignore dl notebooks in Synapse tests
* update mardown name
---
.../synapse/ml/codegen/DotnetCodegen.scala | 4 +
.../azure/synapse/ml/codegen/PyCodegen.scala | 6 +-
.../ml/nbtest/DatabricksGPUTests.scala | 2 +-
.../ml/nbtest/DatabricksUtilities.scala | 6 +-
.../synapse/ml/nbtest/SynapseTests.scala | 3 +-
.../src/main/python/horovod_installation.sh | 11 +-
.../synapse/ml/dl/DeepTextClassifier.py | 290 ++++++++++++++++++
.../python/synapse/ml/dl/DeepTextModel.py | 119 +++++++
.../synapse/ml/dl/DeepVisionClassifier.py | 62 ++--
.../python/synapse/ml/dl/DeepVisionModel.py | 32 +-
.../python/synapse/ml/dl/LitDeepTextModel.py | 176 +++++++++++
.../python/synapse/ml/dl/PredictionParams.py | 87 ++++--
.../src/main/python/synapse/ml/dl/__init__.py | 3 +
.../src/main/python/synapse/ml/dl/utils.py | 30 ++
.../test/python/synapsemltest/dl/conftest.py | 27 ++
.../dl/test_deep_text_classifier.py | 55 ++++
.../synapsemltest/dl/test_deep_text_model.py | 82 +++++
.../dl/test_deep_vision_classifier.py | 9 +-
environment.yml | 2 +
...pLearning - Deep Text Classification.ipynb | 256 ++++++++++++++++
...arning - Deep Vision Classification.ipynb} | 0
templates/update_cli.yml | 2 +-
.../features/simple_deep_learning/about.md | 4 +-
23 files changed, 1169 insertions(+), 99 deletions(-)
create mode 100644 deep-learning/src/main/python/synapse/ml/dl/DeepTextClassifier.py
create mode 100644 deep-learning/src/main/python/synapse/ml/dl/DeepTextModel.py
create mode 100644 deep-learning/src/main/python/synapse/ml/dl/LitDeepTextModel.py
create mode 100644 deep-learning/src/test/python/synapsemltest/dl/test_deep_text_classifier.py
create mode 100644 deep-learning/src/test/python/synapsemltest/dl/test_deep_text_model.py
create mode 100644 notebooks/features/simple_deep_learning/DeepLearning - Deep Text Classification.ipynb
rename notebooks/features/simple_deep_learning/{DeepLearning - Deep Vision Classifier.ipynb => DeepLearning - Deep Vision Classification.ipynb} (100%)
diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/DotnetCodegen.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/DotnetCodegen.scala
index f687bc73b4..d8c455b98c 100644
--- a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/DotnetCodegen.scala
+++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/DotnetCodegen.scala
@@ -34,6 +34,9 @@ object DotnetCodegen {
if (!projectDir.exists()){
projectDir.mkdirs()
}
+ val newtonsoftDep = if(curProject == "DeepLearning") {
+ s"""""".stripMargin
+ } else ""
// TODO: update SynapseML.DotnetBase version whenever we upload a new one
writeFile(new File(projectDir, s"${curProject}ProjectSetup.csproj"),
s"""
@@ -52,6 +55,7 @@ object DotnetCodegen {
|
|
|
+ | $newtonsoftDep
|
|
|
diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala
index 8aaa28f8d3..b348bc0e48 100644
--- a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala
+++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala
@@ -57,12 +57,16 @@ object PyCodegen {
s"""MINIMUM_SUPPORTED_PYTHON_VERSION = "3.8"""".stripMargin
} else ""
val extraRequirements = if (conf.name.contains("deep-learning")) {
+ // There's `Already borrowed` error found in transformers 4.16.2 when using tokenizers
s"""extras_require={"extras": [
| "cmake",
| "horovod==0.25.0",
| "pytorch_lightning>=1.5.0,<1.5.10",
| "torch==1.11.0",
- | "torchvision>=0.12.0"
+ | "torchvision>=0.12.0",
+ | "transformers==4.15.0",
+ | "petastorm>=0.12.0",
+ | "huggingface-hub>=0.8.1",
|]},
|python_requires=f">={MINIMUM_SUPPORTED_PYTHON_VERSION}",""".stripMargin
} else ""
diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala
index 8b91be9fea..be308c7af7 100644
--- a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala
+++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala
@@ -14,7 +14,7 @@ class DatabricksGPUTests extends DatabricksTestHelper {
val horovodInstallationScript: File = FileUtilities.join(
BuildInfo.baseDirectory.getParent, "deep-learning",
"src", "main", "python", "horovod_installation.sh").getCanonicalFile
- uploadFileToDBFS(horovodInstallationScript, "/FileStore/horovod/horovod_installation.sh")
+ uploadFileToDBFS(horovodInstallationScript, "/FileStore/horovod-fix-commit/horovod_installation.sh")
val clusterId: String = createClusterInPool(GPUClusterName, AdbGpuRuntime, 2, GpuPoolId, GPUInitScripts)
val jobIdsToCancel: ListBuffer[Int] = databricksTestHelper(
clusterId, GPULibraries, GPUNotebooks)
diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala
index 22bccd1e04..54dea3efb4 100644
--- a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala
+++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala
@@ -61,11 +61,13 @@ object DatabricksUtilities {
// TODO: install synapse.ml.dl wheel package here
val GPULibraries: String = List(
- Map("maven" -> Map("coordinates" -> PackageMavenCoordinate, "repo" -> PackageRepository))
+ Map("maven" -> Map("coordinates" -> PackageMavenCoordinate, "repo" -> PackageRepository)),
+ Map("pypi" -> Map("package" -> "transformers==4.15.0")),
+ Map("pypi" -> Map("package" -> "petastorm==0.12.0"))
).toJson.compactPrint
val GPUInitScripts: String = List(
- Map("dbfs" -> Map("destination" -> "dbfs:/FileStore/horovod/horovod_installation.sh"))
+ Map("dbfs" -> Map("destination" -> "dbfs:/FileStore/horovod-fix-commit/horovod_installation.sh"))
).toJson.compactPrint
// Execution Params
diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/SynapseTests.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/SynapseTests.scala
index 734746a3ca..f89b97e624 100644
--- a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/SynapseTests.scala
+++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/SynapseTests.scala
@@ -46,7 +46,8 @@ class SynapseTests extends TestBase {
.filter(_.getAbsolutePath.endsWith(".py"))
.filterNot(_.getAbsolutePath.contains("HyperParameterTuning"))
.filterNot(_.getAbsolutePath.contains("IsolationForest"))
- .filterNot(_.getAbsolutePath.contains("DeepLearningDeepVisionClassifier"))
+ .filterNot(_.getAbsolutePath.contains("DeepLearningDeepTextClassification"))
+ .filterNot(_.getAbsolutePath.contains("DeepLearningDeepVisionClassification"))
.filterNot(_.getAbsolutePath.contains("Interpretability")) //add more exclusion TODO: Remove when fixed
.sortBy(_.getAbsolutePath)
diff --git a/deep-learning/src/main/python/horovod_installation.sh b/deep-learning/src/main/python/horovod_installation.sh
index 822792fc38..758b4993de 100644
--- a/deep-learning/src/main/python/horovod_installation.sh
+++ b/deep-learning/src/main/python/horovod_installation.sh
@@ -8,6 +8,8 @@ set -eu
# Install prerequisite libraries that horovod depends on
pip install pytorch-lightning==1.5.0
pip install torchvision==0.12.0
+pip install transformers==4.15.0
+pip install petastorm>=0.12.0
# Remove Outdated Signing Key:
sudo apt-key del 7fa2af80
@@ -32,9 +34,12 @@ libcusparse-dev-11-0=11.1.1.245-1
git clone --recursive https://github.com/horovod/horovod.git
cd horovod
-# fix version 0.25.0
-git fetch origin refs/tags/v0.25.0:tags/v0.25.0
-git checkout tags/v0.25.0 -b v0.25.0-branch
+# # fix version 0.25.0
+# git fetch origin refs/tags/v0.25.0:tags/v0.25.0
+# git checkout tags/v0.25.0 -b v0.25.0-branch
+# fix to this commit number until they release a new version
+git checkout ab97fd15bbba3258adcdd12983f36a1cdeacbc94
+git checkout -b tmp-branch
rm -rf build/ dist/
HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_CUDA_HOME=/usr/local/cuda-11/ HOROVOD_WITH_PYTORCH=1 HOROVOD_WITHOUT_MXNET=1 \
/databricks/python3/bin/python setup.py bdist_wheel
diff --git a/deep-learning/src/main/python/synapse/ml/dl/DeepTextClassifier.py b/deep-learning/src/main/python/synapse/ml/dl/DeepTextClassifier.py
new file mode 100644
index 0000000000..0702fc828b
--- /dev/null
+++ b/deep-learning/src/main/python/synapse/ml/dl/DeepTextClassifier.py
@@ -0,0 +1,290 @@
+from horovod.spark.lightning import TorchEstimator
+import torch
+from pyspark.ml.param.shared import Param, Params
+from pytorch_lightning.utilities import _module_available
+from synapse.ml.dl.DeepTextModel import DeepTextModel
+from synapse.ml.dl.LitDeepTextModel import LitDeepTextModel
+from synapse.ml.dl.utils import keywords_catch, get_or_create_backend
+from synapse.ml.dl.PredictionParams import TextPredictionParams
+
+_TRANSFORMERS_AVAILABLE = _module_available("transformers")
+if _TRANSFORMERS_AVAILABLE:
+ import transformers
+
+ _TRANSFORMERS_EQUAL_4_15_0 = transformers.__version__ == "4.15.0"
+ if _TRANSFORMERS_EQUAL_4_15_0:
+ from transformers import AutoTokenizer
+ else:
+ raise RuntimeError(
+ "transformers should be == 4.15.0, found: {}".format(
+ transformers.__version__
+ )
+ )
+else:
+ raise ModuleNotFoundError("module not found: transformers")
+
+
+class DeepTextClassifier(TorchEstimator, TextPredictionParams):
+
+ checkpoint = Param(
+ Params._dummy(), "checkpoint", "checkpoint of the deep text classifier"
+ )
+
+ additional_layers_to_train = Param(
+ Params._dummy(),
+ "additional_layers_to_train",
+ "number of last layers to fine tune for the model, should be larger or equal to 0. default to 3.",
+ )
+
+ num_classes = Param(Params._dummy(), "num_classes", "number of target classes")
+
+ loss_name = Param(
+ Params._dummy(),
+ "loss_name",
+ "string representation of torch.nn.functional loss function for the underlying pytorch_lightning model, e.g. binary_cross_entropy",
+ )
+
+ optimizer_name = Param(
+ Params._dummy(),
+ "optimizer_name",
+ "string representation of optimizer function for the underlying pytorch_lightning model",
+ )
+
+ tokenizer = Param(Params._dummy(), "tokenizer", "tokenizer")
+
+ max_token_len = Param(Params._dummy(), "max_token_len", "max_token_len")
+
+ learning_rate = Param(
+ Params._dummy(), "learning_rate", "learning rate to be used for the optimizer"
+ )
+
+ train_from_scratch = Param(
+ Params._dummy(),
+ "train_from_scratch",
+ "whether to train the model from scratch or not, if set to False then param additional_layers_to_train need to be specified.",
+ )
+
+ @keywords_catch
+ def __init__(
+ self,
+ checkpoint=None,
+ additional_layers_to_train=3, # this is needed otherwise the performance is usually bad
+ num_classes=None,
+ optimizer_name="adam",
+ loss_name="cross_entropy",
+ tokenizer=None,
+ max_token_len=128,
+ learning_rate=None,
+ train_from_scratch=True,
+ # Classifier args
+ label_col="label",
+ text_col="text",
+ prediction_col="prediction",
+ # TorchEstimator args
+ num_proc=None,
+ backend=None,
+ store=None,
+ metrics=None,
+ loss_weights=None,
+ sample_weight_col=None,
+ gradient_compression=None,
+ input_shapes=None,
+ validation=None,
+ callbacks=None,
+ batch_size=None,
+ val_batch_size=None,
+ epochs=None,
+ verbose=1,
+ random_seed=None,
+ shuffle_buffer_size=None,
+ partitions_per_process=None,
+ run_id=None,
+ train_minibatch_fn=None,
+ train_steps_per_epoch=None,
+ validation_steps_per_epoch=None,
+ transformation_fn=None,
+ transformation_edit_fields=None,
+ transformation_removed_fields=None,
+ train_reader_num_workers=None,
+ trainer_args=None,
+ val_reader_num_workers=None,
+ reader_pool_type=None,
+ label_shapes=None,
+ inmemory_cache_all=False,
+ num_gpus=None,
+ logger=None,
+ log_every_n_steps=50,
+ data_module=None,
+ loader_num_epochs=None,
+ terminate_on_nan=False,
+ profiler=None,
+ debug_data_loader=False,
+ train_async_data_loader_queue_size=None,
+ val_async_data_loader_queue_size=None,
+ use_gpu=True,
+ mp_start_method=None,
+ ):
+ super(DeepTextClassifier, self).__init__()
+
+ self._setDefault(
+ checkpoint=None,
+ additional_layers_to_train=3,
+ num_classes=None,
+ optimizer_name="adam",
+ loss_name="cross_entropy",
+ tokenizer=None,
+ max_token_len=128,
+ learning_rate=None,
+ train_from_scratch=True,
+ feature_cols=["text"],
+ label_cols=["label"],
+ label_col="label",
+ text_col="text",
+ prediction_col="prediction",
+ )
+
+ kwargs = self._kwargs
+ self._set(**kwargs)
+
+ self._update_cols()
+ self._update_transformation_fn()
+
+ model = LitDeepTextModel(
+ checkpoint=self.getCheckpoint(),
+ additional_layers_to_train=self.getAdditionalLayersToTrain(),
+ num_labels=self.getNumClasses(),
+ optimizer_name=self.getOptimizerName(),
+ loss_name=self.getLossName(),
+ label_col=self.getLabelCol(),
+ text_col=self.getTextCol(),
+ learning_rate=self.getLearningRate(),
+ train_from_scratch=self.getTrainFromScratch(),
+ )
+ self._set(model=model)
+
+ def setCheckpoint(self, value):
+ return self._set(checkpoint=value)
+
+ def getCheckpoint(self):
+ return self.getOrDefault(self.checkpoint)
+
+ def setAdditionalLayersToTrain(self, value):
+ return self._set(additional_layers_to_train=value)
+
+ def getAdditionalLayersToTrain(self):
+ return self.getOrDefault(self.additional_layers_to_train)
+
+ def setNumClasses(self, value):
+ return self._set(num_classes=value)
+
+ def getNumClasses(self):
+ return self.getOrDefault(self.num_classes)
+
+ def setLossName(self, value):
+ return self._set(loss_name=value)
+
+ def getLossName(self):
+ return self.getOrDefault(self.loss_name)
+
+ def setOptimizerName(self, value):
+ return self._set(optimizer_name=value)
+
+ def getOptimizerName(self):
+ return self.getOrDefault(self.optimizer_name)
+
+ def setTokenizer(self, value):
+ return self._set(tokenizer=value)
+
+ def getTokenizer(self):
+ return self.getOrDefault(self.tokenizer)
+
+ def setMaxTokenLen(self, value):
+ return self._set(max_token_len=value)
+
+ def getMaxTokenLen(self):
+ return self.getOrDefault(self.max_token_len)
+
+ def setLearningRate(self, value):
+ return self._set(learning_rate=value)
+
+ def getLearningRate(self):
+ return self.getOrDefault(self.learning_rate)
+
+ def setTrainFromScratch(self, value):
+ return self._set(train_from_scratch=value)
+
+ def getTrainFromScratch(self):
+ return self.getOrDefault(self.train_from_scratch)
+
+ def _update_cols(self):
+ self.setFeatureCols([self.getTextCol()])
+ self.setLabelCols([self.getLabelCol()])
+
+ def _fit(self, dataset):
+ return super()._fit(dataset)
+
+ # override this method to provide a correct default backend
+ def _get_or_create_backend(self):
+ return get_or_create_backend(
+ self.getBackend(), self.getNumProc(), self.getVerbose(), self.getUseGpu()
+ )
+
+ def _update_transformation_fn(self):
+
+ text_col = self.getTextCol()
+ label_col = self.getLabelCol()
+ max_token_len = self.getMaxTokenLen()
+ # load it inside to avoid `Already borrowed` error (https://github.com/huggingface/tokenizers/issues/537)
+ if self.getTokenizer() is None:
+ self.setTokenizer(AutoTokenizer.from_pretrained(self.getCheckpoint()))
+ tokenizer = self.getTokenizer()
+
+ def _encoding_text(row):
+ text = row[text_col]
+ label = row[label_col]
+ encoding = tokenizer(
+ text,
+ max_length=max_token_len,
+ padding="max_length",
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ input_ids = encoding["input_ids"].flatten().numpy()
+ attention_mask = encoding["attention_mask"].flatten().numpy()
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": torch.tensor(label, dtype=int),
+ }
+
+ transformation_edit_fields = [
+ ("input_ids", int, None, True),
+ ("attention_mask", int, None, True),
+ ("labels", int, None, False),
+ ]
+ self.setTransformationEditFields(transformation_edit_fields)
+ transformation_removed_fields = [self.getTextCol(), self.getLabelCol()]
+ self.setTransformationRemovedFields(transformation_removed_fields)
+ self.setTransformationFn(_encoding_text)
+
+ def get_model_class(self):
+ return DeepTextModel
+
+ def _get_model_kwargs(self, model, history, optimizer, run_id, metadata):
+ return dict(
+ history=history,
+ model=model,
+ optimizer=optimizer,
+ input_shapes=self.getInputShapes(),
+ run_id=run_id,
+ _metadata=metadata,
+ loss=self.getLoss(),
+ loss_constructors=self.getLossConstructors(),
+ tokenizer=self.getTokenizer(),
+ checkpoint=self.getCheckpoint(),
+ max_token_len=self.getMaxTokenLen(),
+ label_col=self.getLabelCol(),
+ text_col=self.getTextCol(),
+ prediction_col=self.getPredictionCol(),
+ )
diff --git a/deep-learning/src/main/python/synapse/ml/dl/DeepTextModel.py b/deep-learning/src/main/python/synapse/ml/dl/DeepTextModel.py
new file mode 100644
index 0000000000..4bdc2c8d51
--- /dev/null
+++ b/deep-learning/src/main/python/synapse/ml/dl/DeepTextModel.py
@@ -0,0 +1,119 @@
+from horovod.spark.lightning import TorchModel
+import numpy as np
+import torch
+from horovod.spark.lightning import TorchModel
+from synapse.ml.dl.PredictionParams import TextPredictionParams
+from pyspark.ml.param import Param, Params, TypeConverters
+from pyspark.sql.functions import col, udf
+from pyspark.sql.types import DoubleType
+from synapse.ml.dl.utils import keywords_catch
+from transformers import AutoTokenizer
+
+
+class DeepTextModel(TorchModel, TextPredictionParams):
+
+ tokenizer = Param(Params._dummy(), "tokenizer", "tokenizer")
+
+ checkpoint = Param(
+ Params._dummy(), "checkpoint", "checkpoint of the deep text classifier"
+ )
+
+ max_token_len = Param(Params._dummy(), "max_token_len", "max_token_len")
+
+ @keywords_catch
+ def __init__(
+ self,
+ history=None,
+ model=None,
+ input_shapes=None,
+ optimizer=None,
+ run_id=None,
+ _metadata=None,
+ loss=None,
+ loss_constructors=None,
+ # diff from horovod
+ checkpoint=None,
+ tokenizer=None,
+ max_token_len=128,
+ label_col="label",
+ text_col="text",
+ prediction_col="prediction",
+ ):
+ super(DeepTextModel, self).__init__()
+
+ self._setDefault(
+ optimizer=None,
+ loss=None,
+ loss_constructors=None,
+ input_shapes=None,
+ checkpoint=None,
+ max_token_len=128,
+ text_col="text",
+ label_col="label",
+ prediction_col="prediction",
+ feature_columns=["text"],
+ label_columns=["label"],
+ outputCols=["output"],
+ )
+
+ kwargs = self._kwargs
+ self._set(**kwargs)
+
+ def setTokenizer(self, value):
+ return self._set(tokenizer=value)
+
+ def getTokenizer(self):
+ return self.getOrDefault(self.tokenizer)
+
+ def setCheckpoint(self, value):
+ return self._set(checkpoint=value)
+
+ def getCheckpoint(self):
+ return self.getOrDefault(self.checkpoint)
+
+ def setMaxTokenLen(self, value):
+ return self._set(max_token_len=value)
+
+ def getMaxTokenLen(self):
+ return self.getOrDefault(self.max_token_len)
+
+ def _update_cols(self):
+ self.setFeatureColumns([self.getTextCol()])
+ self.setLabelColoumns([self.getLabelCol()])
+
+ # override this to encoding text
+ def get_prediction_fn(self):
+ text_col = self.getTextCol()
+ max_token_len = self.getMaxTokenLen()
+ tokenizer = self.getTokenizer()
+
+ def predict_fn(model, row):
+ text = row[text_col]
+ data = tokenizer(
+ text,
+ max_length=max_token_len,
+ padding="max_length",
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ with torch.no_grad():
+ outputs = model(**data)
+ pred = torch.nn.functional.softmax(outputs.logits, dim=-1)
+
+ return pred
+
+ return predict_fn
+
+ # pytorch_lightning module has its own optimizer configuration
+ def getOptimizer(self):
+ return None
+
+ def _transform(self, df):
+ self._update_cols()
+ output_df = super()._transform(df)
+ argmax = udf(lambda v: float(np.argmax(v)), returnType=DoubleType())
+ pred_df = output_df.withColumn(
+ self.getPredictionCol(), argmax(col(self.getOutputCols()[0]))
+ )
+ return pred_df
diff --git a/deep-learning/src/main/python/synapse/ml/dl/DeepVisionClassifier.py b/deep-learning/src/main/python/synapse/ml/dl/DeepVisionClassifier.py
index 3a29483caa..2968fbd7a8 100644
--- a/deep-learning/src/main/python/synapse/ml/dl/DeepVisionClassifier.py
+++ b/deep-learning/src/main/python/synapse/ml/dl/DeepVisionClassifier.py
@@ -12,8 +12,8 @@
from pytorch_lightning.utilities import _module_available
from synapse.ml.dl.DeepVisionModel import DeepVisionModel
from synapse.ml.dl.LitDeepVisionModel import LitDeepVisionModel
-from synapse.ml.dl.utils import keywords_catch
-from synapse.ml.dl.PredictionParams import PredictionParams
+from synapse.ml.dl.utils import keywords_catch, get_or_create_backend
+from synapse.ml.dl.PredictionParams import VisionPredictionParams
_HOROVOD_AVAILABLE = _module_available("horovod")
if _HOROVOD_AVAILABLE:
@@ -28,7 +28,7 @@
raise ModuleNotFoundError("module not found: horovod")
-class DeepVisionClassifier(TorchEstimator, PredictionParams):
+class DeepVisionClassifier(TorchEstimator, VisionPredictionParams):
backbone = Param(
Params._dummy(), "backbone", "backbone of the deep vision classifier"
@@ -217,30 +217,9 @@ def _fit(self, dataset):
# override this method to provide a correct default backend
def _get_or_create_backend(self):
- backend = self.getBackend()
- num_proc = self.getNumProc()
- if backend is None:
- if num_proc is None:
- num_proc = self._find_num_proc()
- backend = SparkBackend(
- num_proc,
- stdout=sys.stdout,
- stderr=sys.stderr,
- prefix_output_with_timestamp=True,
- verbose=self.getVerbose(),
- )
- elif num_proc is not None:
- raise ValueError(
- 'At most one of parameters "backend" and "num_proc" may be specified'
- )
- return backend
-
- def _find_num_proc(self):
- if self.getUseGpu():
- # set it as number of executors for now (ignoring num_gpus per executor)
- sc = SparkContext.getOrCreate()
- return sc._jsc.sc().getExecutorMemoryStatus().size() - 1
- return None
+ return get_or_create_backend(
+ self.getBackend(), self.getNumProc(), self.getVerbose(), self.getUseGpu()
+ )
def _update_transformation_fn(self):
if self.getTransformationFn() is None:
@@ -258,23 +237,18 @@ def _update_transformation_fn(self):
)
self.setTransformFn(transform)
- def _create_transform_row(image_col, label_col, transform):
- def _transform_row(row):
- path = row[image_col]
- label = row[label_col]
- image = Image.open(path).convert("RGB")
- image = transform(image).numpy()
- return {image_col: image, label_col: label}
-
- return _transform_row
-
- self.setTransformationFn(
- _create_transform_row(
- self.getImageCol(),
- self.getLabelCol(),
- self.getTransformFn(),
- )
- )
+ image_col = self.getImageCol()
+ label_col = self.getLabelCol()
+ transform = self.getTransformFn()
+
+ def _transform_row(row):
+ path = row[image_col]
+ label = row[label_col]
+ image = Image.open(path).convert("RGB")
+ image = transform(image).numpy()
+ return {image_col: image, label_col: label}
+
+ self.setTransformationFn(_transform_row)
def get_model_class(self):
return DeepVisionModel
diff --git a/deep-learning/src/main/python/synapse/ml/dl/DeepVisionModel.py b/deep-learning/src/main/python/synapse/ml/dl/DeepVisionModel.py
index 7f35112df4..1fa67dcb4f 100644
--- a/deep-learning/src/main/python/synapse/ml/dl/DeepVisionModel.py
+++ b/deep-learning/src/main/python/synapse/ml/dl/DeepVisionModel.py
@@ -6,14 +6,14 @@
import torchvision.transforms as transforms
from horovod.spark.lightning import TorchModel
from PIL import Image
-from synapse.ml.dl.PredictionParams import PredictionParams
+from synapse.ml.dl.PredictionParams import VisionPredictionParams
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.sql.functions import col, udf
from pyspark.sql.types import DoubleType
from synapse.ml.dl.utils import keywords_catch
-class DeepVisionModel(TorchModel, PredictionParams):
+class DeepVisionModel(TorchModel, VisionPredictionParams):
transform_fn = Param(
Params._dummy(),
@@ -56,8 +56,6 @@ def __init__(
kwargs = self._kwargs
self._set(**kwargs)
- self._update_transform_fn()
- self._update_cols()
def setTransformFn(self, value):
return self._set(transform_fn=value)
@@ -93,29 +91,29 @@ def _update_cols(self):
def get_prediction_fn(self):
input_shape = self.getInputShapes()[0]
image_col = self.getImageCol()
+ transform = self.getTransformFn()
- def _create_predict_fn(transform):
- def predict_fn(model, row):
- if type(row[image_col]) == str:
- image = Image.open(row[image_col]).convert("RGB")
- data = torch.tensor(transform(image).numpy()).reshape(input_shape)
- else:
- data = torch.tensor([row[image_col]]).reshape(input_shape)
-
- with torch.no_grad():
- pred = model(data)
+ def predict_fn(model, row):
+ if type(row[image_col]) == str:
+ image = Image.open(row[image_col]).convert("RGB")
+ data = torch.tensor(transform(image).numpy()).reshape(input_shape)
+ else:
+ data = torch.tensor([row[image_col]]).reshape(input_shape)
- return pred
+ with torch.no_grad():
+ pred = model(data)
- return predict_fn
+ return pred
- return _create_predict_fn(self.getTransformFn())
+ return predict_fn
# pytorch_lightning module has its own optimizer configuration
def getOptimizer(self):
return None
def _transform(self, df):
+ self._update_transform_fn()
+ self._update_cols()
output_df = super()._transform(df)
argmax = udf(lambda v: float(np.argmax(v)), returnType=DoubleType())
pred_df = output_df.withColumn(
diff --git a/deep-learning/src/main/python/synapse/ml/dl/LitDeepTextModel.py b/deep-learning/src/main/python/synapse/ml/dl/LitDeepTextModel.py
new file mode 100644
index 0000000000..2283281c0b
--- /dev/null
+++ b/deep-learning/src/main/python/synapse/ml/dl/LitDeepTextModel.py
@@ -0,0 +1,176 @@
+# Copyright (C) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See LICENSE in project root for information.
+
+import inspect
+
+import pytorch_lightning as pl
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+from pytorch_lightning.utilities import _module_available
+
+_TRANSFORMERS_AVAILABLE = _module_available("transformers")
+if _TRANSFORMERS_AVAILABLE:
+ import transformers
+
+ _TRANSFORMERS_EQUAL_4_15_0 = transformers.__version__ == "4.15.0"
+ if _TRANSFORMERS_EQUAL_4_15_0:
+ from transformers import AutoModelForSequenceClassification
+ else:
+ raise RuntimeError(
+ "transformers should be == 4.15.0, found: {}".format(
+ transformers.__version__
+ )
+ )
+else:
+ raise ModuleNotFoundError("module not found: transformers")
+
+
+class LitDeepTextModel(pl.LightningModule):
+ def __init__(
+ self,
+ checkpoint,
+ text_col,
+ label_col,
+ num_labels,
+ additional_layers_to_train,
+ optimizer_name,
+ loss_name,
+ learning_rate=None,
+ train_from_scratch=True,
+ ):
+ """
+ :param checkpoint: Checkpoint for pre-trained model. This is expected to
+ be a checkpoint you could find on [HuggingFace](https://huggingface.co/models)
+ and is of type `AutoModelForSequenceClassification`.
+ :param text_col: Text column name.
+ :param label_col: Label column name.
+ :param num_labels: Number of labels for classification.
+ :param additional_layers_to_train: Additional number of layers to train on. For Deep text model
+ we'd better choose a positive number for better performance.
+ :param optimizer_name: Name of the optimizer.
+ :param loss_name: Name of the loss function.
+ :param learning_rate: Learning rate for the optimizer.
+ :param train_from_scratch: Whether train the model from scratch or not. If this is set to true then
+ additional_layers_to_train param will be ignored. Default to True.
+ """
+ super(LitDeepTextModel, self).__init__()
+
+ self.checkpoint = checkpoint
+ self.text_col = text_col
+ self.label_col = label_col
+ self.num_labels = num_labels
+ self.additional_layers_to_train = additional_layers_to_train
+ self.optimizer_name = optimizer_name
+ self.loss_name = loss_name
+ self.learning_rate = learning_rate
+ self.train_from_scratch = train_from_scratch
+
+ self._check_params()
+
+ self.save_hyperparameters(
+ "checkpoint",
+ "text_col",
+ "label_col",
+ "num_labels",
+ "additional_layers_to_train",
+ "optimizer_name",
+ "loss_name",
+ "learning_rate",
+ "train_from_scratch",
+ )
+
+ def _check_params(self):
+ try:
+ # TODO: Add other types of models here
+ self.model = AutoModelForSequenceClassification.from_pretrained(
+ self.checkpoint, num_labels=self.num_labels
+ )
+ self._update_learning_rate()
+ except Exception as err:
+ raise ValueError(
+ f"No checkpoint {self.checkpoint} found: {err=}, {type(err)=}"
+ )
+
+ if self.loss_name.lower() not in F.__dict__:
+ raise ValueError("No loss function: {} found".format(self.loss_name))
+ self.loss_fn = F.__dict__[self.loss_name.lower()]
+
+ optimizers_mapping = {
+ key.lower(): value
+ for key, value in optim.__dict__.items()
+ if inspect.isclass(value) and issubclass(value, optim.Optimizer)
+ }
+ if self.optimizer_name.lower() not in optimizers_mapping:
+ raise ValueError("No optimizer: {} found".format(self.optimizer_name))
+ self.optimizer_fn = optimizers_mapping[self.optimizer_name.lower()]
+
+ def forward(self, **inputs):
+ return self.model(**inputs)
+
+ def configure_optimizers(self):
+ if not self.train_from_scratch:
+ # Freeze those weights
+ for p in self.model.base_model.parameters():
+ p.requires_grad = False
+ self._fine_tune_layers()
+ params_to_update = filter(lambda p: p.requires_grad, self.model.parameters())
+ return self.optimizer_fn(params_to_update, self.learning_rate)
+
+ def _fine_tune_layers(self):
+ if self.additional_layers_to_train < 0:
+ raise ValueError(
+ "additional_layers_to_train has to be non-negative: {} found".format(
+ self.additional_layers_to_train
+ )
+ )
+ # base_model contains the real model to fine tune
+ children = list(self.model.base_model.children())
+ added_layer, cur_layer = 0, -1
+ while added_layer < self.additional_layers_to_train and -cur_layer < len(
+ children
+ ):
+ tunable = False
+ for p in children[cur_layer].parameters():
+ p.requires_grad = True
+ tunable = True
+ # only tune those layers contain parameters
+ if tunable:
+ added_layer += 1
+ cur_layer -= 1
+
+ def _update_learning_rate(self):
+ ## TODO: add more default values for different models
+ if not self.learning_rate:
+ if "bert" in self.checkpoint:
+ self.learning_rate = 5e-5
+ else:
+ self.learning_rate = 0.01
+
+ def training_step(self, batch, batch_idx):
+ loss = self._step(batch, False)
+ self.log("train_loss", loss)
+ return loss
+
+ def _step(self, batch, validation):
+ inputs = batch
+ outputs = self(**inputs)
+ loss = outputs.loss
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ loss = self._step(batch, True)
+ self.log("val_loss", loss)
+
+ def validation_epoch_end(self, outputs):
+ avg_loss = (
+ torch.stack([x["val_loss"] for x in outputs]).mean()
+ if len(outputs) > 0
+ else float("inf")
+ )
+ self.log("avg_val_loss", avg_loss)
+
+ def test_step(self, batch, batch_idx):
+ loss = self._step(batch, False)
+ self.log("test_loss", loss)
+ return loss
diff --git a/deep-learning/src/main/python/synapse/ml/dl/PredictionParams.py b/deep-learning/src/main/python/synapse/ml/dl/PredictionParams.py
index e4c68845ec..c2e7782790 100644
--- a/deep-learning/src/main/python/synapse/ml/dl/PredictionParams.py
+++ b/deep-learning/src/main/python/synapse/ml/dl/PredictionParams.py
@@ -4,7 +4,7 @@
from pyspark.ml.param import Param, Params, TypeConverters
-class PredictionParams(Params):
+class HasLabelColParam(Params):
label_col = Param(
Params._dummy(),
@@ -13,25 +13,9 @@ class PredictionParams(Params):
typeConverter=TypeConverters.toString,
)
- image_col = Param(
- Params._dummy(),
- "image_col",
- "image column name.",
- typeConverter=TypeConverters.toString,
- )
-
- prediction_col = Param(
- Params._dummy(),
- "prediction_col",
- "prediction column name.",
- typeConverter=TypeConverters.toString,
- )
-
def __init__(self):
- super(PredictionParams, self).__init__()
- self._setDefault(
- label_col="label", image_col="image", prediction_col="prediction"
- )
+ super(HasLabelColParam, self).__init__()
+ self._setDefault(label_col="label")
def setLabelCol(self, value):
"""
@@ -45,6 +29,20 @@ def getLabelCol(self):
"""
return self.getOrDefault(self.label_col)
+
+class HasImageColParam(Params):
+
+ image_col = Param(
+ Params._dummy(),
+ "image_col",
+ "image column name.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ def __init__(self):
+ super(HasImageColParam, self).__init__()
+ self._setDefault(image_col="image")
+
def setImageCol(self, value):
"""
Sets the value of :py:attr:`image_col`.
@@ -57,6 +55,47 @@ def getImageCol(self):
"""
return self.getOrDefault(self.image_col)
+
+## TODO: Potentially generalize to support multiple text columns as input
+class HasTextColParam(Params):
+
+ text_col = Param(
+ Params._dummy(),
+ "text_col",
+ "text column name.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ def __init__(self):
+ super(HasTextColParam, self).__init__()
+ self._setDefault(text_col="text")
+
+ def setTextCol(self, value):
+ """
+ Sets the value of :py:attr:`text_col`.
+ """
+ return self._set(text_col=value)
+
+ def getTextCol(self):
+ """
+ Gets the value of text_col or its default value.
+ """
+ return self.getOrDefault(self.text_col)
+
+
+class HasPredictionColParam(Params):
+
+ prediction_col = Param(
+ Params._dummy(),
+ "prediction_col",
+ "prediction column name.",
+ typeConverter=TypeConverters.toString,
+ )
+
+ def __init__(self):
+ super(HasPredictionColParam, self).__init__()
+ self._setDefault(prediction_col="prediction")
+
def setPredictionCol(self, value):
"""
Sets the value of :py:attr:`prediction_col`.
@@ -68,3 +107,13 @@ def getPredictionCol(self):
Gets the value of prediction_col or its default value.
"""
return self.getOrDefault(self.prediction_col)
+
+
+class VisionPredictionParams(HasLabelColParam, HasImageColParam, HasPredictionColParam):
+ def __init__(self):
+ super(VisionPredictionParams, self).__init__()
+
+
+class TextPredictionParams(HasLabelColParam, HasTextColParam, HasPredictionColParam):
+ def __init__(self):
+ super(TextPredictionParams, self).__init__()
diff --git a/deep-learning/src/main/python/synapse/ml/dl/__init__.py b/deep-learning/src/main/python/synapse/ml/dl/__init__.py
index 1a8ad55d9e..7d97ae576d 100644
--- a/deep-learning/src/main/python/synapse/ml/dl/__init__.py
+++ b/deep-learning/src/main/python/synapse/ml/dl/__init__.py
@@ -1,6 +1,9 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.
+from synapse.ml.dl.DeepTextClassifier import *
+from synapse.ml.dl.DeepTextModel import *
from synapse.ml.dl.DeepVisionClassifier import *
from synapse.ml.dl.DeepVisionModel import *
+from synapse.ml.dl.LitDeepTextModel import *
from synapse.ml.dl.LitDeepVisionModel import *
diff --git a/deep-learning/src/main/python/synapse/ml/dl/utils.py b/deep-learning/src/main/python/synapse/ml/dl/utils.py
index 956d98541c..e9fca931eb 100644
--- a/deep-learning/src/main/python/synapse/ml/dl/utils.py
+++ b/deep-learning/src/main/python/synapse/ml/dl/utils.py
@@ -1,7 +1,11 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.
+import sys
+
from functools import wraps
+from horovod.spark.common.backend import SparkBackend
+from pyspark.context import SparkContext
def keywords_catch(func):
@@ -22,3 +26,29 @@ def wrapper(self, *args, **kwargs):
return func(self, **kwargs)
return wrapper
+
+
+def get_or_create_backend(backend, num_proc, verbose, use_gpu):
+ if backend is None:
+ if num_proc is None:
+ num_proc = _find_num_proc(use_gpu)
+ backend = SparkBackend(
+ num_proc,
+ stdout=sys.stdout,
+ stderr=sys.stderr,
+ prefix_output_with_timestamp=True,
+ verbose=verbose,
+ )
+ elif num_proc is not None:
+ raise ValueError(
+ 'At most one of parameters "backend" and "num_proc" may be specified'
+ )
+ return backend
+
+
+def _find_num_proc(use_gpu):
+ if use_gpu:
+ # set it as number of executors for now (ignoring num_gpus per executor)
+ sc = SparkContext.getOrCreate()
+ return sc._jsc.sc().getExecutorMemoryStatus().size() - 1
+ return None
diff --git a/deep-learning/src/test/python/synapsemltest/dl/conftest.py b/deep-learning/src/test/python/synapsemltest/dl/conftest.py
index 200dd6b8f3..1542168271 100644
--- a/deep-learning/src/test/python/synapsemltest/dl/conftest.py
+++ b/deep-learning/src/test/python/synapsemltest/dl/conftest.py
@@ -8,8 +8,10 @@
from os.path import join
import numpy as np
+import pandas as pd
import pytest
import torchvision.transforms as transforms
+from pyspark.ml.feature import StringIndexer
IS_WINDOWS = os.name == "nt"
delimiter = "\\" if IS_WINDOWS else "/"
@@ -19,6 +21,14 @@
)
+class CallbackBackend(object):
+ def run(self, fn, args=(), kwargs={}, env={}):
+ return [fn(*args, **kwargs)] * self.num_processes()
+
+ def num_processes(self):
+ return 1
+
+
def _download_dataset():
urllib.request.urlretrieve(
@@ -82,3 +92,20 @@ def transform():
]
)
return transform
+
+
+def _prepare_text_data(spark):
+ df = (
+ spark.read.format("csv")
+ .option("header", "true")
+ .load(
+ "wasbs://publicwasb@mmlspark.blob.core.windows.net/text_classification/Emotion_classification.csv"
+ )
+ )
+ indexer = StringIndexer(inputCol="Emotion", outputCol="label")
+ indexer_model = indexer.fit(df)
+ df = indexer_model.transform(df).drop("Emotion")
+
+ train_df, test_df = df.randomSplit([0.85, 0.15], seed=1)
+
+ return train_df, test_df
diff --git a/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_classifier.py b/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_classifier.py
new file mode 100644
index 0000000000..4cd464f34a
--- /dev/null
+++ b/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_classifier.py
@@ -0,0 +1,55 @@
+# Copyright (C) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See LICENSE in project root for information.
+
+from pyspark.ml.evaluation import MulticlassClassificationEvaluator
+from pyspark.sql import SparkSession
+from pytorch_lightning.callbacks import ModelCheckpoint
+from synapse.ml.dl import *
+
+import pytest
+from .conftest import CallbackBackend, _prepare_text_data
+from .test_deep_vision_classifier import local_store
+from .test_deep_vision_model import MyDummyCallback
+
+
+@pytest.mark.skip(reason="not testing this for now")
+def test_bert_base_cased():
+ spark = SparkSession.builder.master("local[*]").getOrCreate()
+
+ train_df, test_df = _prepare_text_data(spark)
+
+ ctx = CallbackBackend()
+
+ epochs = 2
+ callbacks = [
+ MyDummyCallback(epochs),
+ ModelCheckpoint(dirpath="target/bert_base_uncased/"),
+ ]
+
+ with local_store() as store:
+
+ checkpoint = "bert-base-uncased"
+
+ deep_text_classifier = DeepTextClassifier(
+ checkpoint=checkpoint,
+ store=store,
+ backend=ctx,
+ callbacks=callbacks,
+ num_classes=6,
+ batch_size=16,
+ epochs=epochs,
+ validation=0.1,
+ text_col="Text",
+ transformation_removed_fields=["Text", "Emotion", "label"],
+ )
+
+ deep_text_model = deep_text_classifier.fit(train_df)
+
+ pred_df = deep_text_model.transform(test_df)
+ evaluator = MulticlassClassificationEvaluator(
+ predictionCol="prediction", labelCol="label", metricName="accuracy"
+ )
+ accuracy = evaluator.evaluate(pred_df)
+ assert accuracy > 0.5
+
+ spark.stop()
diff --git a/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_model.py b/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_model.py
new file mode 100644
index 0000000000..7466e7dbec
--- /dev/null
+++ b/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_model.py
@@ -0,0 +1,82 @@
+# Copyright (C) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See LICENSE in project root for information.
+
+import pytest
+import torch
+from pyspark.sql import SparkSession
+from pytorch_lightning import Trainer
+from torch.utils.data import DataLoader, Dataset
+from transformers import AutoTokenizer
+
+from .conftest import _prepare_text_data
+from .test_deep_vision_model import MyDummyCallback
+
+from synapse.ml.dl import *
+
+
+@pytest.mark.skip(reason="skip this as it takes too long")
+def test_lit_deep_text_model():
+ class TextDataset(Dataset):
+ def __init__(self, data, tokenizer, max_token_len):
+ super(TextDataset, self).__init__()
+ self.data = data
+ self.tokenizer = tokenizer
+ self.max_token_len = max_token_len
+
+ def __getitem__(self, index):
+ text = self.data["Text"][index]
+ label = self.data["label"][index]
+ encoding = self.tokenizer(
+ text,
+ truncation=True,
+ padding="max_length",
+ max_length=self.max_token_len,
+ return_tensors="pt",
+ )
+ return {
+ "input_ids": encoding["input_ids"].flatten(),
+ "attention_mask": encoding["attention_mask"].flatten(),
+ "labels": torch.tensor([label], dtype=int),
+ }
+
+ def __len__(self):
+ return len(self.data)
+
+ spark = SparkSession.builder.master("local[*]").getOrCreate()
+ train_df, test_df = _prepare_text_data(spark)
+
+ checkpoint = "bert-base-cased"
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
+ max_token_len = 128
+
+ train_loader = DataLoader(
+ TextDataset(train_df.toPandas(), tokenizer, max_token_len),
+ batch_size=16,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ )
+
+ test_loader = DataLoader(
+ TextDataset(test_df.toPandas(), tokenizer, max_token_len),
+ batch_size=16,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ )
+
+ epochs = 1
+ model = LitDeepTextModel(
+ checkpoint=checkpoint,
+ additional_layers_to_train=10,
+ num_labels=6,
+ optimizer_name="adam",
+ loss_name="cross_entropy",
+ label_col="label",
+ text_col="Text",
+ )
+
+ callbacks = [MyDummyCallback(epochs)]
+ trainer = Trainer(callbacks=callbacks, max_epochs=epochs)
+ trainer.fit(model, train_dataloaders=train_loader)
+ trainer.test(model, dataloaders=test_loader)
diff --git a/deep-learning/src/test/python/synapsemltest/dl/test_deep_vision_classifier.py b/deep-learning/src/test/python/synapsemltest/dl/test_deep_vision_classifier.py
index 20e1d45d10..24dac4d80f 100644
--- a/deep-learning/src/test/python/synapsemltest/dl/test_deep_vision_classifier.py
+++ b/deep-learning/src/test/python/synapsemltest/dl/test_deep_vision_classifier.py
@@ -15,6 +15,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from synapse.ml.dl import *
+from .conftest import CallbackBackend
from .test_deep_vision_model import MyDummyCallback
@@ -64,14 +65,6 @@ def extract_path_and_label(path):
return train_df, test_df
-class CallbackBackend(object):
- def run(self, fn, args=(), kwargs={}, env={}):
- return [fn(*args, **kwargs)] * self.num_processes()
-
- def num_processes(self):
- return 1
-
-
@pytest.mark.skip(reason="not testing this for now")
def test_mobilenet_v2(get_data_path):
spark = SparkSession.builder.master("local[*]").getOrCreate()
diff --git a/environment.yml b/environment.yml
index 4da964a9eb..8b573642b8 100644
--- a/environment.yml
+++ b/environment.yml
@@ -35,3 +35,5 @@ dependencies:
- onnxmltools==1.7.0
- matplotlib
- Pillow
+ - transformers==4.15.0
+ - huggingface-hub>=0.8.1
diff --git a/notebooks/features/simple_deep_learning/DeepLearning - Deep Text Classification.ipynb b/notebooks/features/simple_deep_learning/DeepLearning - Deep Text Classification.ipynb
new file mode 100644
index 0000000000..9e237e9b3b
--- /dev/null
+++ b/notebooks/features/simple_deep_learning/DeepLearning - Deep Text Classification.ipynb
@@ -0,0 +1,256 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "a4b6a348-5155-4665-9616-3776bea40ff0",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "source": [
+ "## Deep Learning - Deep Text Classifier"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "910eee89-ded8-4c36-90ae-e9b8539c5773",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "source": [
+ "### Environment Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "60a84fca-38ae-48dc-826a-1cc2011c3977",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# install cloudpickle 2.0.0 to add synapse module for usage of horovod\n",
+ "%pip install cloudpickle==2.0.0 --force-reinstall --no-deps"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "cd1e438b-4b6e-4d92-8cd4-0c184afe0721",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import synapse\n",
+ "import cloudpickle\n",
+ "\n",
+ "cloudpickle.register_pickle_by_value(synapse)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "29b27e85-09c0-4e5f-8c58-af3a2bc9d373",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "! horovodrun --check-build"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "205531d2-6c06-49b4-828a-6f207371830b",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "source": [
+ "### Read Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import urllib\n",
+ "\n",
+ "urllib.request.urlretrieve(\n",
+ " \"https://mmlspark.blob.core.windows.net/publicwasb/text_classification/Emotion_classification.csv\",\n",
+ " \"/tmp/Emotion_classification.csv\",\n",
+ ")\n",
+ "\n",
+ "import pandas as pd\n",
+ "from pyspark.ml.feature import StringIndexer\n",
+ "\n",
+ "df = pd.read_csv(\"/tmp/Emotion_classification.csv\")\n",
+ "df = spark.createDataFrame(df)\n",
+ "\n",
+ "indexer = StringIndexer(inputCol=\"Emotion\", outputCol=\"label\")\n",
+ "indexer_model = indexer.fit(df)\n",
+ "df = indexer_model.transform(df).drop((\"Emotion\"))\n",
+ "\n",
+ "train_df, test_df = df.randomSplit([0.85, 0.15], seed=1)\n",
+ "display(train_df)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "bc46d0f5-86b6-409d-b6f9-e3deae631d50",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "source": [
+ "### Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "1f6b513c-606b-4e32-b75e-2baaf19a11d9",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from horovod.spark.common.store import DBFSLocalStore\n",
+ "from pytorch_lightning.callbacks import ModelCheckpoint\n",
+ "from synapse.ml.dl import *\n",
+ "\n",
+ "checkpoint = \"bert-base-uncased\"\n",
+ "run_output_dir = f\"/dbfs/FileStore/test/{checkpoint}\"\n",
+ "store = DBFSLocalStore(run_output_dir)\n",
+ "\n",
+ "epochs = 1\n",
+ "\n",
+ "callbacks = [ModelCheckpoint(filename=\"{epoch}-{train_loss:.2f}\")]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "9450b5fe-ab0d-4f73-8eb2-f2428ad88b4e",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "deep_text_classifier = DeepTextClassifier(\n",
+ " checkpoint=checkpoint,\n",
+ " store=store,\n",
+ " callbacks=callbacks,\n",
+ " num_classes=6,\n",
+ " batch_size=16,\n",
+ " epochs=epochs,\n",
+ " validation=0.1,\n",
+ " text_col=\"Text\",\n",
+ ")\n",
+ "\n",
+ "deep_text_model = deep_text_classifier.fit(train_df.limit(6000).repartition(50))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "4168b8a6-330e-4a28-949a-16954e1ea757",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "source": [
+ "### Prediction"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "inputWidgets": {},
+ "nuid": "d6f97c75-b814-4138-a2a8-512bc85a6f65",
+ "showTitle": false,
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
+ "\n",
+ "pred_df = deep_text_model.transform(test_df.limit(500))\n",
+ "evaluator = MulticlassClassificationEvaluator(\n",
+ " predictionCol=\"prediction\", labelCol=\"label\", metricName=\"accuracy\"\n",
+ ")\n",
+ "print(\"Test accuracy:\", evaluator.evaluate(pred_df))"
+ ]
+ }
+ ],
+ "metadata": {
+ "application/vnd.databricks.v1+notebook": {
+ "dashboards": [],
+ "language": "python",
+ "notebookMetadata": {
+ "pythonIndentUnit": 2
+ },
+ "notebookName": "DeepLearning - Deep Text Classification",
+ "notebookOrigID": 4390929852015145,
+ "widgets": {}
+ },
+ "kernelspec": {
+ "display_name": "Python 3.8.5 ('base')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.8.5"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "601a75c4c141f401603984f1538447337114e368c54c4d5b589ea94315afdca2"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/notebooks/features/simple_deep_learning/DeepLearning - Deep Vision Classifier.ipynb b/notebooks/features/simple_deep_learning/DeepLearning - Deep Vision Classification.ipynb
similarity index 100%
rename from notebooks/features/simple_deep_learning/DeepLearning - Deep Vision Classifier.ipynb
rename to notebooks/features/simple_deep_learning/DeepLearning - Deep Vision Classification.ipynb
diff --git a/templates/update_cli.yml b/templates/update_cli.yml
index 4444074e53..67c845f847 100644
--- a/templates/update_cli.yml
+++ b/templates/update_cli.yml
@@ -1,7 +1,7 @@
steps:
- task: UsePythonVersion@0
inputs:
- versionSpec: '3.8.X'
+ versionSpec: '3.8'
architecture: 'x64'
- task: JavaToolInstaller@0
inputs:
diff --git a/website/docs/features/simple_deep_learning/about.md b/website/docs/features/simple_deep_learning/about.md
index f66fc886f4..6cb8c649d9 100644
--- a/website/docs/features/simple_deep_learning/about.md
+++ b/website/docs/features/simple_deep_learning/about.md
@@ -26,13 +26,13 @@ Coordinate: com.microsoft.azure:synapseml_2.12:SYNAPSEML_SCALA_VERSION
Repository: https://mmlspark.azureedge.net/maven
```
:::note
-If you install the jar package, you need to follow the first two cell of this [sample](./DeepLearning%20-%20Deep%20Vision%20Classifier.md/#environment-setup----reinstall-horovod-based-on-new-version-of-pytorch)
+If you install the jar package, you need to follow the first two cell of this [sample](./DeepLearning%20-%20Deep%20Vision%20Classification.md/#environment-setup----reinstall-horovod-based-on-new-version-of-pytorch)
to make horovod recognizing our module.
:::
## 3. Try our sample notebook
-You could follow the rest of this [sample](./DeepLearning%20-%20Deep%20Vision%20Classifier.md) and have a try on your own dataset.
+You could follow the rest of this [sample](./DeepLearning%20-%20Deep%20Vision%20Classification.md) and have a try on your own dataset.
Supported models (`backbone` parameter for `DeepVisionClassifer`) should be string format of [torchvision supported models](https://github.com/pytorch/vision/blob/v0.12.0/torchvision/models/__init__.py);
You could also check by running `backbone in torchvision.models.__dict__`.