Skip to content

Commit

Permalink
Add TrainFastText utility to aggregate links to training functions
Browse files Browse the repository at this point in the history
Change-Id: Iadf598d930dca954ee1a0f450012cd4b4b2baab5
  • Loading branch information
zachgk committed Mar 9, 2022
1 parent 46d8e45 commit 87de7a9
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 4 deletions.
6 changes: 4 additions & 2 deletions extensions/fasttext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
This module contains the NLP support with fastText implementation.

fastText module's implementation in DJL is not considered as an Engine, it doesn't support Trainer and Predictor.
The training and inference functionality is directly provided through [FtModel](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/FtModel.html)
class. You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java).
Training is only supported by using [TrainFastText](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/TrainFastText.html).
This produces a special block which can perform inference on its own or by using a model and predictor.
Pre-trained FastText models can also be loaded by using the standard DJL criteria.
You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java).

Current implementation has the following limitations:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
/**
* {@code FtModel} is the fastText implementation of {@link Model}.
*
* <p>FtModel contains all the methods in Model to load and process a model.
* <p>FtModel contains all the methods in Model to load and process a model. However, it only
* supports training by using {@link TrainFastText}.
*/
public class FtModel implements Model {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.fasttext;

import ai.djl.basicdataset.RawDataset;
import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification;
import java.io.IOException;
import java.nio.file.Path;

/** A utility to aggregate options for training with fasttext. */
public final class TrainFastText {

private TrainFastText() {}

/**
* Trains a fastText {@link ai.djl.Application.NLP#TEXT_CLASSIFICATION} model.
*
* @param config the training configuration to use
* @param dataset the training dataset
* @return the result of the training
* @throws IOException when IO operation fails in loading a resource
* @see FtTextClassification#fit(FtTrainingConfig, RawDataset)
*/
public static FtTextClassification textClassification(
FtTrainingConfig config, RawDataset<Path> dataset) throws IOException {
return FtTextClassification.fit(config, dataset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void testTrainTextClassification() throws IOException, TranslateException
.optLoss(FtTrainingConfig.FtLoss.HS)
.build();

FtTextClassification block = FtTextClassification.fit(config, dataset);
FtTextClassification block = TrainFastText.textClassification(config, dataset);
TrainingResult result = block.getTrainingResult();
Assert.assertEquals(result.getEpoch(), 5);
Assert.assertTrue(Files.exists(Paths.get("build/cooking.bin")));
Expand Down

0 comments on commit 87de7a9

Please sign in to comment.