From 5dcb0af9694a6a7048fbcf35f72ec83294803add Mon Sep 17 00:00:00 2001 From: Devin Ha Date: Tue, 11 Nov 2025 14:12:32 +0100 Subject: [PATCH 1/4] NerDLGraphChecker add missing setter on scala side --- .../johnsnowlabs/nlp/annotators/ner/dl/NerDLGraphChecker.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLGraphChecker.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLGraphChecker.scala index eab7cf5b2cf249..a97d286704f0e2 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLGraphChecker.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLGraphChecker.scala @@ -167,6 +167,9 @@ class NerDLGraphChecker(override val uid: String) /** @group getParam */ protected def getGraphFolder: Option[String] = get(graphFolder) + /** @group setParam */ + protected def setGraphFolder(graphFolder: String): this.type = set(graphFolder, graphFolder) + /** Extracts the graph hyperparameters from the training data (dataset). * * * @param dataset the training dataset From 60227def5f48eaa998bce0d997d8cdc3e056856c Mon Sep 17 00:00:00 2001 From: Devin Ha Date: Tue, 11 Nov 2025 14:12:38 +0100 Subject: [PATCH 2/4] Introduce NerDLDataLoader for NerDLApproach Threaded NerDLDataLoader fetches batches in the background while training is happening in NerDLApproach, reducing idle time in the driver thread. --- .../nlp/annotators/common/Tagged.scala | 25 +- .../nlp/annotators/ner/dl/NerDLApproach.scala | 72 ++-- .../nlp/training/NerDLDataLoader.scala | 314 ++++++++++++++++++ .../nlp/annotators/ner/dl/NerDLSpec.scala | 50 +++ .../nlp/training/NerDLDataLoaderTest.scala | 49 +++ 5 files changed, 482 insertions(+), 28 deletions(-) create mode 100644 src/main/scala/com/johnsnowlabs/nlp/training/NerDLDataLoader.scala create mode 100644 src/test/scala/com/johnsnowlabs/nlp/training/NerDLDataLoaderTest.scala diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/common/Tagged.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/common/Tagged.scala index b496d48428f8a7..525d5b044d8c3a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/common/Tagged.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/common/Tagged.scala @@ -118,7 +118,11 @@ trait Tagged[T >: TaggedSentence <: TaggedSentence] extends Annotated[T] { row.getAs[Seq[Row]](colNum).map(obj => Annotation(obj)) } - protected def getLabelsFromSentences( + def getAnnotations(row: Row, col: String): Seq[Annotation] = { + row.getAs[Seq[Row]](col).map(obj => Annotation(obj)) + } + + def getLabelsFromSentences( sentences: Seq[WordpieceEmbeddingsSentence], labelAnnotations: Seq[Annotation]): Seq[TextSentenceLabels] = { val sortedLabels = labelAnnotations.sortBy(a => a.begin).toArray @@ -203,16 +207,25 @@ object NerTagged extends Tagged[NerTaggedSentence] { dataset: Dataset[Row], sentenceCols: Seq[String], labelColumn: String, - batchSize: Int): Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = { + batchSize: Int, + shuffleInPartition: Boolean = true) + : Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = { new Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] { import com.johnsnowlabs.nlp.annotators.common.DatasetHelpers._ // Send batches, don't collect(), only keeping a single batch in memory anytime - val it: util.Iterator[Row] = dataset - .select(labelColumn, sentenceCols: _*) - .randomize // to improve training - .toLocalIterator() // Uses as much memory as the largest partition, potentially all data if not careful + val it: util.Iterator[Row] = { + val selected = dataset + .select(labelColumn, sentenceCols: _*) + ( + // to improve training + // NOTE: This might have implications on model performance, partitions are not shuffled + if (shuffleInPartition) selected.randomize + else + selected + ).toLocalIterator() // Uses as much memory as the largest partition, potentially all data if not careful + } // create a batch override def next(): Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)] = { diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala index ac73fa331eb21c..b1c839138ccd0d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala @@ -24,6 +24,7 @@ import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN, WORD_E import com.johnsnowlabs.nlp.annotators.common.{NerTagged, WordpieceEmbeddingsSentence} import com.johnsnowlabs.nlp.annotators.ner.{ModelMetrics, NerApproach, Verbose} import com.johnsnowlabs.nlp.annotators.param.EvaluationDLParams +import com.johnsnowlabs.nlp.training.NerDLDataLoader import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper} import com.johnsnowlabs.nlp.{AnnotatorApproach, AnnotatorType, ParamsAndFeaturesWritable} import com.johnsnowlabs.storage.HasStorageRef @@ -450,6 +451,14 @@ class NerDLApproach(override val uid: String) } + val prefetchBatches = new IntParam( + this, + "prefetchBatches", + "Number of batches to prefetch while training using memory optimizer. Has no effect if memory optimizer is disabled.") + + def getPrefetchBatches: Int = $(this.prefetchBatches) + def setPrefetchBatches(value: Int): this.type = set(this.prefetchBatches, value) + setDefault( minEpochs -> 0, maxEpochs -> 70, @@ -462,7 +471,8 @@ class NerDLApproach(override val uid: String) includeAllConfidenceScores -> false, enableMemoryOptimizer -> false, useBestModel -> false, - bestModelMetric -> ModelMetrics.loss) + bestModelMetric -> ModelMetrics.loss, + prefetchBatches -> 0) override val verboseLevel: Verbose.Level = Verbose($(verbose)) @@ -485,6 +495,24 @@ class NerDLApproach(override val uid: String) $(validationSplit) <= 1f | $(validationSplit) >= 0f, "The validationSplit must be between 0f and 1f") + def getIteratorFunc(split: Dataset[Row]) = if (!getEnableMemoryOptimizer) { + // No memory optimizer + NerDLApproach.getIteratorFunc( + split, + inputColumns = getInputCols, + labelColumn = $(labelColumn), + batchSize = $(batchSize), + enableMemoryOptimizer = $(enableMemoryOptimizer)) + } else { + logger.info(s"Using memory optimizer with $prefetchBatches prefetch batches.") + NerDLApproach.getIteratorFunc( + split, + inputColumns = getInputCols, + labelColumn = $(labelColumn), + batchSize = $(batchSize), + prefetchBatches = getPrefetchBatches) + } + val embeddingsRef = HasStorageRef.getStorageRefFromInput(dataset, $(inputCols), AnnotatorType.WORD_EMBEDDINGS) @@ -506,26 +534,10 @@ class NerDLApproach(override val uid: String) (cacheIfNeeded(trainSplit), cacheIfNeeded(validSplit), cacheIfNeeded(test)) } - val trainIteratorFunc = NerDLApproach.getIteratorFunc( - trainSplit, - inputColumns = getInputCols, - labelColumn = $(labelColumn), - batchSize = $(batchSize), - enableMemoryOptimizer = $(enableMemoryOptimizer)) - - val validIteratorFunc = NerDLApproach.getIteratorFunc( - validSplit, - inputColumns = getInputCols, - labelColumn = $(labelColumn), - batchSize = $(batchSize), - enableMemoryOptimizer = $(enableMemoryOptimizer)) - - val testIteratorFunc = NerDLApproach.getIteratorFunc( - test, - inputColumns = getInputCols, - labelColumn = $(labelColumn), - batchSize = $(batchSize), - enableMemoryOptimizer = $(enableMemoryOptimizer)) + // TODO DHA: Better way to do this? + val trainIteratorFunc = getIteratorFunc(trainSplit) + val validIteratorFunc = getIteratorFunc(validSplit) + val testIteratorFunc = getIteratorFunc(test) val ( labels: mutable.Set[AnnotatorType], @@ -752,8 +764,9 @@ object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraph : () => Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = { if (enableMemoryOptimizer) { () => + // Old implementation, kept for backward compatibility but won't be called from NerDLApproach.train + // NerDLDataLoader will be used with memory optimizer NerTagged.iterateOnDataframe(dataset, inputColumns, labelColumn, batchSize) - } else { val inMemory = dataset .select(labelColumn, inputColumns.toSeq: _*) @@ -763,6 +776,21 @@ object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraph } } + def getIteratorFunc( + dataset: Dataset[Row], + inputColumns: Array[String], + labelColumn: String, + batchSize: Int, + prefetchBatches: Int) + : () => Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = { () => + NerDLDataLoader.iterateOnDataframe( + dataset = dataset, + inputColumns = inputColumns, + labelColumn = labelColumn, + batchSize = batchSize, + prefetchBatches = prefetchBatches) + } + def getDataSetParams(dsIt: Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]]) : (mutable.Set[String], mutable.Set[Char], Int, Long) = { diff --git a/src/main/scala/com/johnsnowlabs/nlp/training/NerDLDataLoader.scala b/src/main/scala/com/johnsnowlabs/nlp/training/NerDLDataLoader.scala new file mode 100644 index 00000000000000..36eafd7a8e2596 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/training/NerDLDataLoader.scala @@ -0,0 +1,314 @@ +/* + * Copyright 2017-2025 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.training + +import com.johnsnowlabs.ml.crf.TextSentenceLabels +import com.johnsnowlabs.nlp.annotators.common.DatasetHelpers._ +import com.johnsnowlabs.nlp.annotators.common.NerTagged.{getAnnotations, getLabelsFromSentences} +import com.johnsnowlabs.nlp.annotators.common.WordpieceEmbeddingsSentence +import org.apache.spark.sql.{Dataset, Row} +import org.slf4j.LoggerFactory + +import java.util.concurrent.{ExecutorService, Executors, LinkedBlockingQueue, TimeUnit} +import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext +import scala.jdk.CollectionConverters._ + +/** Configuration for the NerDLDataLoader. + * + * @param batchSize + * Number of sentences per batch (default: 16) + * @param prefetchBatches + * Number of batches to prefetch per worker (default: 2). Total prefetch buffer size will be + * numWorkers * prefetchFactor + * @param shuffleInPartition + * Whether to shuffle the data (default: true). Improves training convergence. + * @param timeoutMillis + * Timeout in milliseconds for fetching a batch (default: 10000). Prevents hanging on slow + * operations. + */ +case class DataLoaderConfig( + batchSize: Int = 16, + prefetchBatches: Int = 20, + shuffleInPartition: Boolean = true, + timeoutMillis: Long = 10000) + +/** DataLoader for NerDLApproach with threaded prefetching. + * + * This class provides an efficient way to load training data for NER models by: + * - Prefetching batches in background threads to overlap I/O with computation + * - Using a bounded queue to prevent excessive memory usage + * + * @param config + * Configuration for the data loader + */ +class NerDLDataLoader(config: DataLoaderConfig = DataLoaderConfig()) { + import com.johnsnowlabs.nlp.util.io.ResourceHelper.spark.implicits._ + + @volatile private var isShutdown = false + private var executorService: Option[ExecutorService] = None + private var executionContext: Option[ExecutionContext] = None + private val logger = LoggerFactory.getLogger(this.getClass) + + /** Creates an iterator that prefetches and yields batches of NER training data from a Spark + * DataFrame. + * + * The iterator uses background threads to prefetch batches while the main thread consumes + * them, improving throughput by overlapping I/O with computation. + * + * @param dataset + * Spark DataFrame containing the training data + * @param inputCols + * TOKEN and EMBEDDING type input columns + * @param labelColumn + * Column name containing the NER labels + * @return + * Iterator over batches, where each batch is an Array of (labels, embeddings) pairs + */ + def createIterator( + dataset: Dataset[Row], + inputCols: Seq[String], + labelColumn: String): Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = { + if (config.prefetchBatches <= 0) { + // Single-threaded mode - no prefetching and directly consume from the dataset + createSourceBatchIterator(dataset, inputCols, labelColumn, config.batchSize) + } else { + // Threaded mode with prefetching + createThreadedIterator(dataset, inputCols, labelColumn) + } + } + + /** Creates an iterator over batches of NER training data directly from a Spark DataFrame using + * `toLocalIterator`. + * + * We should probably take from this iterator as much as possible (within RAM limits) to + * trigger partition computation across the cluster. + * + * @param dataset + * Spark DataFrame containing the training data + * @param inputCols + * TOKEN and EMBEDDING type input columns + * @param labelColumn + * Column name containing the NER labels + * @param batchSize + * Number of sentences per batch + * @return + * Iterator over batches, where each batch is an Array of (labels, embeddings) pairs + */ + private def createSourceBatchIterator( + dataset: Dataset[Row], + inputCols: Seq[String], + labelColumn: String, + batchSize: Int): Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = { + + def processPartition( + it: Iterator[Row]): Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = + new Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] { + // create a batch + override def next(): Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)] = { + var count = 0 + val thisBatch = new ArrayBuffer[(TextSentenceLabels, WordpieceEmbeddingsSentence)] + + while (it.hasNext && count < batchSize) { + count += 1 + val nextRow = it.next + + val labelAnnotations = getAnnotations(nextRow, labelColumn) + val sentenceAnnotations = + inputCols.flatMap(s => getAnnotations(nextRow, s)) + val sentences = WordpieceEmbeddingsSentence.unpack(sentenceAnnotations) + val labels = getLabelsFromSentences(sentences, labelAnnotations) + val thisOne = labels.zip(sentences) + + thisBatch ++= thisOne + } + thisBatch.toArray + } + + override def hasNext: Boolean = it.hasNext + } + + // Process each partition on worker nodes + val selected = dataset.select(labelColumn, inputCols: _*) + ( + // to improve training + // NOTE: This might have implications on model performance, partitions themselves are not shuffled + if (config.shuffleInPartition) selected.randomize + else + selected + ) + .mapPartitions(processPartition) // create batches in each partition + .as[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] + .toLocalIterator() + .asScala + } + + /** Worker runnable that loads batches and puts them in the queue. + * + * @param batchQueue + * Blocking queue to hold loaded batches + * @param sourceIterator + * Iterator over source batches + */ + private class BatchLoaderThread( + batchQueue: LinkedBlockingQueue[ + Option[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]]], + sourceIterator: Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]]) + extends Runnable { + private val logger = LoggerFactory.getLogger(this.getClass) + + override def run(): Unit = { + try { + while (!isShutdown && sourceIterator.hasNext) { + // Fetch the next batch + val batch = sourceIterator.next() + + // Offer to queue (blocking with timeout) + var offered = false + while (!offered && !isShutdown) { + offered = batchQueue.offer(Some(batch), config.timeoutMillis, TimeUnit.MILLISECONDS) + } + } + } catch { + case _: InterruptedException => + Thread.currentThread().interrupt() + case e: Exception => + logger.error(s"Fetcher Error: ${e.getMessage}") + e.printStackTrace() + } finally { + // Sentinel: Signal end of data + // Either due to completion or shutdown + try { + batchQueue.put(None) + } catch { + case _: InterruptedException => // Ignore during shutdown + } + } + } + + } + + private def createThreadedIterator( + dataset: Dataset[Row], + sentenceCols: Seq[String], + labelColumn: String): Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = { + + // Queue Capacity: holds completed batches. + val queueCapacity = config.prefetchBatches + val batchQueue = + new LinkedBlockingQueue[Option[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]]]( + queueCapacity) + + // Source data iterator + val sourceBatchIterator = + createSourceBatchIterator(dataset, sentenceCols, labelColumn, config.batchSize) + + // Create a producer thread for prefetching. + val executor = Executors.newSingleThreadExecutor() + executorService = Some(executor) + logger.info(s"Starting data loader thread with prefetch buffer size: $queueCapacity batches.") + executor.submit(new BatchLoaderThread(batchQueue, sourceBatchIterator)) + + // Consumer Iterator (Main Thread) + new BatchLoaderIterator(batchQueue) + } + + private class BatchLoaderIterator( + batchQueue: LinkedBlockingQueue[ + Option[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]]]) + extends Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] { + private var nextBatch: Option[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = + None + private var endOfData = false + + @tailrec + final override def hasNext: Boolean = { + if (endOfData) false + else if (nextBatch.isDefined) true // Already have a batch ready + else { + // Poll from queue in advance (to avoid blocking in next()) + val result = batchQueue.poll(config.timeoutMillis, TimeUnit.MILLISECONDS) + + result match { + case null => + // Timeout waiting for Spark. Training is faster than Data Loading. + // Wait for next batch. + if (isShutdown) false + else { + hasNext + } + case None => + endOfData = true // Signal: No more data + false + case Some(batch) => + nextBatch = Some(batch) + true + } + } + } + + override def next(): Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)] = { + if (!hasNext) throw new NoSuchElementException("No more batches") + val batch = nextBatch.get + nextBatch = None + batch + } + } + + /** Shuts down the data loader and releases all resources. + * + * This method should be called when there is still data but the loader is no longer needed to + * prevent resource leaks. It's safe to call multiple times. + */ + def shutdown(): Unit = { + isShutdown = true + executorService.foreach { executor => + executor.shutdown() + try { + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow() + } + } catch { + case _: InterruptedException => + executor.shutdownNow() + Thread.currentThread().interrupt() + } + } + executorService = None + executionContext = None + } +} + +/** Companion object providing factory methods for NerDLDataLoader. */ +object NerDLDataLoader { + def iterateOnDataframe( + dataset: Dataset[Row], + inputColumns: Array[String], + labelColumn: String, + batchSize: Int, + prefetchBatches: Int, + shuffleInPartition: Boolean = true) + : Iterator[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = { + new NerDLDataLoader( + DataLoaderConfig( + batchSize = batchSize, + prefetchBatches = prefetchBatches, + shuffleInPartition = shuffleInPartition)) + .createIterator(dataset, inputColumns, labelColumn) + } +} diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLSpec.scala index 33e03690781fcf..80588db8c97d17 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLSpec.scala @@ -415,4 +415,54 @@ class NerDLSpec extends AnyFlatSpec { "Embedding dim from metadata does not match expected dim") assert(dsLen == expectedDsLen, "Dataset length from metadata does not match expected length") } + + def trainRun(ner: NerDLApproach): Unit = { + val conll = CoNLL() + val training_data = conll.readDataset( + ResourceHelper.spark, + "src/test/resources/ner-corpus/test_ner_dataset.txt") + + val embeddings = AnnotatorBuilder.getGLoveEmbeddings(training_data.toDF()) + val trainData = embeddings.transform(training_data) + + ner.fit(trainData) + } + + "NerDLApproach" should "train with memory optimizer and prefetchBatches enabled" taggedAs SlowTest in { + val ner = new NerDLApproach() + .setInputCols("sentence", "token", "embeddings") + .setOutputCol("ner") + .setLabelColumn("label") + .setLr(1e-1f) + .setPo(5e-3f) + .setDropout(5e-1f) + .setMaxEpochs(2) + .setRandomSeed(0) + .setVerbose(0) + .setBatchSize(8) + .setEnableMemoryOptimizer(true) + .setPrefetchBatches(10) + .setGraphFolder("src/test/resources/graph/") + + trainRun(ner) + } + + "NerDLApproach" should "train with memory optimizer and optimizePartitioning disabled" taggedAs SlowTest in { + val ner = new NerDLApproach() + .setInputCols("sentence", "token", "embeddings") + .setOutputCol("ner") + .setLabelColumn("label") + .setLr(1e-1f) + .setPo(5e-3f) + .setDropout(5e-1f) + .setMaxEpochs(2) + .setRandomSeed(0) + .setVerbose(0) + .setBatchSize(8) + .setEnableMemoryOptimizer(true) + .setOptimizePartitioning(false) + .setGraphFolder("src/test/resources/graph/") + + trainRun(ner) + } } diff --git a/src/test/scala/com/johnsnowlabs/nlp/training/NerDLDataLoaderTest.scala b/src/test/scala/com/johnsnowlabs/nlp/training/NerDLDataLoaderTest.scala new file mode 100644 index 00000000000000..20b25fb71c3044 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/training/NerDLDataLoaderTest.scala @@ -0,0 +1,49 @@ +package com.johnsnowlabs.nlp.training + +import com.johnsnowlabs.ml.crf.TextSentenceLabels +import com.johnsnowlabs.nlp.annotators.SparkSessionTest +import com.johnsnowlabs.nlp.annotators.common.{NerTagged, WordpieceEmbeddingsSentence} +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.FastTest +import org.apache.spark.sql.{Dataset, Row} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should._ + +class NerDLDataLoaderTest extends AnyFlatSpec with SparkSessionTest with Matchers { + + lazy val textCols: Array[String] = Array("token", "sentence") + lazy val labelCol = "label" + lazy val data: Dataset[Row] = + CoNLL() + .readDataset(ResourceHelper.spark, "src/test/resources/ner-corpus/test_ner_dataset.txt") + .limit(100) + .select(labelCol, textCols: _*) + + val batchSize = 16 + + behavior of "NerDLDataLoader" + + it should "create same batches as non threaded iterator" taggedAs FastTest in { + + val expectedData: Array[Array[(TextSentenceLabels, WordpieceEmbeddingsSentence)]] = NerTagged + .iterateOnDataframe( + data, + textCols, + labelCol, + batchSize = batchSize, + shuffleInPartition = false) + .toArray + + val nerDLDataLoader = NerDLDataLoader.iterateOnDataframe( + data, + textCols, + labelCol, + batchSize = batchSize, + prefetchBatches = 10, + shuffleInPartition = false) + val loaderData = nerDLDataLoader.toArray + + loaderData.length shouldBe expectedData.length + loaderData should contain theSameElementsInOrderAs expectedData + } +} From 6580ad55a4ea2131f1637966ad7944660274e847 Mon Sep 17 00:00:00 2001 From: Devin Ha Date: Fri, 21 Nov 2025 17:36:08 +0100 Subject: [PATCH 3/4] NerDLApproach: Optimize partitioning flag Allow NerDLApproach to repartition the input dataset, so the driver does not go out of memory when training on large partitions. --- .../nlp/annotators/ner/dl/NerDLApproach.scala | 170 +++++++++++------- 1 file changed, 109 insertions(+), 61 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala index b1c839138ccd0d..eb6dde956a1762 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/NerDLApproach.scala @@ -26,7 +26,12 @@ import com.johnsnowlabs.nlp.annotators.ner.{ModelMetrics, NerApproach, Verbose} import com.johnsnowlabs.nlp.annotators.param.EvaluationDLParams import com.johnsnowlabs.nlp.training.NerDLDataLoader import com.johnsnowlabs.nlp.util.io.{OutputHelper, ResourceHelper} -import com.johnsnowlabs.nlp.{AnnotatorApproach, AnnotatorType, ParamsAndFeaturesWritable} +import com.johnsnowlabs.nlp.{ + Annotation, + AnnotatorApproach, + AnnotatorType, + ParamsAndFeaturesWritable +} import com.johnsnowlabs.storage.HasStorageRef import org.apache.commons.io.IOUtils import org.apache.commons.lang3.SystemUtils @@ -457,8 +462,26 @@ class NerDLApproach(override val uid: String) "Number of batches to prefetch while training using memory optimizer. Has no effect if memory optimizer is disabled.") def getPrefetchBatches: Int = $(this.prefetchBatches) + + /** Sets number of batches to prefetch while training using memory optimizer. Has no effect if + * memory optimizer is disabled. + * @group setParam + */ def setPrefetchBatches(value: Int): this.type = set(this.prefetchBatches, value) + val optimizePartitioning = new BooleanParam( + this, + "optimizePartitioning", + "Whether to repartition the dataset before training for optimal performance. Has no effect if memory optimizer is disabled.") + + def getOptimizePartitioning: Boolean = $(this.optimizePartitioning) + + /** Sets whether to repartition the dataset before training for optimal performance. Has no + * effect if memory optimizer is disabled. + * @group setParam + */ + def setOptimizePartitioning(value: Boolean): this.type = set(this.optimizePartitioning, value) + setDefault( minEpochs -> 0, maxEpochs -> 70, @@ -472,7 +495,8 @@ class NerDLApproach(override val uid: String) enableMemoryOptimizer -> false, useBestModel -> false, bestModelMetric -> ModelMetrics.loss, - prefetchBatches -> 0) + prefetchBatches -> 0, + optimizePartitioning -> true) override val verboseLevel: Verbose.Level = Verbose($(verbose)) @@ -488,6 +512,74 @@ class NerDLApproach(override val uid: String) LoadsContrib.loadContribToTensorflow() } + private def getIteratorFunc(split: Dataset[Row]) = if (!getEnableMemoryOptimizer) { + // No memory optimizer + NerDLApproach.getIteratorFunc( + split, + inputColumns = getInputCols, + labelColumn = $(labelColumn), + batchSize = $(batchSize), + enableMemoryOptimizer = $(enableMemoryOptimizer)) + } else { + logger.info(s"Using memory optimizer with $prefetchBatches prefetch batches.") + NerDLApproach.getIteratorFunc( + split, + inputColumns = getInputCols, + labelColumn = $(labelColumn), + batchSize = $(batchSize), + prefetchBatches = getPrefetchBatches) + } + + /** Extracts graph parameters and returns an optimized dataframe for training. + * + * @param dataset + * input dataset + * @return + * (labels, chars, embeddingsDim, dsLen, optimizedDataset) + */ + private def prepareData(dataset: Dataset[Row]) + : (mutable.Set[AnnotatorType], mutable.Set[Char], Int, Long, Dataset[Row]) = { + def optimizePartitioning(ds: Dataset[Row], dsLen: Long): Dataset[Row] = { + if (getEnableMemoryOptimizer && getOptimizePartitioning) { + // Repartition cachedDataset according to heuristic: + // Assume one row contains about 1 MB of data (BertEmbeddings), and spark recommends 100MB to 1GB partitions. + // We'll go for the middle ground of 500MB means that one partition should hold 500 rows + val numPartitions = math.ceil(dsLen / 500.0).toInt + logger.info( + s"Repartitioning input cachedDataset to $numPartitions partitions for NerDL training.") + ds.repartition(numPartitions) + } else ds + } + + val cachedDataset: Dataset[Row] = dataset.cache().toDF() + NerDLApproach.getDataSetParamsFromMetadata(cachedDataset, $(labelColumn)) match { + // metadata contains the length of the entire cachedDataset, so we can avoid a count() action + case Some( + ( + labels: mutable.Set[AnnotatorType], + chars: mutable.Set[Char], + embeddingsDim: Int, + dsLen: Long)) => + // Only repartition if using memory optimizer + val repartitionedDataset = optimizePartitioning(cachedDataset, dsLen) + (labels, chars, embeddingsDim, dsLen, repartitionedDataset) + case None => // Legacy way of getting cachedDataset params + logger.info("Dataset metadata does not contain graph parameters, extracting from data.") + val docColumn = + Annotation.getColumnByType(dataset, getInputCols, AnnotatorType.DOCUMENT).name + val dsLen = cachedDataset.selectExpr(s"explode($docColumn)").count() + // Repartition now, so we don't OOM when extracting params + val repartitionedDataset = optimizePartitioning(cachedDataset, dsLen) + val ( + labels: mutable.Set[AnnotatorType], + chars: mutable.Set[Char], + embeddingsDim: Int, + _) = + NerDLApproach.getDataSetParams(getIteratorFunc(repartitionedDataset)()) + (labels, chars, embeddingsDim, dsLen, repartitionedDataset) + } + } + override def train( dataset: Dataset[_], recursivePipeline: Option[PipelineModel]): NerDLModel = { @@ -495,79 +587,37 @@ class NerDLApproach(override val uid: String) $(validationSplit) <= 1f | $(validationSplit) >= 0f, "The validationSplit must be between 0f and 1f") - def getIteratorFunc(split: Dataset[Row]) = if (!getEnableMemoryOptimizer) { - // No memory optimizer - NerDLApproach.getIteratorFunc( - split, - inputColumns = getInputCols, - labelColumn = $(labelColumn), - batchSize = $(batchSize), - enableMemoryOptimizer = $(enableMemoryOptimizer)) - } else { - logger.info(s"Using memory optimizer with $prefetchBatches prefetch batches.") - NerDLApproach.getIteratorFunc( - split, - inputColumns = getInputCols, - labelColumn = $(labelColumn), - batchSize = $(batchSize), - prefetchBatches = getPrefetchBatches) - } - val embeddingsRef = HasStorageRef.getStorageRefFromInput(dataset, $(inputCols), AnnotatorType.WORD_EMBEDDINGS) + val ( + labels: mutable.Set[AnnotatorType], + chars: mutable.Set[Char], + embeddingsDim: Int, + dsLen: Long, + optimizedDataset: Dataset[Row]) = prepareData(dataset.toDF()) + val trainDsLen = math.round(dsLen * (1.0f - $(validationSplit))) + val valDsLen = dsLen - trainDsLen + // Get the data splits val (trainSplit: Dataset[Row], validSplit: Dataset[Row], test: Dataset[Row]) = { - def cacheIfNeeded(ds: Dataset[Row]): Dataset[Row] = - if (getEnableMemoryOptimizer && getMaxEpochs > 1) ds.cache() else ds - - val train = dataset.toDF() val (validSplit, trainSplit) = - train.randomSplit(Array($(validationSplit), 1.0f - $(validationSplit))) match { + optimizedDataset.randomSplit(Array($(validationSplit), 1.0f - $(validationSplit))) match { case Array(validSplit, trainSplit) => (validSplit, trainSplit) } val test = - if (!isDefined(testDataset)) train.limit(0) // keep the schema only + if (!isDefined(testDataset)) optimizedDataset.limit(0) // keep the schema only else ResourceHelper.readSparkDataFrame($(testDataset)) - (cacheIfNeeded(trainSplit), cacheIfNeeded(validSplit), cacheIfNeeded(test)) + (trainSplit, validSplit, test) } - // TODO DHA: Better way to do this? + // Get Iterators val trainIteratorFunc = getIteratorFunc(trainSplit) val validIteratorFunc = getIteratorFunc(validSplit) val testIteratorFunc = getIteratorFunc(test) - val ( - labels: mutable.Set[AnnotatorType], - chars: mutable.Set[Char], - embeddingsDim: Int, - trainDsLen: Long, - valDsLen: Long) = { - NerDLApproach.getDataSetParamsFromMetadata(trainSplit, $(labelColumn)) match { - // metadata contains the length of the entire dataset - case Some( - ( - labels: mutable.Set[AnnotatorType], - chars: mutable.Set[Char], - embeddingsDim: Int, - dsLen: Long)) => - val trainDsLen = math.round(dsLen * (1.0f - $(validationSplit))) - val valDsLen = dsLen - trainDsLen - (labels, chars, embeddingsDim, trainDsLen.toLong, valDsLen) - case None => // Legacy way of getting dataset params - val ( - labels: mutable.Set[AnnotatorType], - chars: mutable.Set[Char], - embeddingsDim: Int, - trainDsLen: Long) = NerDLApproach.getDataSetParams(trainIteratorFunc()) - val valDsLen: Long = - math.round(trainDsLen / (1 - $(validationSplit)) * $(validationSplit)) - (labels, chars, embeddingsDim, trainDsLen, valDsLen) - } - } - val settings = DatasetEncoderParams( labels.toList, chars.toList, @@ -641,9 +691,7 @@ class NerDLApproach(override val uid: String) if (get(configProtoBytes).isDefined) model.setConfigProtoBytes($(configProtoBytes)) - trainSplit.unpersist() - validSplit.unpersist() - test.unpersist() + optimizedDataset.unpersist() model } } @@ -860,7 +908,7 @@ object NerDLApproach extends DefaultParamsReadable[NerDLApproach] with WithGraph val chars = metadata.getStringArray(NerDLGraphCheckerModel.charsKey).map(_.head) val embeddingsDim = metadata.getLong(NerDLGraphCheckerModel.embeddingsDimKey).toInt val dsLen = metadata.getLong(NerDLGraphCheckerModel.dsLenKey) - logger.info(s"NerDLApproach: Found graph params in label column metadata:" + + logger.info(s"Found graph params in label column metadata:" + s" labels=${labels.length}, chars=${chars.length}, embeddingsDim=$embeddingsDim, dsLen=$dsLen") Some( From 631b35013cc022792f2396d51a47cd4a6f770168 Mon Sep 17 00:00:00 2001 From: Devin Ha Date: Mon, 24 Nov 2025 12:47:16 +0100 Subject: [PATCH 4/4] NerDL Optimizations python side --- python/sparknlp/annotator/ner/ner_dl.py | 34 ++++++++++- .../annotator/ner/ner_dl_approach_test.py | 59 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 python/test/annotator/ner/ner_dl_approach_test.py diff --git a/python/sparknlp/annotator/ner/ner_dl.py b/python/sparknlp/annotator/ner/ner_dl.py index 194468b16496c1..00c8bad88e2ea8 100755 --- a/python/sparknlp/annotator/ner/ner_dl.py +++ b/python/sparknlp/annotator/ner/ner_dl.py @@ -238,6 +238,14 @@ class NerDLApproach(AnnotatorApproach, NerApproach, EvaluationDLParams): "Whether to check F1 Micro-average or F1 Macro-average as a final metric for the best model.", TypeConverters.toString) + prefetchBatches = Param(Params._dummy(), "prefetchBatches", + "Number of batches to prefetch while training using memory optimizer. Has no effect if memory optimizer is disabled.", + TypeConverters.toInt) + + optimizePartitioning = Param(Params._dummy(), "optimizePartitioning", + "Whether to repartition the dataset before training for optimal performance. Has no effect if memory optimizer is disabled.", + TypeConverters.toBoolean) + def setConfigProtoBytes(self, b): """Sets configProto from tensorflow, serialized into byte array. @@ -377,6 +385,28 @@ def setBestModelMetric(self, value): """ return self._set(bestModelMetric=value) + def setPrefetchBatches(self, value): + """Sets number of batches to prefetch while training using memory optimizer. + Has no effect if memory optimizer is disabled. + + Parameters + ---------- + value : int + Number of batches to prefetch + """ + return self._set(prefetchBatches=value) + + def setOptimizePartitioning(self, value): + """Sets whether to repartition the dataset before training for optimal performance. + Has no effect if memory optimizer is disabled. + + Parameters + ---------- + value: bool + Whether to optimize partitioning + """ + return self._set(optimizePartitioning=value) + def _create_model(self, java_model): return NerDLModel(java_model=java_model) @@ -400,7 +430,9 @@ def __init__(self): enableOutputLogs=False, enableMemoryOptimizer=False, useBestModel=False, - bestModelMetric="f1_micro" + bestModelMetric="f1_micro", + prefetchBatches=0, + optimizePartitioning=True ) diff --git a/python/test/annotator/ner/ner_dl_approach_test.py b/python/test/annotator/ner/ner_dl_approach_test.py new file mode 100644 index 00000000000000..c9f03c1e89ba63 --- /dev/null +++ b/python/test/annotator/ner/ner_dl_approach_test.py @@ -0,0 +1,59 @@ +# Copyright 2017-2025 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import pytest + +from sparknlp.annotator import * +from test.util import SparkSessionForTest + + +@pytest.mark.fast +class NerDLApproachTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkSessionForTest.spark + + def test_setters(self): + ner_approach = ( + NerDLApproach() + .setLr(0.01) + .setPo(0.005) + .setBatchSize(16) + .setDropout(0.01) + .setGraphFolder("graph_folder") + .setConfigProtoBytes([]) + .setUseContrib(False) + .setEnableMemoryOptimizer(True) + .setIncludeConfidence(True) + .setIncludeAllConfidenceScores(True) + .setUseBestModel(True) + .setPrefetchBatches(20) + .setOptimizePartitioning(True) + ) + + # Check param map + param_map = ner_approach.extractParamMap() + self.assertEqual(param_map[ner_approach.lr], 0.01) + self.assertEqual(param_map[ner_approach.po], 0.005) + self.assertEqual(param_map[ner_approach.batchSize], 16) + self.assertEqual(param_map[ner_approach.dropout], 0.01) + self.assertEqual(param_map[ner_approach.graphFolder], "graph_folder") + self.assertEqual(param_map[ner_approach.configProtoBytes], []) + self.assertEqual(param_map[ner_approach.useContrib], False) + self.assertEqual(param_map[ner_approach.enableMemoryOptimizer], True) + self.assertEqual(param_map[ner_approach.includeConfidence], True) + self.assertEqual(param_map[ner_approach.includeAllConfidenceScores], True) + self.assertEqual(param_map[ner_approach.useBestModel], True) + self.assertEqual(param_map[ner_approach.prefetchBatches], 20) + self.assertEqual(param_map[ner_approach.optimizePartitioning], True)