From c181383c3435528a59be49ac1a4ea2ef57bba781 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Mon, 7 Mar 2022 06:36:23 -0800 Subject: [PATCH 1/3] Refactor ml extension support This makes a few changes, specifically to fasttext, but it can apply to other DJL ml wrappers as well. First, it creates several passthrough utility classes with the goal that ml models can be loaded through the model zoo and run through the predictor. Next, it modifies the base of the ml construct to better support engines that have multiple applications. Each application now manifests as a type of SymbolBlock rather than a model. Then, the single model class can run any of them. The model is still created for each engine because it contains general loading functionality that can determine which block should be used to load the given target. It also has to update the fasttext JNI to support this. First, it fixes the modelType to actually return different results. Then, it modifies the predictProba method to use an ArrayList instead of an Array. The main difference is because it is possible to pass a topk of -1 in order to load all elements. However, this doesn't work with the previous array setup. Change-Id: I5fbaceb2ac4711a942b9d15dac6e6fed386dd46e --- .../util/passthrough/PassthroughNDArray.java | 64 ++++++ .../passthrough/PassthroughNDManager.java | 209 ++++++++++++++++++ .../passthrough/PassthroughTranslator.java | 38 ++++ .../ai/djl/util/passthrough/package-info.java | 15 ++ .../java/ai/djl/fasttext/FtAbstractBlock.java | 63 ++++++ .../main/java/ai/djl/fasttext/FtModel.java | 89 +++----- .../ai/djl/fasttext/jni/FastTextLibrary.java | 8 +- .../java/ai/djl/fasttext/jni/FtWrapper.java | 11 +- .../FtTextClassification.java | 144 ++++++++++++ .../TextClassificationModelLoader.java | 3 +- .../FtWord2VecWordEmbedding.java | 35 ++- .../word_embedding/FtWordEmbeddingBlock.java | 45 ++++ .../zoo/nlp/word_embedding/package-info.java | 18 ++ .../ai_djl_fasttext_jni_FastTextLibrary.cc | 16 +- .../fasttext/CookingStackExchangeTest.java | 63 ++++-- 15 files changed, 717 insertions(+), 104 deletions(-) create mode 100644 api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java create mode 100644 api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java create mode 100644 api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java create mode 100644 api/src/main/java/ai/djl/util/passthrough/package-info.java create mode 100644 extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java create mode 100644 extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/FtTextClassification.java rename extensions/fasttext/src/main/java/ai/djl/fasttext/{ => zoo/nlp/word_embedding}/FtWord2VecWordEmbedding.java (64%) create mode 100644 extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWordEmbeddingBlock.java create mode 100644 extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/package-info.java diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java new file mode 100644 index 00000000000..eb7d3e699e0 --- /dev/null +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDArray.java @@ -0,0 +1,64 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util.passthrough; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDArrayAdapter; +import java.nio.ByteBuffer; + +/** + * An {@link NDArray} that stores an arbitrary Java object. + * + *

