diff --git a/extensions/fasttext/README.md b/extensions/fasttext/README.md index 64f582f8b07..f39e64b5730 100644 --- a/extensions/fasttext/README.md +++ b/extensions/fasttext/README.md @@ -5,8 +5,10 @@ This module contains the NLP support with fastText implementation. fastText module's implementation in DJL is not considered as an Engine, it doesn't support Trainer and Predictor. -The training and inference functionality is directly provided through [FtModel](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/FtModel.html) -class. You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java). +Training is only supported by using [TrainFastText](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/TrainFastText.html). +This produces a special block which can perform inference on its own or by using a model and predictor. +Pre-trained FastText models can also be loaded by using the standard DJL criteria. +You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java). Current implementation has the following limitations: diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java index a146c29cc33..d7b4451b739 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java @@ -41,7 +41,8 @@ /** * {@code FtModel} is the fastText implementation of {@link Model}. * - *

FtModel contains all the methods in Model to load and process a model. + *

FtModel contains all the methods in Model to load and process a model. However, it only + * supports training by using {@link TrainFastText}. */ public class FtModel implements Model { diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java new file mode 100644 index 00000000000..6250063932a --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java @@ -0,0 +1,38 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.fasttext; + +import ai.djl.basicdataset.RawDataset; +import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification; +import java.io.IOException; +import java.nio.file.Path; + +/** A utility to aggregate options for training with fasttext. */ +public final class TrainFastText { + + private TrainFastText() {} + + /** + * Trains a fastText {@link ai.djl.Application.NLP#TEXT_CLASSIFICATION} model. + * + * @param config the training configuration to use + * @param dataset the training dataset + * @return the result of the training + * @throws IOException when IO operation fails in loading a resource + * @see FtTextClassification#fit(FtTrainingConfig, RawDataset) + */ + public static FtTextClassification textClassification( + FtTrainingConfig config, RawDataset dataset) throws IOException { + return FtTextClassification.fit(config, dataset); + } +} diff --git a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java index 17c24d92b81..c0bb2fc3485 100644 --- a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java +++ b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java @@ -66,7 +66,7 @@ public void testTrainTextClassification() throws IOException, TranslateException .optLoss(FtTrainingConfig.FtLoss.HS) .build(); - FtTextClassification block = FtTextClassification.fit(config, dataset); + FtTextClassification block = TrainFastText.textClassification(config, dataset); TrainingResult result = block.getTrainingResult(); Assert.assertEquals(result.getEpoch(), 5); Assert.assertTrue(Files.exists(Paths.get("build/cooking.bin")));