Skip to content

Commit

Permalink
Add tfhub url support
Browse files Browse the repository at this point in the history
Fixes #1122

Change-Id: Ia1a2fafc502cb07878ed23dea66f1914b8b3159a
  • Loading branch information
frankfliu committed Sep 19, 2021
1 parent 6defabb commit 0beaa8f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 15 deletions.
18 changes: 18 additions & 0 deletions api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
Expand Down Expand Up @@ -59,6 +60,23 @@ public Repository newInstance(String name, URI uri) {
throw new IllegalArgumentException("Malformed URL: " + uri, e);
}

if ("tfhub.dev".equals(uri.getHost().toLowerCase(Locale.ROOT))) {
// Handle tfhub case
String path = uri.getPath();
if (path.endsWith("/")) {
path = path.substring(0, path.length() - 1);
}
path = "/tfhub-modules" + path + ".tar.gz";
try {
uri = new URI("https", null, "storage.googleapis.com", -1, path, null, null);
} catch (URISyntaxException e) {
throw new IllegalArgumentException("Failed to append query string: " + uri, e);
}
String[] tokens = path.split("/");
String modelName = tokens[tokens.length - 2];
return new SimpleUrlRepository(name, uri, modelName);
}

Path path = parseFilePath(uri);
String fileName = path.toFile().getName();
if (FilenameUtils.isArchiveFile(fileName)) {
Expand Down
28 changes: 28 additions & 0 deletions api/src/test/java/ai/djl/repository/TfhubRepositoryTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.repository;

import org.testng.Assert;
import org.testng.annotations.Test;

public class TfhubRepositoryTest {

@Test
public void testResource() {
Repository repo =
Repository.newInstance(
"tfhub",
"https://tfhub.dev/tensorflow/faster_rcnn/inception_resnet_v2_640x640/1/");
Assert.assertEquals(repo.getResources().size(), 1);
}
}
45 changes: 30 additions & 15 deletions extensions/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ djl-bench currently support benchmark the following type of models:
- ONNX model
- PaddlePaddle model
- TFLite model
- TensorRT model
- Neo DLR (TVM) model
- XGBoost model

Expand Down Expand Up @@ -50,12 +51,24 @@ curl -O https://publish.djl.ai/djl-bench/0.12.0/djl-bench_0.12.0-1_all.deb
sudo dpkg -i djl-bench_0.12.0-1_all.deb
```

For centOS or Amazon Linux 2

You can download djl-bench zip file from [here](https://publish.djl.ai/djl-bench/0.12.0/benchmark-0.12.0.zip).

```
curl -O https://publish.djl.ai/djl-bench/0.12.0/benchmark-0.12.0.zip
unzip benchmark-0.12.0.zip
rm benchmark-0.12.0.zip
sudo ln -s $PWD/benchmark-0.12.0/bin/benchmark /usr/bin/djl-bench
```

For Windows

We are considering to create a `chocolatey` package for Windows. For the time being, you can
download djl-bench zip file from [here](https://publish.djl.ai/djl-bench/0.12.0/benchmark-0.12.0.zip).

Or you can run benchmark using gradle:

```
cd djl
Expand Down Expand Up @@ -87,10 +100,10 @@ they have different CUDA version to support. Please check the individual Engine
Here is a few sample benchmark script for you to refer. You can also skip this and directly follow
the 4-step instructions for your own model.

Benchmark on a Tensorflow model from http url with all-ones NDArray input for 10 times:
Benchmark on a Tensorflow model from [tfhub](https://tfhub.dev/) url with all-zeros NDArray input for 10 times:

```
djl-bench -e TensorFlow -u https://storage.googleapis.com/tfhub-modules/tensorflow/resnet_50/classification/1.tar.gz -c 10 -s 1,224,224,3
djl-bench -e TensorFlow -u https://tfhub.dev/tensorflow/resnet_50/classification/1 -c 10 -s 1,224,224,3
```

Similarly, this is for PyTorch
Expand All @@ -117,7 +130,6 @@ SSD object detection model:
djl-bench -e PyTorch -c 2 -s 1,3,300,300 -u djl://ai.djl.pytorch/ssd/0.0.1/ssd_300_resnet50
```


## Configuration of Benchmark script

To start your benchmarking, we need to make sure we provide the following information.
Expand All @@ -140,18 +152,20 @@ This will print out the possible arguments to pass in:

```
usage: djl-bench [-p MODEL-PATH] -s INPUT-SHAPES [OPTIONS]
-c,--iteration <ITERATION> Number of total iterations (per thread).
-d,--duration <DURATION> Duration of the test in minutes.
-e,--engine <ENGINE-NAME> Choose an Engine for the benchmark.
-g,--gpus <NUMBER_GPUS> Number of GPUS to run multithreading inference.
-h,--help Print this help.
-l,--delay <DELAY> Delay of incremental threads.
-n,--model-name <MODEL-NAME> Specify model file name.
-o,--output-dir <OUTPUT-DIR> Directory for output logs.
-p,--model-path <MODEL-PATH> Model directory file path.
-s,--input-shapes <INPUT-SHAPES> Input data shapes for the model.
-t,--threads <NUMBER_THREADS> Number of inference threads.
-u,--model-url <MODEL-URL> Model archive file URL.
-c,--iteration <ITERATION> Number of total iterations.
-d,--duration <DURATION> Duration of the test in minutes.
-e,--engine <ENGINE-NAME> Choose an Engine for the benchmark.
-g,--gpus <NUMBER_GPUS> Number of GPUS to run multithreading inference.
-h,--help Print this help.
-l,--delay <DELAY> Delay of incremental threads.
--model-arguments <MODEL-ARGUMENTS> Specify model loading arguments.
--model-options <MODEL-OPTIONS> Specify model loading options.
-n,--model-name <MODEL-NAME> Specify model file name.
-o,--output-dir <OUTPUT-DIR> Directory for output logs.
-p,--model-path <MODEL-PATH> Model directory file path.
-s,--input-shapes <INPUT-SHAPES> Input data shapes for the model.
-t,--threads <NUMBER_THREADS> Number of inference threads.
-u,--model-url <MODEL-URL> Model archive file URL.
```

### Step 1: Pick your deep engine
Expand All @@ -165,6 +179,7 @@ By default, the above script will use MXNet as the default Engine, but you can a
-e PaddlePaddle # PaddlePaddle
-e OnnxRuntime # pytorch
-e TFLite # TFLite
-e TensorRT # TensorRT
-e DLR # Neo DLR
-e XGBoost # XGBoost
```
Expand Down

0 comments on commit 0beaa8f

Please sign in to comment.