This class is mainly for use in extensions and hybrid engines. Despite it's name, it will + * often not contain actual {@link NDArray}s but just any object necessary to conform to the DJL + * predictor API. + */ +public class PassthroughNDArray extends NDArrayAdapter { + + private Object object; + + /** + * Constructs a {@link PassthroughNDArray} storing an object. + * + * @param object the object to store + */ + public PassthroughNDArray(Object object) { + super(null, null, null, null, null); + this.object = object; + } + + /** + * Returns the object stored. + * + * @return the object stored + */ + public Object getObject() { + return object; + } + + /** {@inheritDoc} */ + @Override + public ByteBuffer toByteBuffer() { + throw new UnsupportedOperationException("Operation not supported for FastText"); + } + + /** {@inheritDoc} */ + @Override + public void intern(NDArray replaced) { + throw new UnsupportedOperationException("Operation not supported for FastText"); + } + + /** {@inheritDoc} */ + @Override + public void detach() {} +} diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java new file mode 100644 index 00000000000..493b2fa0e9a --- /dev/null +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughNDManager.java @@ -0,0 +1,209 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util.passthrough; + +import ai.djl.Device; +import ai.djl.engine.Engine; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.NDResource; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.util.PairList; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.nio.file.Path; + +/** An {@link NDManager} that does nothing, for use in extensions and hybrid engines. */ +public final class PassthroughNDManager implements NDManager { + + private static final String UNSUPPORTED = "Not supported by PassthroughNDManager"; + public static final PassthroughNDManager INSTANCE = new PassthroughNDManager(); + + private PassthroughNDManager() {} + + @Override + public Device defaultDevice() { + return Device.cpu(); + } + + @Override + public ByteBuffer allocateDirect(int capacity) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray from(NDArray array) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray create(String[] data, Charset charset, Shape shape) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray create(Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray createCSR(Buffer data, long[] indptr, long[] indices, Shape shape) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray createRowSparse(Buffer data, Shape dataShape, long[] indices, Shape shape) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray createCoo(Buffer data, long[][] indices, Shape shape) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDList load(Path path) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public void setName(String name) {} + + @Override + public String getName() { + return "PassthroughNDManager"; + } + + @Override + public NDArray zeros(Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray ones(Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray full(Shape shape, float value, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray arange(float start, float stop, float step, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray eye(int rows, int cols, int k, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray linspace(float start, float stop, int num, boolean endpoint) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray randomInteger(long low, long high, Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray randomUniform(float low, float high, Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray randomNormal(float loc, float scale, Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray truncatedNormal(float loc, float scale, Shape shape, DataType dataType) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray randomMultinomial(int n, NDArray pValues) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDArray randomMultinomial(int n, NDArray pValues, Shape shape) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public NDManager getParentManager() { + return this; + } + + @Override + public NDManager newSubManager() { + return this; + } + + @Override + public NDManager newSubManager(Device device) { + return this; + } + + @Override + public Device getDevice() { + return Device.cpu(); + } + + @Override + public void attachInternal(String resourceId, AutoCloseable resource) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public void tempAttachInternal( + NDManager originalManager, String resourceId, NDResource resource) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public void detachInternal(String resourceId) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public void invoke( + String operation, NDArray[] src, NDArray[] dest, PairList params) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public NDList invoke(String operation, NDList src, PairList params) { + throw new UnsupportedOperationException(UNSUPPORTED); + } + + @Override + public Engine getEngine() { + return null; + } + + @Override + public void close() {} +} diff --git a/api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java b/api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java new file mode 100644 index 00000000000..55f6692a2b9 --- /dev/null +++ b/api/src/main/java/ai/djl/util/passthrough/PassthroughTranslator.java @@ -0,0 +1,38 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.util.passthrough; + +import ai.djl.ndarray.NDList; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslatorContext; + +/** + * A translator that stores and removes data from a {@link PassthroughNDArray}. + * + * @param translator input type + * @param translator output type + */ +public class PassthroughTranslator implements NoBatchifyTranslator { + + @Override + public NDList processInput(TranslatorContext ctx, I input) throws Exception { + return new NDList(new PassthroughNDArray(input)); + } + + @Override + @SuppressWarnings("unchecked") + public O processOutput(TranslatorContext ctx, NDList list) { + PassthroughNDArray wrapper = (PassthroughNDArray) list.singletonOrThrow(); + return (O) wrapper.getObject(); + } +} diff --git a/api/src/main/java/ai/djl/util/passthrough/package-info.java b/api/src/main/java/ai/djl/util/passthrough/package-info.java new file mode 100644 index 00000000000..62a0fd37ce9 --- /dev/null +++ b/api/src/main/java/ai/djl/util/passthrough/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains passthrough DJL classes for use in extensions and hybrid engines. */ +package ai.djl.util.passthrough; diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java new file mode 100644 index 00000000000..c7ffec04a95 --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtAbstractBlock.java @@ -0,0 +1,63 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.fasttext; + +import ai.djl.fasttext.jni.FtWrapper; +import ai.djl.nn.AbstractSymbolBlock; +import java.nio.file.Path; + +/** + * A parent class containing shared behavior for {@link ai.djl.nn.SymbolBlock}s based on fasttext + * models. + */ +public abstract class FtAbstractBlock extends AbstractSymbolBlock implements AutoCloseable { + + protected FtWrapper fta; + + protected Path modelFile; + + /** + * Constructs a {@link FtAbstractBlock}. + * + * @param fta the {@link FtWrapper} containing the "fasttext model" + */ + public FtAbstractBlock(FtWrapper fta) { + this.fta = fta; + } + + /** + * Returns the fasttext model file for the block. + * + * @return the fasttext model file for the block + */ + public Path getModelFile() { + return modelFile; + } + + /** + * Embeds a word using fasttext. + * + * @param word the word to embed + * @return the embedding + * @see ai.djl.modality.nlp.embedding.WordEmbedding + */ + public float[] embedWord(String word) { + return fta.getWordVector(word); + } + + @Override + public void close() { + fta.unloadModel(); + fta.close(); + } +} diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java index fed7a38b4d8..a146c29cc33 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java @@ -15,19 +15,19 @@ import ai.djl.Device; import ai.djl.MalformedModelException; import ai.djl.Model; -import ai.djl.basicdataset.RawDataset; import ai.djl.fasttext.jni.FtWrapper; +import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification; +import ai.djl.fasttext.zoo.nlp.word_embedding.FtWordEmbeddingBlock; import ai.djl.inference.Predictor; -import ai.djl.modality.Classifications; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.nn.Block; import ai.djl.training.Trainer; import ai.djl.training.TrainingConfig; -import ai.djl.training.TrainingResult; import ai.djl.translate.Translator; import ai.djl.util.PairList; +import ai.djl.util.passthrough.PassthroughNDManager; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; @@ -45,7 +45,7 @@ */ public class FtModel implements Model { - FtWrapper fta; + FtAbstractBlock block; private Path modelDir; private String modelName; @@ -58,7 +58,6 @@ public class FtModel implements Model { */ public FtModel(String name) { this.modelName = name; - fta = FtWrapper.newInstance(); properties = new ConcurrentHashMap<>(); } @@ -80,6 +79,7 @@ public void load(Path modelPath, String prefix, Map options) } String modelFilePath = modelFile.toString(); + FtWrapper fta = FtWrapper.newInstance(); if (!fta.checkModel(modelFilePath)) { throw new MalformedModelException("Malformed FastText model file:" + modelFilePath); } @@ -90,7 +90,21 @@ public void load(Path modelPath, String prefix, Map options) properties.put(entry.getKey(), entry.getValue().toString()); } } - properties.put("model-type", fta.getModelType()); + String modelType = fta.getModelType(); + properties.put("model-type", modelType); + + if ("sup".equals(modelType)) { + String labelPrefix = + properties.getOrDefault( + "label-prefix", FtTextClassification.DEFAULT_LABEL_PREFIX); + block = new FtTextClassification(fta, labelPrefix); + modelDir = block.getModelFile(); + } else if ("cbow".equals(modelType) || "sg".equals(modelType)) { + block = new FtWordEmbeddingBlock(fta); + modelDir = block.getModelFile(); + } else { + throw new MalformedModelException("Unexpected FastText model type: " + modelType); + } } /** {@inheritDoc} */ @@ -130,49 +144,6 @@ private Path findModelFile(String prefix) { return modelFile; } - /** - * Returns top K number of classifications of the input text. - * - * @param text the input text to be classified - * @param topK the value of K - * @return classifications of the input text - */ - public Classifications classify(String text, int topK) { - String labelPrefix = properties.getOrDefault("label-prefix", "__label__"); - return fta.predictProba(text, topK, labelPrefix); - } - - /** - * Train the fastText model. - * - * @param config the training configuration to use - * @param dataset the training dataset - * @return the result of the training - * @throws IOException when IO operation fails in loading a resource - */ - public TrainingResult fit(FtTrainingConfig config, RawDataset dataset) - throws IOException { - Path outputDir = config.getOutputDir(); - if (Files.notExists(outputDir)) { - Files.createDirectory(outputDir); - } - String fitModelName = config.getModelName(); - Path modelFile = outputDir.resolve(fitModelName).toAbsolutePath(); - - String[] args = config.toCommand(dataset.getData().toString()); - - fta.runCmd(args); - setModelFile(modelFile); - - TrainingResult result = new TrainingResult(); - int epoch = config.getEpoch(); - if (epoch <= 0) { - epoch = 5; - } - result.setEpoch(epoch); - return result; - } - /** {@inheritDoc} */ @Override public void save(Path modelDir, String newModelName) {} @@ -185,14 +156,17 @@ public Path getModelPath() { /** {@inheritDoc} */ @Override - public Block getBlock() { - throw new UnsupportedOperationException("Fasttext doesn't support Block."); + public FtAbstractBlock getBlock() { + return block; } /** {@inheritDoc} */ @Override public void setBlock(Block block) { - throw new UnsupportedOperationException("Fasttext doesn't support setting the Block."); + if (!(block instanceof FtAbstractBlock)) { + throw new IllegalArgumentException("Expected a FtAbstractBlock Block"); + } + this.block = (FtAbstractBlock) block; } /** {@inheritDoc} */ @@ -205,7 +179,7 @@ public String getName() { @Override public Trainer newTrainer(TrainingConfig trainingConfig) { throw new UnsupportedOperationException( - "FastText only supports training using FtModel.fit"); + "FastText only supports training using the FtAbstractBlocks"); } /** {@inheritDoc} */ @@ -263,7 +237,7 @@ public InputStream getArtifactAsStream(String name) { /** {@inheritDoc} */ @Override public NDManager getNDManager() { - return null; + return PassthroughNDManager.INSTANCE; } /** {@inheritDoc} */ @@ -278,15 +252,10 @@ public String getProperty(String key) { return properties.get(key); } - void setModelFile(Path modelFile) { - this.modelDir = modelFile; - } - /** {@inheritDoc} */ @Override public void close() { - fta.unloadModel(); - fta.close(); + block.close(); } /** {@inheritDoc} */ diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java index 42fecd3e219..b10af276606 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java @@ -12,6 +12,8 @@ */ package ai.djl.fasttext.jni; +import java.util.ArrayList; + /** A class containing utilities to interact with the SentencePiece Engine's JNI layer. */ @SuppressWarnings("MissingJavadocMethod") final class FastTextLibrary { @@ -33,7 +35,11 @@ private FastTextLibrary() {} native String getModelType(long handle); native int predictProba( - long handle, String text, int topK, String[] classes, float[] probabilities); + long handle, + String text, + int topK, + ArrayList classes, + ArrayList probabilities); native float[] getWordVector(long handle, String word); diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java index 7bbfcb1fe58..03d9499698d 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FtWrapper.java @@ -59,25 +59,26 @@ public String getModelType() { } public Classifications predictProba(String text, int topK, String labelPrefix) { - String[] labels = new String[topK]; - float[] probs = new float[topK]; + int cap = topK != -1 ? topK : 10; + ArrayList labels = new ArrayList<>(cap); + ArrayList probs = new ArrayList<>(cap); int size = FastTextLibrary.LIB.predictProba(getHandle(), text, topK, labels, probs); List classes = new ArrayList<>(size); List probabilities = new ArrayList<>(size); for (int i = 0; i < size; ++i) { - String label = labels[i]; + String label = labels.get(i); if (label.startsWith(labelPrefix)) { label = label.substring(labelPrefix.length()); } classes.add(label); - probabilities.add((double) probs[i]); + probabilities.add((double) probs.get(i)); } return new Classifications(classes, probabilities); } - public float[] getDataVector(String word) { + public float[] getWordVector(String word) { return FastTextLibrary.LIB.getWordVector(getHandle(), word); } diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/FtTextClassification.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/FtTextClassification.java new file mode 100644 index 00000000000..c19b1370bfb --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/FtTextClassification.java @@ -0,0 +1,144 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.fasttext.zoo.nlp.textclassification; + +import ai.djl.basicdataset.RawDataset; +import ai.djl.fasttext.FtAbstractBlock; +import ai.djl.fasttext.FtTrainingConfig; +import ai.djl.fasttext.jni.FtWrapper; +import ai.djl.fasttext.zoo.nlp.word_embedding.FtWordEmbeddingBlock; +import ai.djl.modality.Classifications; +import ai.djl.ndarray.NDList; +import ai.djl.training.ParameterStore; +import ai.djl.training.TrainingResult; +import ai.djl.util.PairList; +import ai.djl.util.passthrough.PassthroughNDArray; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +/** A {@link FtAbstractBlock} for {@link ai.djl.Application.NLP#TEXT_CLASSIFICATION}. */ +public class FtTextClassification extends FtAbstractBlock { + + public static final String DEFAULT_LABEL_PREFIX = "__label__"; + + private String labelPrefix; + + private TrainingResult trainingResult; + + /** + * Constructs a {@link FtTextClassification}. + * + * @param fta the {@link FtWrapper} containing the "fasttext model" + * @param labelPrefix the prefix to use for labels + */ + public FtTextClassification(FtWrapper fta, String labelPrefix) { + super(fta); + this.labelPrefix = labelPrefix; + } + + /** + * Trains the fastText model. + * + * @param config the training configuration to use + * @param dataset the training dataset + * @return the result of the training + * @throws IOException when IO operation fails in loading a resource + */ + public static FtTextClassification fit(FtTrainingConfig config, RawDataset dataset) + throws IOException { + Path outputDir = config.getOutputDir(); + if (Files.notExists(outputDir)) { + Files.createDirectory(outputDir); + } + String fitModelName = config.getModelName(); + FtWrapper fta = FtWrapper.newInstance(); + Path modelFile = outputDir.resolve(fitModelName).toAbsolutePath(); + + String[] args = config.toCommand(dataset.getData().toString()); + + fta.runCmd(args); + + TrainingResult result = new TrainingResult(); + int epoch = config.getEpoch(); + if (epoch <= 0) { + epoch = 5; + } + result.setEpoch(epoch); + + FtTextClassification block = new FtTextClassification(fta, config.getLabelPrefix()); + block.modelFile = modelFile; + block.trainingResult = result; + return block; + } + + /** + * Returns the fasttext label prefix. + * + * @return the fasttext label prefix + */ + public String getLabelPrefix() { + return labelPrefix; + } + + /** + * Returns the results of training, or null if not trained. + * + * @return the results of training, or null if not trained + */ + public TrainingResult getTrainingResult() { + return trainingResult; + } + + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + PassthroughNDArray inputWrapper = (PassthroughNDArray) inputs.singletonOrThrow(); + String input = (String) inputWrapper.getObject(); + Classifications result = fta.predictProba(input, -1, labelPrefix); + return new NDList(new PassthroughNDArray(result)); + } + + /** + * Converts the block into the equivalent {@link FtWordEmbeddingBlock}. + * + * @return the equivalent {@link FtWordEmbeddingBlock} + */ + public FtWordEmbeddingBlock toWordEmbedding() { + return new FtWordEmbeddingBlock(fta); + } + + /** + * Returns the classifications of the input text. + * + * @param text the input text to be classified + * @return classifications of the input text + */ + public Classifications classify(String text) { + return classify(text, -1); + } + + /** + * Returns top K classifications of the input text. + * + * @param text the input text to be classified + * @param topK the value of K + * @return classifications of the input text + */ + public Classifications classify(String text, int topK) { + return fta.predictProba(text, topK, labelPrefix); + } +} diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java index 06c863301bd..4219420c585 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/textclassification/TextClassificationModelLoader.java @@ -24,6 +24,7 @@ import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; import ai.djl.util.Progress; +import ai.djl.util.passthrough.PassthroughTranslator; import java.io.IOException; import java.nio.file.Path; @@ -66,6 +67,6 @@ public ZooModel loadModel(Criteria criteria) Model model = new FtModel(modelName); Path modelPath = mrl.getRepository().getResourceDirectory(artifact); model.load(modelPath, modelName, criteria.getOptions()); - return new ZooModel<>(model, null); + return new ZooModel<>(model, new PassthroughTranslator<>()); } } diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtWord2VecWordEmbedding.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java similarity index 64% rename from extensions/fasttext/src/main/java/ai/djl/fasttext/FtWord2VecWordEmbedding.java rename to extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java index 18e765092f2..230079d79cd 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtWord2VecWordEmbedding.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWord2VecWordEmbedding.java @@ -10,27 +10,50 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.fasttext; +package ai.djl.fasttext.zoo.nlp.word_embedding; +import ai.djl.Model; +import ai.djl.fasttext.FtAbstractBlock; +import ai.djl.fasttext.FtModel; import ai.djl.modality.nlp.Vocabulary; import ai.djl.modality.nlp.embedding.WordEmbedding; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; +import ai.djl.repository.zoo.ZooModel; /** An implementation of {@link WordEmbedding} for FastText word embeddings. */ public class FtWord2VecWordEmbedding implements WordEmbedding { - private FtModel model; + private FtAbstractBlock embedding; private Vocabulary vocabulary; /** * Constructs a {@link FtWord2VecWordEmbedding}. * - * @param model a loaded FastText model + * @param model a loaded FastText wordEmbedding model or a ZooModel containing one * @param vocabulary the {@link Vocabulary} to get indices from */ - public FtWord2VecWordEmbedding(FtModel model, Vocabulary vocabulary) { - this.model = model; + public FtWord2VecWordEmbedding(Model model, Vocabulary vocabulary) { + if (model instanceof ZooModel) { + model = ((ZooModel) model).getWrappedModel(); + } + + if (!(model instanceof FtModel)) { + throw new IllegalArgumentException("The FtWord2VecWordEmbedding requires an FtModel"); + } + + this.embedding = (FtAbstractBlock) model.getBlock(); + this.vocabulary = vocabulary; + } + + /** + * Constructs a {@link FtWord2VecWordEmbedding}. + * + * @param embedding the word embedding + * @param vocabulary the {@link Vocabulary} to get indices from + */ + public FtWord2VecWordEmbedding(FtAbstractBlock embedding, Vocabulary vocabulary) { + this.embedding = embedding; this.vocabulary = vocabulary; } @@ -56,7 +79,7 @@ public NDArray embedWord(NDArray index) { @Override public NDArray embedWord(NDManager manager, long index) { String word = vocabulary.getToken(index); - float[] buf = model.fta.getDataVector(word); + float[] buf = embedding.embedWord(word); return manager.create(buf); } diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWordEmbeddingBlock.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWordEmbeddingBlock.java new file mode 100644 index 00000000000..8f18558858d --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/FtWordEmbeddingBlock.java @@ -0,0 +1,45 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.fasttext.zoo.nlp.word_embedding; + +import ai.djl.fasttext.FtAbstractBlock; +import ai.djl.fasttext.jni.FtWrapper; +import ai.djl.ndarray.NDList; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; +import ai.djl.util.passthrough.PassthroughNDArray; + +/** A {@link FtAbstractBlock} for {@link ai.djl.Application.NLP#WORD_EMBEDDING}. */ +public class FtWordEmbeddingBlock extends FtAbstractBlock { + + /** + * Constructs a {@link FtWordEmbeddingBlock}. + * + * @param fta the {@link FtWrapper} for the "fasttext model". + */ + public FtWordEmbeddingBlock(FtWrapper fta) { + super(fta); + } + + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + PassthroughNDArray inputWrapper = (PassthroughNDArray) inputs.singletonOrThrow(); + String input = (String) inputWrapper.getObject(); + float[] result = embedWord(input); + return new NDList(new PassthroughNDArray(result)); + } +} diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/package-info.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/package-info.java new file mode 100644 index 00000000000..4bbd761e559 --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/zoo/nlp/word_embedding/package-info.java @@ -0,0 +1,18 @@ +/* + * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** + * Contains classes for the {@link ai.djl.Application.NLP#WORD_EMBEDDING} models in the {@link + * ai.djl.fasttext.zoo.FtModelZoo}. + */ +package ai.djl.fasttext.zoo.nlp.word_embedding; diff --git a/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc b/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc index c4cb4354cee..8a14c1078b2 100644 --- a/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc +++ b/extensions/fasttext/src/main/native/ai_djl_fasttext_jni_FastTextLibrary.cc @@ -97,9 +97,9 @@ JNIEXPORT jstring JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_getModelType( if (modelName == model_name::cbow) { return env->NewStringUTF("cbow"); } else if (modelName == model_name::sg) { - return env->NewStringUTF("cbow"); + return env->NewStringUTF("sg"); } else if (modelName == model_name::sup) { - return env->NewStringUTF("cbow"); + return env->NewStringUTF("sup"); } else { jclass jexception = env->FindClass("ai/djl/engine/EngineException"); env->ThrowNew(jexception, "Unrecognized model type"); @@ -108,7 +108,7 @@ JNIEXPORT jstring JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_getModelType( } JNIEXPORT jint JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_predictProba( - JNIEnv* env, jobject jthis, jlong jhandle, jstring jtext, jint top_k, jobjectArray jclasses, jfloatArray jprob) { + JNIEnv* env, jobject jthis, jlong jhandle, jstring jtext, jint top_k, jobject jclasses, jobject jprob) { auto* fasttext_ptr = reinterpret_cast(jhandle); std::string text = djl::utils::jni::GetStringFromJString(env, jtext); std::istringstream in(text); @@ -116,13 +116,15 @@ JNIEXPORT jint JNICALL Java_ai_djl_fasttext_jni_FastTextLibrary_predictProba( fasttext_ptr->predictLine(in, predictions, top_k, 0.0); int size = predictions.size(); - std::vector prob; + jclass java_lang_Float = static_cast(env->NewGlobalRef(env->FindClass("java/lang/Float"))); + jmethodID java_lang_Float_ = env->GetMethodID(java_lang_Float, "", "(F)V"); + jclass java_util_ArrayList = static_cast(env->NewGlobalRef(env->FindClass("java/util/ArrayList"))); + jmethodID java_util_ArrayList_add = env->GetMethodID(java_util_ArrayList, "add", "(Ljava/lang/Object;)Z"); for (int i = 0; i < size; ++i) { std::pair pair = predictions[i]; - env->SetObjectArrayElement(jclasses, i, env->NewStringUTF(pair.second.c_str())); - prob.push_back(pair.first); + env->CallBooleanMethod(jclasses, java_util_ArrayList_add, env->NewStringUTF(pair.second.c_str())); + env->CallBooleanMethod(jprob, java_util_ArrayList_add, env->NewObject(java_lang_Float, java_lang_Float_, pair.first)); } - env->SetFloatArrayRegion(jprob, 0, size, prob.data()); return size; } diff --git a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java index 0ab6d1180d4..17c24d92b81 100644 --- a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java +++ b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java @@ -12,19 +12,26 @@ */ package ai.djl.fasttext; +import ai.djl.Application; import ai.djl.MalformedModelException; import ai.djl.ModelException; import ai.djl.basicdataset.nlp.CookingStackExchange; +import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification; +import ai.djl.fasttext.zoo.nlp.word_embedding.FtWord2VecWordEmbedding; +import ai.djl.inference.Predictor; import ai.djl.modality.Classifications; import ai.djl.modality.nlp.DefaultVocabulary; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; +import ai.djl.repository.Artifact; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ModelZoo; import ai.djl.repository.zoo.ZooModel; import ai.djl.testing.TestRequirements; import ai.djl.training.TrainingResult; +import ai.djl.translate.TranslateException; import java.io.IOException; import java.io.InputStream; import java.net.URL; @@ -33,6 +40,8 @@ import java.nio.file.Paths; import java.nio.file.StandardCopyOption; import java.util.Collections; +import java.util.List; +import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; @@ -43,30 +52,30 @@ public class CookingStackExchangeTest { private static final Logger logger = LoggerFactory.getLogger(CookingStackExchangeTest.class); @Test - public void testTrainTextClassification() throws IOException { + public void testTrainTextClassification() throws IOException, TranslateException { TestRequirements.notWindows(); // fastText is not supported on windows - try (FtModel model = new FtModel("cooking")) { - CookingStackExchange dataset = CookingStackExchange.builder().build(); - - // setup training configuration - FtTrainingConfig config = - FtTrainingConfig.builder() - .setOutputDir(Paths.get("build")) - .setModelName("cooking") - .optEpoch(5) - .optLoss(FtTrainingConfig.FtLoss.HS) - .build(); - - TrainingResult result = model.fit(config, dataset); - Assert.assertEquals(result.getEpoch(), 5); - Assert.assertTrue(Files.exists(Paths.get("build/cooking.bin"))); - } + CookingStackExchange dataset = CookingStackExchange.builder().build(); + + // setup training configuration + FtTrainingConfig config = + FtTrainingConfig.builder() + .setOutputDir(Paths.get("build")) + .setModelName("cooking") + .optEpoch(5) + .optLoss(FtTrainingConfig.FtLoss.HS) + .build(); + + FtTextClassification block = FtTextClassification.fit(config, dataset); + TrainingResult result = block.getTrainingResult(); + Assert.assertEquals(result.getEpoch(), 5); + Assert.assertTrue(Files.exists(Paths.get("build/cooking.bin"))); } @Test public void testTextClassification() - throws IOException, MalformedModelException, ModelNotFoundException { + throws IOException, MalformedModelException, ModelNotFoundException, + TranslateException { TestRequirements.notWindows(); // fastText is not supported on windows Criteria criteria = @@ -75,11 +84,18 @@ public void testTextClassification() .optArtifactId("ai.djl.fasttext:cooking_stackexchange") .optOption("label-prefix", "__label") .build(); + Map> models = ModelZoo.listModels(criteria); + models.forEach( + (app, list) -> { + String appName = app.toString(); + list.forEach(artifact -> logger.info("{} {}", appName, artifact)); + }); try (ZooModel model = criteria.loadModel()) { String input = "Which baking dish is best to bake a banana bread ?"; - FtModel ftModel = (FtModel) model.getWrappedModel(); - Classifications result = ftModel.classify(input, 8); - Assert.assertEquals(result.item(0).getClassName(), "__bread"); + try (Predictor predictor = model.newPredictor()) { + Classifications result = predictor.predict(input); + Assert.assertEquals(result.item(0).getClassName(), "__bread"); + } } } @@ -95,10 +111,9 @@ public void testWord2Vec() throws IOException, MalformedModelException, ModelNot try (ZooModel model = criteria.loadModel(); NDManager manager = NDManager.newBaseManager()) { - FtModel ftModel = (FtModel) model.getWrappedModel(); FtWord2VecWordEmbedding fasttextWord2VecWordEmbedding = new FtWord2VecWordEmbedding( - ftModel, new DefaultVocabulary(Collections.singletonList("bread"))); + model, new DefaultVocabulary(Collections.singletonList("bread"))); long index = fasttextWord2VecWordEmbedding.preprocessWordToEmbed("bread"); NDArray embedding = fasttextWord2VecWordEmbedding.embedWord(manager, index); Assert.assertEquals(embedding.getShape(), new Shape(100)); @@ -125,7 +140,7 @@ public void testBlazingText() throws IOException, ModelException { model.load(modelFile); String text = "Convair was an american aircraft manufacturing company which later expanded into rockets and spacecraft ."; - Classifications result = model.classify(text, 5); + Classifications result = ((FtTextClassification) model.getBlock()).classify(text, 5); logger.info("{}", result); Assert.assertEquals(result.item(0).getClassName(), "Company"); From 46d8e452750a1c2f6b434b05112b6ecde10ea9b5 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Wed, 9 Mar 2022 09:08:09 -0800 Subject: [PATCH 2/3] Fix PMD Change-Id: I7b03f8958594c01d3ab2f722dbb7a573c6d3109e --- .../src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java | 1 + 1 file changed, 1 insertion(+) diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java index b10af276606..d874d3ad1aa 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/jni/FastTextLibrary.java @@ -34,6 +34,7 @@ private FastTextLibrary() {} native String getModelType(long handle); + @SuppressWarnings("PMD.LooseCoupling") native int predictProba( long handle, String text, From 87de7a9c2b08d4db210171c749eceac34a7a3fee Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Wed, 9 Mar 2022 15:28:16 -0800 Subject: [PATCH 3/3] Add TrainFastText utility to aggregate links to training functions Change-Id: Iadf598d930dca954ee1a0f450012cd4b4b2baab5 --- extensions/fasttext/README.md | 6 ++- .../main/java/ai/djl/fasttext/FtModel.java | 3 +- .../java/ai/djl/fasttext/TrainFastText.java | 38 +++++++++++++++++++ .../fasttext/CookingStackExchangeTest.java | 2 +- 4 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java diff --git a/extensions/fasttext/README.md b/extensions/fasttext/README.md index 64f582f8b07..f39e64b5730 100644 --- a/extensions/fasttext/README.md +++ b/extensions/fasttext/README.md @@ -5,8 +5,10 @@ This module contains the NLP support with fastText implementation. fastText module's implementation in DJL is not considered as an Engine, it doesn't support Trainer and Predictor. -The training and inference functionality is directly provided through [FtModel](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/FtModel.html) -class. You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java). +Training is only supported by using [TrainFastText](https://javadoc.io/doc/ai.djl.fasttext/fasttext-engine/latest/ai/djl/fasttext/TrainFastText.html). +This produces a special block which can perform inference on its own or by using a model and predictor. +Pre-trained FastText models can also be loaded by using the standard DJL criteria. +You can find examples [here](https://github.com/deepjavalibrary/djl/blob/master/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java). Current implementation has the following limitations: diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java index a146c29cc33..d7b4451b739 100644 --- a/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/FtModel.java @@ -41,7 +41,8 @@ /** * {@code FtModel} is the fastText implementation of {@link Model}. * - *

FtModel contains all the methods in Model to load and process a model. + *

FtModel contains all the methods in Model to load and process a model. However, it only + * supports training by using {@link TrainFastText}. */ public class FtModel implements Model { diff --git a/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java b/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java new file mode 100644 index 00000000000..6250063932a --- /dev/null +++ b/extensions/fasttext/src/main/java/ai/djl/fasttext/TrainFastText.java @@ -0,0 +1,38 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.fasttext; + +import ai.djl.basicdataset.RawDataset; +import ai.djl.fasttext.zoo.nlp.textclassification.FtTextClassification; +import java.io.IOException; +import java.nio.file.Path; + +/** A utility to aggregate options for training with fasttext. */ +public final class TrainFastText { + + private TrainFastText() {} + + /** + * Trains a fastText {@link ai.djl.Application.NLP#TEXT_CLASSIFICATION} model. + * + * @param config the training configuration to use + * @param dataset the training dataset + * @return the result of the training + * @throws IOException when IO operation fails in loading a resource + * @see FtTextClassification#fit(FtTrainingConfig, RawDataset) + */ + public static FtTextClassification textClassification( + FtTrainingConfig config, RawDataset dataset) throws IOException { + return FtTextClassification.fit(config, dataset); + } +} diff --git a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java index 17c24d92b81..c0bb2fc3485 100644 --- a/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java +++ b/extensions/fasttext/src/test/java/ai/djl/fasttext/CookingStackExchangeTest.java @@ -66,7 +66,7 @@ public void testTrainTextClassification() throws IOException, TranslateException .optLoss(FtTrainingConfig.FtLoss.HS) .build(); - FtTextClassification block = FtTextClassification.fit(config, dataset); + FtTextClassification block = TrainFastText.textClassification(config, dataset); TrainingResult result = block.getTrainingResult(); Assert.assertEquals(result.getEpoch(), 5); Assert.assertTrue(Files.exists(Paths.get("build/cooking.bin")));