diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/DotnetCodegen.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/DotnetCodegen.scala index f687bc73b4..d8c455b98c 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/DotnetCodegen.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/DotnetCodegen.scala @@ -34,6 +34,9 @@ object DotnetCodegen { if (!projectDir.exists()){ projectDir.mkdirs() } + val newtonsoftDep = if(curProject == "DeepLearning") { + s"""""".stripMargin + } else "" // TODO: update SynapseML.DotnetBase version whenever we upload a new one writeFile(new File(projectDir, s"${curProject}ProjectSetup.csproj"), s""" @@ -52,6 +55,7 @@ object DotnetCodegen { | | | + | $newtonsoftDep | | | diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala index 8aaa28f8d3..b348bc0e48 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala @@ -57,12 +57,16 @@ object PyCodegen { s"""MINIMUM_SUPPORTED_PYTHON_VERSION = "3.8"""".stripMargin } else "" val extraRequirements = if (conf.name.contains("deep-learning")) { + // There's `Already borrowed` error found in transformers 4.16.2 when using tokenizers s"""extras_require={"extras": [ | "cmake", | "horovod==0.25.0", | "pytorch_lightning>=1.5.0,<1.5.10", | "torch==1.11.0", - | "torchvision>=0.12.0" + | "torchvision>=0.12.0", + | "transformers==4.15.0", + | "petastorm>=0.12.0", + | "huggingface-hub>=0.8.1", |]}, |python_requires=f">={MINIMUM_SUPPORTED_PYTHON_VERSION}",""".stripMargin } else "" diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala index 8b91be9fea..be308c7af7 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala @@ -14,7 +14,7 @@ class DatabricksGPUTests extends DatabricksTestHelper { val horovodInstallationScript: File = FileUtilities.join( BuildInfo.baseDirectory.getParent, "deep-learning", "src", "main", "python", "horovod_installation.sh").getCanonicalFile - uploadFileToDBFS(horovodInstallationScript, "/FileStore/horovod/horovod_installation.sh") + uploadFileToDBFS(horovodInstallationScript, "/FileStore/horovod-fix-commit/horovod_installation.sh") val clusterId: String = createClusterInPool(GPUClusterName, AdbGpuRuntime, 2, GpuPoolId, GPUInitScripts) val jobIdsToCancel: ListBuffer[Int] = databricksTestHelper( clusterId, GPULibraries, GPUNotebooks) diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala index 22bccd1e04..54dea3efb4 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala @@ -61,11 +61,13 @@ object DatabricksUtilities { // TODO: install synapse.ml.dl wheel package here val GPULibraries: String = List( - Map("maven" -> Map("coordinates" -> PackageMavenCoordinate, "repo" -> PackageRepository)) + Map("maven" -> Map("coordinates" -> PackageMavenCoordinate, "repo" -> PackageRepository)), + Map("pypi" -> Map("package" -> "transformers==4.15.0")), + Map("pypi" -> Map("package" -> "petastorm==0.12.0")) ).toJson.compactPrint val GPUInitScripts: String = List( - Map("dbfs" -> Map("destination" -> "dbfs:/FileStore/horovod/horovod_installation.sh")) + Map("dbfs" -> Map("destination" -> "dbfs:/FileStore/horovod-fix-commit/horovod_installation.sh")) ).toJson.compactPrint // Execution Params diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/SynapseTests.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/SynapseTests.scala index 734746a3ca..f89b97e624 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/SynapseTests.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/SynapseTests.scala @@ -46,7 +46,8 @@ class SynapseTests extends TestBase { .filter(_.getAbsolutePath.endsWith(".py")) .filterNot(_.getAbsolutePath.contains("HyperParameterTuning")) .filterNot(_.getAbsolutePath.contains("IsolationForest")) - .filterNot(_.getAbsolutePath.contains("DeepLearningDeepVisionClassifier")) + .filterNot(_.getAbsolutePath.contains("DeepLearningDeepTextClassification")) + .filterNot(_.getAbsolutePath.contains("DeepLearningDeepVisionClassification")) .filterNot(_.getAbsolutePath.contains("Interpretability")) //add more exclusion TODO: Remove when fixed .sortBy(_.getAbsolutePath) diff --git a/deep-learning/src/main/python/horovod_installation.sh b/deep-learning/src/main/python/horovod_installation.sh index 822792fc38..758b4993de 100644 --- a/deep-learning/src/main/python/horovod_installation.sh +++ b/deep-learning/src/main/python/horovod_installation.sh @@ -8,6 +8,8 @@ set -eu # Install prerequisite libraries that horovod depends on pip install pytorch-lightning==1.5.0 pip install torchvision==0.12.0 +pip install transformers==4.15.0 +pip install petastorm>=0.12.0 # Remove Outdated Signing Key: sudo apt-key del 7fa2af80 @@ -32,9 +34,12 @@ libcusparse-dev-11-0=11.1.1.245-1 git clone --recursive https://github.com/horovod/horovod.git cd horovod -# fix version 0.25.0 -git fetch origin refs/tags/v0.25.0:tags/v0.25.0 -git checkout tags/v0.25.0 -b v0.25.0-branch +# # fix version 0.25.0 +# git fetch origin refs/tags/v0.25.0:tags/v0.25.0 +# git checkout tags/v0.25.0 -b v0.25.0-branch +# fix to this commit number until they release a new version +git checkout ab97fd15bbba3258adcdd12983f36a1cdeacbc94 +git checkout -b tmp-branch rm -rf build/ dist/ HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_CUDA_HOME=/usr/local/cuda-11/ HOROVOD_WITH_PYTORCH=1 HOROVOD_WITHOUT_MXNET=1 \ /databricks/python3/bin/python setup.py bdist_wheel diff --git a/deep-learning/src/main/python/synapse/ml/dl/DeepTextClassifier.py b/deep-learning/src/main/python/synapse/ml/dl/DeepTextClassifier.py new file mode 100644 index 0000000000..0702fc828b --- /dev/null +++ b/deep-learning/src/main/python/synapse/ml/dl/DeepTextClassifier.py @@ -0,0 +1,290 @@ +from horovod.spark.lightning import TorchEstimator +import torch +from pyspark.ml.param.shared import Param, Params +from pytorch_lightning.utilities import _module_available +from synapse.ml.dl.DeepTextModel import DeepTextModel +from synapse.ml.dl.LitDeepTextModel import LitDeepTextModel +from synapse.ml.dl.utils import keywords_catch, get_or_create_backend +from synapse.ml.dl.PredictionParams import TextPredictionParams + +_TRANSFORMERS_AVAILABLE = _module_available("transformers") +if _TRANSFORMERS_AVAILABLE: + import transformers + + _TRANSFORMERS_EQUAL_4_15_0 = transformers.__version__ == "4.15.0" + if _TRANSFORMERS_EQUAL_4_15_0: + from transformers import AutoTokenizer + else: + raise RuntimeError( + "transformers should be == 4.15.0, found: {}".format( + transformers.__version__ + ) + ) +else: + raise ModuleNotFoundError("module not found: transformers") + + +class DeepTextClassifier(TorchEstimator, TextPredictionParams): + + checkpoint = Param( + Params._dummy(), "checkpoint", "checkpoint of the deep text classifier" + ) + + additional_layers_to_train = Param( + Params._dummy(), + "additional_layers_to_train", + "number of last layers to fine tune for the model, should be larger or equal to 0. default to 3.", + ) + + num_classes = Param(Params._dummy(), "num_classes", "number of target classes") + + loss_name = Param( + Params._dummy(), + "loss_name", + "string representation of torch.nn.functional loss function for the underlying pytorch_lightning model, e.g. binary_cross_entropy", + ) + + optimizer_name = Param( + Params._dummy(), + "optimizer_name", + "string representation of optimizer function for the underlying pytorch_lightning model", + ) + + tokenizer = Param(Params._dummy(), "tokenizer", "tokenizer") + + max_token_len = Param(Params._dummy(), "max_token_len", "max_token_len") + + learning_rate = Param( + Params._dummy(), "learning_rate", "learning rate to be used for the optimizer" + ) + + train_from_scratch = Param( + Params._dummy(), + "train_from_scratch", + "whether to train the model from scratch or not, if set to False then param additional_layers_to_train need to be specified.", + ) + + @keywords_catch + def __init__( + self, + checkpoint=None, + additional_layers_to_train=3, # this is needed otherwise the performance is usually bad + num_classes=None, + optimizer_name="adam", + loss_name="cross_entropy", + tokenizer=None, + max_token_len=128, + learning_rate=None, + train_from_scratch=True, + # Classifier args + label_col="label", + text_col="text", + prediction_col="prediction", + # TorchEstimator args + num_proc=None, + backend=None, + store=None, + metrics=None, + loss_weights=None, + sample_weight_col=None, + gradient_compression=None, + input_shapes=None, + validation=None, + callbacks=None, + batch_size=None, + val_batch_size=None, + epochs=None, + verbose=1, + random_seed=None, + shuffle_buffer_size=None, + partitions_per_process=None, + run_id=None, + train_minibatch_fn=None, + train_steps_per_epoch=None, + validation_steps_per_epoch=None, + transformation_fn=None, + transformation_edit_fields=None, + transformation_removed_fields=None, + train_reader_num_workers=None, + trainer_args=None, + val_reader_num_workers=None, + reader_pool_type=None, + label_shapes=None, + inmemory_cache_all=False, + num_gpus=None, + logger=None, + log_every_n_steps=50, + data_module=None, + loader_num_epochs=None, + terminate_on_nan=False, + profiler=None, + debug_data_loader=False, + train_async_data_loader_queue_size=None, + val_async_data_loader_queue_size=None, + use_gpu=True, + mp_start_method=None, + ): + super(DeepTextClassifier, self).__init__() + + self._setDefault( + checkpoint=None, + additional_layers_to_train=3, + num_classes=None, + optimizer_name="adam", + loss_name="cross_entropy", + tokenizer=None, + max_token_len=128, + learning_rate=None, + train_from_scratch=True, + feature_cols=["text"], + label_cols=["label"], + label_col="label", + text_col="text", + prediction_col="prediction", + ) + + kwargs = self._kwargs + self._set(**kwargs) + + self._update_cols() + self._update_transformation_fn() + + model = LitDeepTextModel( + checkpoint=self.getCheckpoint(), + additional_layers_to_train=self.getAdditionalLayersToTrain(), + num_labels=self.getNumClasses(), + optimizer_name=self.getOptimizerName(), + loss_name=self.getLossName(), + label_col=self.getLabelCol(), + text_col=self.getTextCol(), + learning_rate=self.getLearningRate(), + train_from_scratch=self.getTrainFromScratch(), + ) + self._set(model=model) + + def setCheckpoint(self, value): + return self._set(checkpoint=value) + + def getCheckpoint(self): + return self.getOrDefault(self.checkpoint) + + def setAdditionalLayersToTrain(self, value): + return self._set(additional_layers_to_train=value) + + def getAdditionalLayersToTrain(self): + return self.getOrDefault(self.additional_layers_to_train) + + def setNumClasses(self, value): + return self._set(num_classes=value) + + def getNumClasses(self): + return self.getOrDefault(self.num_classes) + + def setLossName(self, value): + return self._set(loss_name=value) + + def getLossName(self): + return self.getOrDefault(self.loss_name) + + def setOptimizerName(self, value): + return self._set(optimizer_name=value) + + def getOptimizerName(self): + return self.getOrDefault(self.optimizer_name) + + def setTokenizer(self, value): + return self._set(tokenizer=value) + + def getTokenizer(self): + return self.getOrDefault(self.tokenizer) + + def setMaxTokenLen(self, value): + return self._set(max_token_len=value) + + def getMaxTokenLen(self): + return self.getOrDefault(self.max_token_len) + + def setLearningRate(self, value): + return self._set(learning_rate=value) + + def getLearningRate(self): + return self.getOrDefault(self.learning_rate) + + def setTrainFromScratch(self, value): + return self._set(train_from_scratch=value) + + def getTrainFromScratch(self): + return self.getOrDefault(self.train_from_scratch) + + def _update_cols(self): + self.setFeatureCols([self.getTextCol()]) + self.setLabelCols([self.getLabelCol()]) + + def _fit(self, dataset): + return super()._fit(dataset) + + # override this method to provide a correct default backend + def _get_or_create_backend(self): + return get_or_create_backend( + self.getBackend(), self.getNumProc(), self.getVerbose(), self.getUseGpu() + ) + + def _update_transformation_fn(self): + + text_col = self.getTextCol() + label_col = self.getLabelCol() + max_token_len = self.getMaxTokenLen() + # load it inside to avoid `Already borrowed` error (https://github.com/huggingface/tokenizers/issues/537) + if self.getTokenizer() is None: + self.setTokenizer(AutoTokenizer.from_pretrained(self.getCheckpoint())) + tokenizer = self.getTokenizer() + + def _encoding_text(row): + text = row[text_col] + label = row[label_col] + encoding = tokenizer( + text, + max_length=max_token_len, + padding="max_length", + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + input_ids = encoding["input_ids"].flatten().numpy() + attention_mask = encoding["attention_mask"].flatten().numpy() + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": torch.tensor(label, dtype=int), + } + + transformation_edit_fields = [ + ("input_ids", int, None, True), + ("attention_mask", int, None, True), + ("labels", int, None, False), + ] + self.setTransformationEditFields(transformation_edit_fields) + transformation_removed_fields = [self.getTextCol(), self.getLabelCol()] + self.setTransformationRemovedFields(transformation_removed_fields) + self.setTransformationFn(_encoding_text) + + def get_model_class(self): + return DeepTextModel + + def _get_model_kwargs(self, model, history, optimizer, run_id, metadata): + return dict( + history=history, + model=model, + optimizer=optimizer, + input_shapes=self.getInputShapes(), + run_id=run_id, + _metadata=metadata, + loss=self.getLoss(), + loss_constructors=self.getLossConstructors(), + tokenizer=self.getTokenizer(), + checkpoint=self.getCheckpoint(), + max_token_len=self.getMaxTokenLen(), + label_col=self.getLabelCol(), + text_col=self.getTextCol(), + prediction_col=self.getPredictionCol(), + ) diff --git a/deep-learning/src/main/python/synapse/ml/dl/DeepTextModel.py b/deep-learning/src/main/python/synapse/ml/dl/DeepTextModel.py new file mode 100644 index 0000000000..4bdc2c8d51 --- /dev/null +++ b/deep-learning/src/main/python/synapse/ml/dl/DeepTextModel.py @@ -0,0 +1,119 @@ +from horovod.spark.lightning import TorchModel +import numpy as np +import torch +from horovod.spark.lightning import TorchModel +from synapse.ml.dl.PredictionParams import TextPredictionParams +from pyspark.ml.param import Param, Params, TypeConverters +from pyspark.sql.functions import col, udf +from pyspark.sql.types import DoubleType +from synapse.ml.dl.utils import keywords_catch +from transformers import AutoTokenizer + + +class DeepTextModel(TorchModel, TextPredictionParams): + + tokenizer = Param(Params._dummy(), "tokenizer", "tokenizer") + + checkpoint = Param( + Params._dummy(), "checkpoint", "checkpoint of the deep text classifier" + ) + + max_token_len = Param(Params._dummy(), "max_token_len", "max_token_len") + + @keywords_catch + def __init__( + self, + history=None, + model=None, + input_shapes=None, + optimizer=None, + run_id=None, + _metadata=None, + loss=None, + loss_constructors=None, + # diff from horovod + checkpoint=None, + tokenizer=None, + max_token_len=128, + label_col="label", + text_col="text", + prediction_col="prediction", + ): + super(DeepTextModel, self).__init__() + + self._setDefault( + optimizer=None, + loss=None, + loss_constructors=None, + input_shapes=None, + checkpoint=None, + max_token_len=128, + text_col="text", + label_col="label", + prediction_col="prediction", + feature_columns=["text"], + label_columns=["label"], + outputCols=["output"], + ) + + kwargs = self._kwargs + self._set(**kwargs) + + def setTokenizer(self, value): + return self._set(tokenizer=value) + + def getTokenizer(self): + return self.getOrDefault(self.tokenizer) + + def setCheckpoint(self, value): + return self._set(checkpoint=value) + + def getCheckpoint(self): + return self.getOrDefault(self.checkpoint) + + def setMaxTokenLen(self, value): + return self._set(max_token_len=value) + + def getMaxTokenLen(self): + return self.getOrDefault(self.max_token_len) + + def _update_cols(self): + self.setFeatureColumns([self.getTextCol()]) + self.setLabelColoumns([self.getLabelCol()]) + + # override this to encoding text + def get_prediction_fn(self): + text_col = self.getTextCol() + max_token_len = self.getMaxTokenLen() + tokenizer = self.getTokenizer() + + def predict_fn(model, row): + text = row[text_col] + data = tokenizer( + text, + max_length=max_token_len, + padding="max_length", + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + with torch.no_grad(): + outputs = model(**data) + pred = torch.nn.functional.softmax(outputs.logits, dim=-1) + + return pred + + return predict_fn + + # pytorch_lightning module has its own optimizer configuration + def getOptimizer(self): + return None + + def _transform(self, df): + self._update_cols() + output_df = super()._transform(df) + argmax = udf(lambda v: float(np.argmax(v)), returnType=DoubleType()) + pred_df = output_df.withColumn( + self.getPredictionCol(), argmax(col(self.getOutputCols()[0])) + ) + return pred_df diff --git a/deep-learning/src/main/python/synapse/ml/dl/DeepVisionClassifier.py b/deep-learning/src/main/python/synapse/ml/dl/DeepVisionClassifier.py index 3a29483caa..2968fbd7a8 100644 --- a/deep-learning/src/main/python/synapse/ml/dl/DeepVisionClassifier.py +++ b/deep-learning/src/main/python/synapse/ml/dl/DeepVisionClassifier.py @@ -12,8 +12,8 @@ from pytorch_lightning.utilities import _module_available from synapse.ml.dl.DeepVisionModel import DeepVisionModel from synapse.ml.dl.LitDeepVisionModel import LitDeepVisionModel -from synapse.ml.dl.utils import keywords_catch -from synapse.ml.dl.PredictionParams import PredictionParams +from synapse.ml.dl.utils import keywords_catch, get_or_create_backend +from synapse.ml.dl.PredictionParams import VisionPredictionParams _HOROVOD_AVAILABLE = _module_available("horovod") if _HOROVOD_AVAILABLE: @@ -28,7 +28,7 @@ raise ModuleNotFoundError("module not found: horovod") -class DeepVisionClassifier(TorchEstimator, PredictionParams): +class DeepVisionClassifier(TorchEstimator, VisionPredictionParams): backbone = Param( Params._dummy(), "backbone", "backbone of the deep vision classifier" @@ -217,30 +217,9 @@ def _fit(self, dataset): # override this method to provide a correct default backend def _get_or_create_backend(self): - backend = self.getBackend() - num_proc = self.getNumProc() - if backend is None: - if num_proc is None: - num_proc = self._find_num_proc() - backend = SparkBackend( - num_proc, - stdout=sys.stdout, - stderr=sys.stderr, - prefix_output_with_timestamp=True, - verbose=self.getVerbose(), - ) - elif num_proc is not None: - raise ValueError( - 'At most one of parameters "backend" and "num_proc" may be specified' - ) - return backend - - def _find_num_proc(self): - if self.getUseGpu(): - # set it as number of executors for now (ignoring num_gpus per executor) - sc = SparkContext.getOrCreate() - return sc._jsc.sc().getExecutorMemoryStatus().size() - 1 - return None + return get_or_create_backend( + self.getBackend(), self.getNumProc(), self.getVerbose(), self.getUseGpu() + ) def _update_transformation_fn(self): if self.getTransformationFn() is None: @@ -258,23 +237,18 @@ def _update_transformation_fn(self): ) self.setTransformFn(transform) - def _create_transform_row(image_col, label_col, transform): - def _transform_row(row): - path = row[image_col] - label = row[label_col] - image = Image.open(path).convert("RGB") - image = transform(image).numpy() - return {image_col: image, label_col: label} - - return _transform_row - - self.setTransformationFn( - _create_transform_row( - self.getImageCol(), - self.getLabelCol(), - self.getTransformFn(), - ) - ) + image_col = self.getImageCol() + label_col = self.getLabelCol() + transform = self.getTransformFn() + + def _transform_row(row): + path = row[image_col] + label = row[label_col] + image = Image.open(path).convert("RGB") + image = transform(image).numpy() + return {image_col: image, label_col: label} + + self.setTransformationFn(_transform_row) def get_model_class(self): return DeepVisionModel diff --git a/deep-learning/src/main/python/synapse/ml/dl/DeepVisionModel.py b/deep-learning/src/main/python/synapse/ml/dl/DeepVisionModel.py index 7f35112df4..1fa67dcb4f 100644 --- a/deep-learning/src/main/python/synapse/ml/dl/DeepVisionModel.py +++ b/deep-learning/src/main/python/synapse/ml/dl/DeepVisionModel.py @@ -6,14 +6,14 @@ import torchvision.transforms as transforms from horovod.spark.lightning import TorchModel from PIL import Image -from synapse.ml.dl.PredictionParams import PredictionParams +from synapse.ml.dl.PredictionParams import VisionPredictionParams from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql.functions import col, udf from pyspark.sql.types import DoubleType from synapse.ml.dl.utils import keywords_catch -class DeepVisionModel(TorchModel, PredictionParams): +class DeepVisionModel(TorchModel, VisionPredictionParams): transform_fn = Param( Params._dummy(), @@ -56,8 +56,6 @@ def __init__( kwargs = self._kwargs self._set(**kwargs) - self._update_transform_fn() - self._update_cols() def setTransformFn(self, value): return self._set(transform_fn=value) @@ -93,29 +91,29 @@ def _update_cols(self): def get_prediction_fn(self): input_shape = self.getInputShapes()[0] image_col = self.getImageCol() + transform = self.getTransformFn() - def _create_predict_fn(transform): - def predict_fn(model, row): - if type(row[image_col]) == str: - image = Image.open(row[image_col]).convert("RGB") - data = torch.tensor(transform(image).numpy()).reshape(input_shape) - else: - data = torch.tensor([row[image_col]]).reshape(input_shape) - - with torch.no_grad(): - pred = model(data) + def predict_fn(model, row): + if type(row[image_col]) == str: + image = Image.open(row[image_col]).convert("RGB") + data = torch.tensor(transform(image).numpy()).reshape(input_shape) + else: + data = torch.tensor([row[image_col]]).reshape(input_shape) - return pred + with torch.no_grad(): + pred = model(data) - return predict_fn + return pred - return _create_predict_fn(self.getTransformFn()) + return predict_fn # pytorch_lightning module has its own optimizer configuration def getOptimizer(self): return None def _transform(self, df): + self._update_transform_fn() + self._update_cols() output_df = super()._transform(df) argmax = udf(lambda v: float(np.argmax(v)), returnType=DoubleType()) pred_df = output_df.withColumn( diff --git a/deep-learning/src/main/python/synapse/ml/dl/LitDeepTextModel.py b/deep-learning/src/main/python/synapse/ml/dl/LitDeepTextModel.py new file mode 100644 index 0000000000..2283281c0b --- /dev/null +++ b/deep-learning/src/main/python/synapse/ml/dl/LitDeepTextModel.py @@ -0,0 +1,176 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +import inspect + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +import torch.optim as optim +from pytorch_lightning.utilities import _module_available + +_TRANSFORMERS_AVAILABLE = _module_available("transformers") +if _TRANSFORMERS_AVAILABLE: + import transformers + + _TRANSFORMERS_EQUAL_4_15_0 = transformers.__version__ == "4.15.0" + if _TRANSFORMERS_EQUAL_4_15_0: + from transformers import AutoModelForSequenceClassification + else: + raise RuntimeError( + "transformers should be == 4.15.0, found: {}".format( + transformers.__version__ + ) + ) +else: + raise ModuleNotFoundError("module not found: transformers") + + +class LitDeepTextModel(pl.LightningModule): + def __init__( + self, + checkpoint, + text_col, + label_col, + num_labels, + additional_layers_to_train, + optimizer_name, + loss_name, + learning_rate=None, + train_from_scratch=True, + ): + """ + :param checkpoint: Checkpoint for pre-trained model. This is expected to + be a checkpoint you could find on [HuggingFace](https://huggingface.co/models) + and is of type `AutoModelForSequenceClassification`. + :param text_col: Text column name. + :param label_col: Label column name. + :param num_labels: Number of labels for classification. + :param additional_layers_to_train: Additional number of layers to train on. For Deep text model + we'd better choose a positive number for better performance. + :param optimizer_name: Name of the optimizer. + :param loss_name: Name of the loss function. + :param learning_rate: Learning rate for the optimizer. + :param train_from_scratch: Whether train the model from scratch or not. If this is set to true then + additional_layers_to_train param will be ignored. Default to True. + """ + super(LitDeepTextModel, self).__init__() + + self.checkpoint = checkpoint + self.text_col = text_col + self.label_col = label_col + self.num_labels = num_labels + self.additional_layers_to_train = additional_layers_to_train + self.optimizer_name = optimizer_name + self.loss_name = loss_name + self.learning_rate = learning_rate + self.train_from_scratch = train_from_scratch + + self._check_params() + + self.save_hyperparameters( + "checkpoint", + "text_col", + "label_col", + "num_labels", + "additional_layers_to_train", + "optimizer_name", + "loss_name", + "learning_rate", + "train_from_scratch", + ) + + def _check_params(self): + try: + # TODO: Add other types of models here + self.model = AutoModelForSequenceClassification.from_pretrained( + self.checkpoint, num_labels=self.num_labels + ) + self._update_learning_rate() + except Exception as err: + raise ValueError( + f"No checkpoint {self.checkpoint} found: {err=}, {type(err)=}" + ) + + if self.loss_name.lower() not in F.__dict__: + raise ValueError("No loss function: {} found".format(self.loss_name)) + self.loss_fn = F.__dict__[self.loss_name.lower()] + + optimizers_mapping = { + key.lower(): value + for key, value in optim.__dict__.items() + if inspect.isclass(value) and issubclass(value, optim.Optimizer) + } + if self.optimizer_name.lower() not in optimizers_mapping: + raise ValueError("No optimizer: {} found".format(self.optimizer_name)) + self.optimizer_fn = optimizers_mapping[self.optimizer_name.lower()] + + def forward(self, **inputs): + return self.model(**inputs) + + def configure_optimizers(self): + if not self.train_from_scratch: + # Freeze those weights + for p in self.model.base_model.parameters(): + p.requires_grad = False + self._fine_tune_layers() + params_to_update = filter(lambda p: p.requires_grad, self.model.parameters()) + return self.optimizer_fn(params_to_update, self.learning_rate) + + def _fine_tune_layers(self): + if self.additional_layers_to_train < 0: + raise ValueError( + "additional_layers_to_train has to be non-negative: {} found".format( + self.additional_layers_to_train + ) + ) + # base_model contains the real model to fine tune + children = list(self.model.base_model.children()) + added_layer, cur_layer = 0, -1 + while added_layer < self.additional_layers_to_train and -cur_layer < len( + children + ): + tunable = False + for p in children[cur_layer].parameters(): + p.requires_grad = True + tunable = True + # only tune those layers contain parameters + if tunable: + added_layer += 1 + cur_layer -= 1 + + def _update_learning_rate(self): + ## TODO: add more default values for different models + if not self.learning_rate: + if "bert" in self.checkpoint: + self.learning_rate = 5e-5 + else: + self.learning_rate = 0.01 + + def training_step(self, batch, batch_idx): + loss = self._step(batch, False) + self.log("train_loss", loss) + return loss + + def _step(self, batch, validation): + inputs = batch + outputs = self(**inputs) + loss = outputs.loss + return loss + + def validation_step(self, batch, batch_idx): + loss = self._step(batch, True) + self.log("val_loss", loss) + + def validation_epoch_end(self, outputs): + avg_loss = ( + torch.stack([x["val_loss"] for x in outputs]).mean() + if len(outputs) > 0 + else float("inf") + ) + self.log("avg_val_loss", avg_loss) + + def test_step(self, batch, batch_idx): + loss = self._step(batch, False) + self.log("test_loss", loss) + return loss diff --git a/deep-learning/src/main/python/synapse/ml/dl/PredictionParams.py b/deep-learning/src/main/python/synapse/ml/dl/PredictionParams.py index e4c68845ec..c2e7782790 100644 --- a/deep-learning/src/main/python/synapse/ml/dl/PredictionParams.py +++ b/deep-learning/src/main/python/synapse/ml/dl/PredictionParams.py @@ -4,7 +4,7 @@ from pyspark.ml.param import Param, Params, TypeConverters -class PredictionParams(Params): +class HasLabelColParam(Params): label_col = Param( Params._dummy(), @@ -13,25 +13,9 @@ class PredictionParams(Params): typeConverter=TypeConverters.toString, ) - image_col = Param( - Params._dummy(), - "image_col", - "image column name.", - typeConverter=TypeConverters.toString, - ) - - prediction_col = Param( - Params._dummy(), - "prediction_col", - "prediction column name.", - typeConverter=TypeConverters.toString, - ) - def __init__(self): - super(PredictionParams, self).__init__() - self._setDefault( - label_col="label", image_col="image", prediction_col="prediction" - ) + super(HasLabelColParam, self).__init__() + self._setDefault(label_col="label") def setLabelCol(self, value): """ @@ -45,6 +29,20 @@ def getLabelCol(self): """ return self.getOrDefault(self.label_col) + +class HasImageColParam(Params): + + image_col = Param( + Params._dummy(), + "image_col", + "image column name.", + typeConverter=TypeConverters.toString, + ) + + def __init__(self): + super(HasImageColParam, self).__init__() + self._setDefault(image_col="image") + def setImageCol(self, value): """ Sets the value of :py:attr:`image_col`. @@ -57,6 +55,47 @@ def getImageCol(self): """ return self.getOrDefault(self.image_col) + +## TODO: Potentially generalize to support multiple text columns as input +class HasTextColParam(Params): + + text_col = Param( + Params._dummy(), + "text_col", + "text column name.", + typeConverter=TypeConverters.toString, + ) + + def __init__(self): + super(HasTextColParam, self).__init__() + self._setDefault(text_col="text") + + def setTextCol(self, value): + """ + Sets the value of :py:attr:`text_col`. + """ + return self._set(text_col=value) + + def getTextCol(self): + """ + Gets the value of text_col or its default value. + """ + return self.getOrDefault(self.text_col) + + +class HasPredictionColParam(Params): + + prediction_col = Param( + Params._dummy(), + "prediction_col", + "prediction column name.", + typeConverter=TypeConverters.toString, + ) + + def __init__(self): + super(HasPredictionColParam, self).__init__() + self._setDefault(prediction_col="prediction") + def setPredictionCol(self, value): """ Sets the value of :py:attr:`prediction_col`. @@ -68,3 +107,13 @@ def getPredictionCol(self): Gets the value of prediction_col or its default value. """ return self.getOrDefault(self.prediction_col) + + +class VisionPredictionParams(HasLabelColParam, HasImageColParam, HasPredictionColParam): + def __init__(self): + super(VisionPredictionParams, self).__init__() + + +class TextPredictionParams(HasLabelColParam, HasTextColParam, HasPredictionColParam): + def __init__(self): + super(TextPredictionParams, self).__init__() diff --git a/deep-learning/src/main/python/synapse/ml/dl/__init__.py b/deep-learning/src/main/python/synapse/ml/dl/__init__.py index 1a8ad55d9e..7d97ae576d 100644 --- a/deep-learning/src/main/python/synapse/ml/dl/__init__.py +++ b/deep-learning/src/main/python/synapse/ml/dl/__init__.py @@ -1,6 +1,9 @@ # Copyright (C) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See LICENSE in project root for information. +from synapse.ml.dl.DeepTextClassifier import * +from synapse.ml.dl.DeepTextModel import * from synapse.ml.dl.DeepVisionClassifier import * from synapse.ml.dl.DeepVisionModel import * +from synapse.ml.dl.LitDeepTextModel import * from synapse.ml.dl.LitDeepVisionModel import * diff --git a/deep-learning/src/main/python/synapse/ml/dl/utils.py b/deep-learning/src/main/python/synapse/ml/dl/utils.py index 956d98541c..e9fca931eb 100644 --- a/deep-learning/src/main/python/synapse/ml/dl/utils.py +++ b/deep-learning/src/main/python/synapse/ml/dl/utils.py @@ -1,7 +1,11 @@ # Copyright (C) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See LICENSE in project root for information. +import sys + from functools import wraps +from horovod.spark.common.backend import SparkBackend +from pyspark.context import SparkContext def keywords_catch(func): @@ -22,3 +26,29 @@ def wrapper(self, *args, **kwargs): return func(self, **kwargs) return wrapper + + +def get_or_create_backend(backend, num_proc, verbose, use_gpu): + if backend is None: + if num_proc is None: + num_proc = _find_num_proc(use_gpu) + backend = SparkBackend( + num_proc, + stdout=sys.stdout, + stderr=sys.stderr, + prefix_output_with_timestamp=True, + verbose=verbose, + ) + elif num_proc is not None: + raise ValueError( + 'At most one of parameters "backend" and "num_proc" may be specified' + ) + return backend + + +def _find_num_proc(use_gpu): + if use_gpu: + # set it as number of executors for now (ignoring num_gpus per executor) + sc = SparkContext.getOrCreate() + return sc._jsc.sc().getExecutorMemoryStatus().size() - 1 + return None diff --git a/deep-learning/src/test/python/synapsemltest/dl/conftest.py b/deep-learning/src/test/python/synapsemltest/dl/conftest.py index 200dd6b8f3..1542168271 100644 --- a/deep-learning/src/test/python/synapsemltest/dl/conftest.py +++ b/deep-learning/src/test/python/synapsemltest/dl/conftest.py @@ -8,8 +8,10 @@ from os.path import join import numpy as np +import pandas as pd import pytest import torchvision.transforms as transforms +from pyspark.ml.feature import StringIndexer IS_WINDOWS = os.name == "nt" delimiter = "\\" if IS_WINDOWS else "/" @@ -19,6 +21,14 @@ ) +class CallbackBackend(object): + def run(self, fn, args=(), kwargs={}, env={}): + return [fn(*args, **kwargs)] * self.num_processes() + + def num_processes(self): + return 1 + + def _download_dataset(): urllib.request.urlretrieve( @@ -82,3 +92,20 @@ def transform(): ] ) return transform + + +def _prepare_text_data(spark): + df = ( + spark.read.format("csv") + .option("header", "true") + .load( + "wasbs://publicwasb@mmlspark.blob.core.windows.net/text_classification/Emotion_classification.csv" + ) + ) + indexer = StringIndexer(inputCol="Emotion", outputCol="label") + indexer_model = indexer.fit(df) + df = indexer_model.transform(df).drop("Emotion") + + train_df, test_df = df.randomSplit([0.85, 0.15], seed=1) + + return train_df, test_df diff --git a/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_classifier.py b/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_classifier.py new file mode 100644 index 0000000000..4cd464f34a --- /dev/null +++ b/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_classifier.py @@ -0,0 +1,55 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +from pyspark.sql import SparkSession +from pytorch_lightning.callbacks import ModelCheckpoint +from synapse.ml.dl import * + +import pytest +from .conftest import CallbackBackend, _prepare_text_data +from .test_deep_vision_classifier import local_store +from .test_deep_vision_model import MyDummyCallback + + +@pytest.mark.skip(reason="not testing this for now") +def test_bert_base_cased(): + spark = SparkSession.builder.master("local[*]").getOrCreate() + + train_df, test_df = _prepare_text_data(spark) + + ctx = CallbackBackend() + + epochs = 2 + callbacks = [ + MyDummyCallback(epochs), + ModelCheckpoint(dirpath="target/bert_base_uncased/"), + ] + + with local_store() as store: + + checkpoint = "bert-base-uncased" + + deep_text_classifier = DeepTextClassifier( + checkpoint=checkpoint, + store=store, + backend=ctx, + callbacks=callbacks, + num_classes=6, + batch_size=16, + epochs=epochs, + validation=0.1, + text_col="Text", + transformation_removed_fields=["Text", "Emotion", "label"], + ) + + deep_text_model = deep_text_classifier.fit(train_df) + + pred_df = deep_text_model.transform(test_df) + evaluator = MulticlassClassificationEvaluator( + predictionCol="prediction", labelCol="label", metricName="accuracy" + ) + accuracy = evaluator.evaluate(pred_df) + assert accuracy > 0.5 + + spark.stop() diff --git a/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_model.py b/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_model.py new file mode 100644 index 0000000000..7466e7dbec --- /dev/null +++ b/deep-learning/src/test/python/synapsemltest/dl/test_deep_text_model.py @@ -0,0 +1,82 @@ +# Copyright (C) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE in project root for information. + +import pytest +import torch +from pyspark.sql import SparkSession +from pytorch_lightning import Trainer +from torch.utils.data import DataLoader, Dataset +from transformers import AutoTokenizer + +from .conftest import _prepare_text_data +from .test_deep_vision_model import MyDummyCallback + +from synapse.ml.dl import * + + +@pytest.mark.skip(reason="skip this as it takes too long") +def test_lit_deep_text_model(): + class TextDataset(Dataset): + def __init__(self, data, tokenizer, max_token_len): + super(TextDataset, self).__init__() + self.data = data + self.tokenizer = tokenizer + self.max_token_len = max_token_len + + def __getitem__(self, index): + text = self.data["Text"][index] + label = self.data["label"][index] + encoding = self.tokenizer( + text, + truncation=True, + padding="max_length", + max_length=self.max_token_len, + return_tensors="pt", + ) + return { + "input_ids": encoding["input_ids"].flatten(), + "attention_mask": encoding["attention_mask"].flatten(), + "labels": torch.tensor([label], dtype=int), + } + + def __len__(self): + return len(self.data) + + spark = SparkSession.builder.master("local[*]").getOrCreate() + train_df, test_df = _prepare_text_data(spark) + + checkpoint = "bert-base-cased" + tokenizer = AutoTokenizer.from_pretrained(checkpoint) + max_token_len = 128 + + train_loader = DataLoader( + TextDataset(train_df.toPandas(), tokenizer, max_token_len), + batch_size=16, + shuffle=True, + num_workers=0, + pin_memory=True, + ) + + test_loader = DataLoader( + TextDataset(test_df.toPandas(), tokenizer, max_token_len), + batch_size=16, + shuffle=True, + num_workers=0, + pin_memory=True, + ) + + epochs = 1 + model = LitDeepTextModel( + checkpoint=checkpoint, + additional_layers_to_train=10, + num_labels=6, + optimizer_name="adam", + loss_name="cross_entropy", + label_col="label", + text_col="Text", + ) + + callbacks = [MyDummyCallback(epochs)] + trainer = Trainer(callbacks=callbacks, max_epochs=epochs) + trainer.fit(model, train_dataloaders=train_loader) + trainer.test(model, dataloaders=test_loader) diff --git a/deep-learning/src/test/python/synapsemltest/dl/test_deep_vision_classifier.py b/deep-learning/src/test/python/synapsemltest/dl/test_deep_vision_classifier.py index 20e1d45d10..24dac4d80f 100644 --- a/deep-learning/src/test/python/synapsemltest/dl/test_deep_vision_classifier.py +++ b/deep-learning/src/test/python/synapsemltest/dl/test_deep_vision_classifier.py @@ -15,6 +15,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from synapse.ml.dl import * +from .conftest import CallbackBackend from .test_deep_vision_model import MyDummyCallback @@ -64,14 +65,6 @@ def extract_path_and_label(path): return train_df, test_df -class CallbackBackend(object): - def run(self, fn, args=(), kwargs={}, env={}): - return [fn(*args, **kwargs)] * self.num_processes() - - def num_processes(self): - return 1 - - @pytest.mark.skip(reason="not testing this for now") def test_mobilenet_v2(get_data_path): spark = SparkSession.builder.master("local[*]").getOrCreate() diff --git a/environment.yml b/environment.yml index 4da964a9eb..8b573642b8 100644 --- a/environment.yml +++ b/environment.yml @@ -35,3 +35,5 @@ dependencies: - onnxmltools==1.7.0 - matplotlib - Pillow + - transformers==4.15.0 + - huggingface-hub>=0.8.1 diff --git a/notebooks/features/simple_deep_learning/DeepLearning - Deep Text Classification.ipynb b/notebooks/features/simple_deep_learning/DeepLearning - Deep Text Classification.ipynb new file mode 100644 index 0000000000..9e237e9b3b --- /dev/null +++ b/notebooks/features/simple_deep_learning/DeepLearning - Deep Text Classification.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "a4b6a348-5155-4665-9616-3776bea40ff0", + "showTitle": false, + "title": "" + } + }, + "source": [ + "## Deep Learning - Deep Text Classifier" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "910eee89-ded8-4c36-90ae-e9b8539c5773", + "showTitle": false, + "title": "" + } + }, + "source": [ + "### Environment Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "60a84fca-38ae-48dc-826a-1cc2011c3977", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "# install cloudpickle 2.0.0 to add synapse module for usage of horovod\n", + "%pip install cloudpickle==2.0.0 --force-reinstall --no-deps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "cd1e438b-4b6e-4d92-8cd4-0c184afe0721", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "import synapse\n", + "import cloudpickle\n", + "\n", + "cloudpickle.register_pickle_by_value(synapse)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "29b27e85-09c0-4e5f-8c58-af3a2bc9d373", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "! horovodrun --check-build" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "205531d2-6c06-49b4-828a-6f207371830b", + "showTitle": false, + "title": "" + } + }, + "source": [ + "### Read Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import urllib\n", + "\n", + "urllib.request.urlretrieve(\n", + " \"https://mmlspark.blob.core.windows.net/publicwasb/text_classification/Emotion_classification.csv\",\n", + " \"/tmp/Emotion_classification.csv\",\n", + ")\n", + "\n", + "import pandas as pd\n", + "from pyspark.ml.feature import StringIndexer\n", + "\n", + "df = pd.read_csv(\"/tmp/Emotion_classification.csv\")\n", + "df = spark.createDataFrame(df)\n", + "\n", + "indexer = StringIndexer(inputCol=\"Emotion\", outputCol=\"label\")\n", + "indexer_model = indexer.fit(df)\n", + "df = indexer_model.transform(df).drop((\"Emotion\"))\n", + "\n", + "train_df, test_df = df.randomSplit([0.85, 0.15], seed=1)\n", + "display(train_df)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "bc46d0f5-86b6-409d-b6f9-e3deae631d50", + "showTitle": false, + "title": "" + } + }, + "source": [ + "### Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "1f6b513c-606b-4e32-b75e-2baaf19a11d9", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "from horovod.spark.common.store import DBFSLocalStore\n", + "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "from synapse.ml.dl import *\n", + "\n", + "checkpoint = \"bert-base-uncased\"\n", + "run_output_dir = f\"/dbfs/FileStore/test/{checkpoint}\"\n", + "store = DBFSLocalStore(run_output_dir)\n", + "\n", + "epochs = 1\n", + "\n", + "callbacks = [ModelCheckpoint(filename=\"{epoch}-{train_loss:.2f}\")]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "9450b5fe-ab0d-4f73-8eb2-f2428ad88b4e", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "deep_text_classifier = DeepTextClassifier(\n", + " checkpoint=checkpoint,\n", + " store=store,\n", + " callbacks=callbacks,\n", + " num_classes=6,\n", + " batch_size=16,\n", + " epochs=epochs,\n", + " validation=0.1,\n", + " text_col=\"Text\",\n", + ")\n", + "\n", + "deep_text_model = deep_text_classifier.fit(train_df.limit(6000).repartition(50))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "4168b8a6-330e-4a28-949a-16954e1ea757", + "showTitle": false, + "title": "" + } + }, + "source": [ + "### Prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "application/vnd.databricks.v1+cell": { + "inputWidgets": {}, + "nuid": "d6f97c75-b814-4138-a2a8-512bc85a6f65", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n", + "\n", + "pred_df = deep_text_model.transform(test_df.limit(500))\n", + "evaluator = MulticlassClassificationEvaluator(\n", + " predictionCol=\"prediction\", labelCol=\"label\", metricName=\"accuracy\"\n", + ")\n", + "print(\"Test accuracy:\", evaluator.evaluate(pred_df))" + ] + } + ], + "metadata": { + "application/vnd.databricks.v1+notebook": { + "dashboards": [], + "language": "python", + "notebookMetadata": { + "pythonIndentUnit": 2 + }, + "notebookName": "DeepLearning - Deep Text Classification", + "notebookOrigID": 4390929852015145, + "widgets": {} + }, + "kernelspec": { + "display_name": "Python 3.8.5 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.8.5" + }, + "vscode": { + "interpreter": { + "hash": "601a75c4c141f401603984f1538447337114e368c54c4d5b589ea94315afdca2" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/notebooks/features/simple_deep_learning/DeepLearning - Deep Vision Classifier.ipynb b/notebooks/features/simple_deep_learning/DeepLearning - Deep Vision Classification.ipynb similarity index 100% rename from notebooks/features/simple_deep_learning/DeepLearning - Deep Vision Classifier.ipynb rename to notebooks/features/simple_deep_learning/DeepLearning - Deep Vision Classification.ipynb diff --git a/templates/update_cli.yml b/templates/update_cli.yml index 4444074e53..67c845f847 100644 --- a/templates/update_cli.yml +++ b/templates/update_cli.yml @@ -1,7 +1,7 @@ steps: - task: UsePythonVersion@0 inputs: - versionSpec: '3.8.X' + versionSpec: '3.8' architecture: 'x64' - task: JavaToolInstaller@0 inputs: diff --git a/website/docs/features/simple_deep_learning/about.md b/website/docs/features/simple_deep_learning/about.md index f66fc886f4..6cb8c649d9 100644 --- a/website/docs/features/simple_deep_learning/about.md +++ b/website/docs/features/simple_deep_learning/about.md @@ -26,13 +26,13 @@ Coordinate: com.microsoft.azure:synapseml_2.12:SYNAPSEML_SCALA_VERSION Repository: https://mmlspark.azureedge.net/maven ``` :::note -If you install the jar package, you need to follow the first two cell of this [sample](./DeepLearning%20-%20Deep%20Vision%20Classifier.md/#environment-setup----reinstall-horovod-based-on-new-version-of-pytorch) +If you install the jar package, you need to follow the first two cell of this [sample](./DeepLearning%20-%20Deep%20Vision%20Classification.md/#environment-setup----reinstall-horovod-based-on-new-version-of-pytorch) to make horovod recognizing our module. ::: ## 3. Try our sample notebook -You could follow the rest of this [sample](./DeepLearning%20-%20Deep%20Vision%20Classifier.md) and have a try on your own dataset. +You could follow the rest of this [sample](./DeepLearning%20-%20Deep%20Vision%20Classification.md) and have a try on your own dataset. Supported models (`backbone` parameter for `DeepVisionClassifer`) should be string format of [torchvision supported models](https://github.com/pytorch/vision/blob/v0.12.0/torchvision/models/__init__.py); You could also check by running `backbone in torchvision.models.__dict__`.