From 52de2ec53c36f19e1259ec121997c58095b82158 Mon Sep 17 00:00:00 2001 From: Aziz Zayed Date: Fri, 20 Aug 2021 15:47:55 -0700 Subject: [PATCH 1/3] Add Style Transfer example with CycleGAN --- .../ai/djl/pytorch/cyclegan/metadata.json | 94 +++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/image-generation/ai/djl/pytorch/cyclegan/metadata.json diff --git a/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/image-generation/ai/djl/pytorch/cyclegan/metadata.json b/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/image-generation/ai/djl/pytorch/cyclegan/metadata.json new file mode 100644 index 00000000000..d0146aa16f4 --- /dev/null +++ b/pytorch/pytorch-model-zoo/src/test/resources/mlrepo/model/cv/image-generation/ai/djl/pytorch/cyclegan/metadata.json @@ -0,0 +1,94 @@ +{ + "metadataVersion": "0.1", + "resourceType": "model", + "application": "cv/image_generation", + "groupId": "ai.djl.pytorch", + "artifactId": "cyclegan", + "name": "CycleGAN", + "description": "CycleGAN style transfer", + "website": "http://www.djl.ai/pytorch/model-zoo", + "licenses": { + "license": { + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" + } + }, + "artifacts": [ + { + "version": "0.0.1", + "snapshot": false, + "name": "style_cezanne", + "properties": { + "artist": "cezanne" + }, + "arguments": { + "translatorFactory": "ai.djl.modality.cv.translator.StyleTransferTranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/style_cezanne.zip", + "sha1Hash": "3214f612c3d13eb834173d1df44334d7a72fc3a0", + "name": "", + "size": 42274056 + } + } + }, + { + "version": "0.0.1", + "snapshot": false, + "name": "style_monet", + "properties": { + "artist": "monet" + }, + "arguments": { + "translatorFactory": "ai.djl.modality.cv.translator.StyleTransferTranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/style_monet.zip", + "sha1Hash": "46e3241e91c310289f68f80ae948f083056d6034", + "name": "", + "size": 42260634 + } + } + }, + { + "version": "0.0.1", + "snapshot": false, + "name": "style_ukiyoe", + "properties": { + "artist": "ukiyoe" + }, + "arguments": { + "translatorFactory": "ai.djl.modality.cv.translator.StyleTransferTranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/style_ukiyoe.zip", + "sha1Hash": "06776c94eb48db5b86cfb6f69b24d1748047c866", + "name": "", + "size": 42264548 + } + } + }, + { + "version": "0.0.1", + "snapshot": false, + "name": "style_vangogh", + "properties": { + "artist": "vangogh" + }, + "arguments": { + "translatorFactory": "ai.djl.modality.cv.translator.StyleTransferTranslatorFactory" + }, + "files": { + "model": { + "uri": "0.0.1/style_vangogh.zip", + "sha1Hash": "81a2828e224fc7ff617e8022d80a392776d3d2d9", + "name": "", + "size": 42272108 + } + } + } + ] +} From 7726ea4dd1db8aceb49d4f7286163ce102fb3c1f Mon Sep 17 00:00:00 2001 From: Aziz Zayed Date: Fri, 20 Aug 2021 15:51:54 -0700 Subject: [PATCH 2/3] Add models to README --- pytorch/pytorch-model-zoo/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch/pytorch-model-zoo/README.md b/pytorch/pytorch-model-zoo/README.md index 28e525f2069..1253d29b05f 100644 --- a/pytorch/pytorch-model-zoo/README.md +++ b/pytorch/pytorch-model-zoo/README.md @@ -36,6 +36,8 @@ The PyTorch model zoo contains Computer Vision (CV) models. All the models are g * CV * Image Classification * Object Detection + * Style Transfer + * Image Generation ### How to find a pre-trained model in model zoo From 7c3d644ec35faea963002883eb8a181cef4cdb43 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 20 Aug 2021 17:06:48 -0700 Subject: [PATCH 3/3] Add StyleTransferTranslatorFactory Change-Id: Id44c176f3772868bd4f333dbd1669abad2268583 --- .../translator}/StyleTransferTranslator.java | 15 ++++-- .../StyleTransferTranslatorFactory.java | 51 +++++++++++++++++++ .../inference/cyclegan/StyleTransfer.java | 9 ++-- .../java/ai/djl/pytorch/zoo/PtModelZoo.java | 3 ++ 4 files changed, 68 insertions(+), 10 deletions(-) rename {examples/src/main/java/ai/djl/examples/inference/cyclegan => api/src/main/java/ai/djl/modality/cv/translator}/StyleTransferTranslator.java (88%) create mode 100644 api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactory.java diff --git a/examples/src/main/java/ai/djl/examples/inference/cyclegan/StyleTransferTranslator.java b/api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslator.java similarity index 88% rename from examples/src/main/java/ai/djl/examples/inference/cyclegan/StyleTransferTranslator.java rename to api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslator.java index 8ed68044e5a..c45c0762ac1 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cyclegan/StyleTransferTranslator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslator.java @@ -10,7 +10,7 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.examples.inference.cyclegan; +package ai.djl.modality.cv.translator; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; @@ -22,25 +22,30 @@ import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; +/** Built-in {@code Translator} that provides preprocessing and postprocessing for StyleTransfer. */ public class StyleTransferTranslator implements Translator { + + /** {@inheritDoc} */ @Override public NDList processInput(TranslatorContext ctx, Image input) { NDArray image = switchFormat(input.toNDArray(ctx.getNDManager())).expandDims(0); return new NDList(image.toType(DataType.FLOAT32, false)); } + /** {@inheritDoc} */ @Override public Image processOutput(TranslatorContext ctx, NDList list) { NDArray output = list.get(0).addi(1).muli(128).toType(DataType.UINT8, false); return ImageFactory.getInstance().fromNDArray(output.squeeze()); } - private NDArray switchFormat(NDArray array) { - return NDArrays.stack(array.split(3, 2)).squeeze(); - } - + /** {@inheritDoc} */ @Override public Batchifier getBatchifier() { return null; } + + private NDArray switchFormat(NDArray array) { + return NDArrays.stack(array.split(3, 2)).squeeze(); + } } diff --git a/api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactory.java b/api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactory.java new file mode 100644 index 00000000000..c2391b53ab5 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/cv/translator/StyleTransferTranslatorFactory.java @@ -0,0 +1,51 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.cv.translator; + +import ai.djl.Model; +import ai.djl.modality.cv.Image; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; +import java.lang.reflect.Type; +import java.util.Collections; +import java.util.Map; +import java.util.Set; + +/** A {@link TranslatorFactory} that creates a {@link StyleTransferTranslator} instance. */ +public class StyleTransferTranslatorFactory implements TranslatorFactory { + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return Collections.singleton(new Pair<>(Image.class, Image.class)); + } + + /** {@inheritDoc} */ + @Override + public Translator newInstance( + Class input, Class output, Model model, Map arguments) + throws TranslateException { + if (!isSupported(input, output)) { + throw new IllegalArgumentException("Unsupported input/output types."); + } + float truncation; + if (arguments.containsKey("truncation")) { + truncation = Float.parseFloat(arguments.get("truncation").toString()); + } else { + truncation = 0.5f; + } + return new BigGANTranslator(truncation); + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/cyclegan/StyleTransfer.java b/examples/src/main/java/ai/djl/examples/inference/cyclegan/StyleTransfer.java index ffa821ca19d..1d08ae61c86 100644 --- a/examples/src/main/java/ai/djl/examples/inference/cyclegan/StyleTransfer.java +++ b/examples/src/main/java/ai/djl/examples/inference/cyclegan/StyleTransfer.java @@ -14,10 +14,12 @@ import ai.djl.Application; import ai.djl.MalformedModelException; +import ai.djl.ModelException; import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.modality.cv.Image; import ai.djl.modality.cv.ImageFactory; +import ai.djl.modality.cv.translator.StyleTransferTranslatorFactory; import ai.djl.repository.zoo.Criteria; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.repository.zoo.ZooModel; @@ -43,10 +45,7 @@ public enum Artist { VANGOGH } - public static void main(String[] args) - throws IOException, ModelNotFoundException, MalformedModelException, - TranslateException { - + public static void main(String[] args) throws IOException, ModelException, TranslateException { Artist artist = Artist.MONET; String imagePath = "src/test/resources/mountains.png"; Image input = ImageFactory.getInstance().fromFile(Paths.get(imagePath)); @@ -79,7 +78,7 @@ public static Image transfer(Image image, Artist artist) .setTypes(Image.class, Image.class) .optModelUrls(modelUrl) .optProgress(new ProgressBar()) - .optTranslator(new StyleTransferTranslator()) + .optTranslatorFactory(new StyleTransferTranslatorFactory()) .build(); try (ZooModel model = criteria.loadModel(); diff --git a/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java b/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java index ccba066bff5..632d83af594 100644 --- a/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java +++ b/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java @@ -53,6 +53,9 @@ public class PtModelZoo extends ModelZoo { MRL bigGan = REPOSITORY.model(CV.IMAGE_GENERATION, GROUP_ID, "biggan-deep", "0.0.1"); MODEL_LOADERS.add(new BaseModelLoader(bigGan)); + + MRL cyclegan = REPOSITORY.model(CV.IMAGE_GENERATION, GROUP_ID, "cyclegan", "0.0.1"); + MODEL_LOADERS.add(new BaseModelLoader(cyclegan)); } /** {@inheritDoc} */