diff --git a/extensions/fasttext/README.md b/extensions/fasttext/README.md index 64f582f8b07..f39e64b5730 100644 --- a/extensions/fasttext/README.md +++ b/extensions/fasttext/README.md @@ -5,8 +5,10 @@ This module contains the NLP support with fastText implementation. fastText module's implementation in DJL is not considered as an Engine, it doesn't support Trainer and Predictor. -The training and inference functionality is directly provided through [FtModel](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/FtModel.html) -class. You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java). +Training is only supported by using [TrainFastText](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/TrainFastText.html). +This produces a special block which can perform inference on its own or by using a model and predictor. +Pre-trained FastText models can also be loaded by using the standard DJL criteria. +You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java). Current implementation has the following limitations: diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java index a146c29cc33..d7b4451b739 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java @@ -41,7 +41,8 @@ /** * {@code FtModel} is the fastText implementation of {@link Model}. * - *
FtModel contains all the methods in Model to load and process a model. + *
FtModel contains all the methods in Model to load and process a model. However, it only
+ * supports training by using {@link TrainFastText}.
*/
public class FtModel implements Model {
diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java
new file mode 100644
index 00000000000..6250063932a
--- /dev/null
+++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+package ai.djl.fasttext;
+
+import ai.djl.basicdataset.RawDataset;
+import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification;
+import java.io.IOException;
+import java.nio.file.Path;
+
+/** A utility to aggregate options for training with fasttext. */
+public final class TrainFastText {
+
+ private TrainFastText() {}
+
+ /**
+ * Trains a fastText {@link ai.djl.Application.NLP#TEXT_CLASSIFICATION} model.
+ *
+ * @param config the training configuration to use
+ * @param dataset the training dataset
+ * @return the result of the training
+ * @throws IOException when IO operation fails in loading a resource
+ * @see FtTextClassification#fit(FtTrainingConfig, RawDataset)
+ */
+ public static FtTextClassification textClassification(
+ FtTrainingConfig config, RawDataset