Skip to content

Commit

Permalink
Allows load ModelZoo model using url
Browse files Browse the repository at this point in the history
Change-Id: I2727d6fef2675f3f417d35322244dd1228b3816b
  • Loading branch information
frankfliu committed Jul 30, 2021
1 parent 154ae6e commit 01a9488
Show file tree
Hide file tree
Showing 32 changed files with 465 additions and 248 deletions.
2 changes: 1 addition & 1 deletion api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies {
exclude group: "junit", module: "junit"
}
testImplementation "org.slf4j:slf4j-simple:${slf4j_version}"
testRuntimeOnly project(":mxnet:mxnet-engine")
testRuntimeOnly project(":mxnet:mxnet-model-zoo")
testRuntimeOnly "ai.djl.mxnet:mxnet-native-auto:${mxnet_version}"
}

Expand Down
14 changes: 6 additions & 8 deletions api/src/main/java/ai/djl/repository/AbstractRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
import java.security.DigestInputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;
import java.util.zip.ZipInputStream;
Expand Down Expand Up @@ -111,7 +109,7 @@ public void prepare(Artifact artifact, Progress progress) throws IOException {

Path parentDir = resourceDir.toAbsolutePath().getParent();
if (parentDir == null) {
throw new AssertionError("Parent path should never be null: " + resourceDir.toString());
throw new AssertionError("Parent path should never be null: " + resourceDir);
}

Files.createDirectories(parentDir);
Expand Down Expand Up @@ -145,15 +143,16 @@ public Path getCacheDirectory() throws IOException {
if (Files.notExists(dir)) {
Files.createDirectories(dir);
} else if (!Files.isDirectory(dir)) {
throw new IOException("Failed initialize cache directory: " + dir.toString());
throw new IOException("Failed initialize cache directory: " + dir);
}
return dir;
}

/** {@inheritDoc} */
@Override
public List<MRL> getResources() {
return Collections.emptyList();
public void addResource(MRL mrl) {
throw new IllegalArgumentException(
getClass().getSimpleName() + " doesn't support addResource.");
}

protected void download(Path tmp, URI baseUri, Artifact.Item item, Progress progress)
Expand Down Expand Up @@ -228,8 +227,7 @@ private void untar(InputStream is, Path dir, boolean gzip) throws IOException {
} else {
Path parentFile = file.getParent();
if (parentFile == null) {
throw new AssertionError(
"Parent path should never be null: " + file.toString());
throw new AssertionError("Parent path should never be null: " + file);
}
Files.createDirectories(parentFile);
Files.copy(tis, file, StandardCopyOption.REPLACE_EXISTING);
Expand Down
25 changes: 20 additions & 5 deletions api/src/main/java/ai/djl/repository/MRL.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public final class MRL {
private String groupId;
private String artifactId;
private String version;
private String artifactName;
private Repository repository;
private Metadata metadata;

Expand All @@ -61,20 +62,23 @@ public final class MRL {
* @param groupId the desired groupId
* @param artifactId the desired artifactId
* @param version the resource version
* @param artifactName the desired artifact name
*/
private MRL(
Repository repository,
String type,
Application application,
String groupId,
String artifactId,
String version) {
String version,
String artifactName) {
this.repository = repository;
this.type = type;
this.application = application;
this.groupId = groupId;
this.artifactId = artifactId;
this.version = version;
this.artifactName = artifactName;
}

/**
Expand All @@ -85,15 +89,18 @@ private MRL(
* @param groupId the desired groupId
* @param artifactId the desired artifactId
* @param version the resource version
* @param artifactName the desired artifact name
* @return a model {@code MRL}
*/
public static MRL model(
Repository repository,
Application application,
String groupId,
String artifactId,
String version) {
return new MRL(repository, "model", application, groupId, artifactId, version);
String version,
String artifactName) {
return new MRL(
repository, "model", application, groupId, artifactId, version, artifactName);
}

/**
Expand All @@ -112,7 +119,7 @@ public static MRL dataset(
String groupId,
String artifactId,
String version) {
return new MRL(repository, "dataset", application, groupId, artifactId, version);
return new MRL(repository, "dataset", application, groupId, artifactId, null, version);
}

/**
Expand All @@ -124,7 +131,7 @@ public static MRL dataset(
* @return a dataset {@code MRL}
*/
public static MRL undefined(Repository repository, String groupId, String artifactId) {
return new MRL(repository, "", Application.UNDEFINED, groupId, artifactId, null);
return new MRL(repository, "", Application.UNDEFINED, groupId, artifactId, null, null);
}

/**
Expand Down Expand Up @@ -215,6 +222,14 @@ public Artifact match(Map<String, String> criteria) throws IOException {
if (list.isEmpty()) {
return null;
}
if (artifactName != null) {
for (Artifact artifact : list) {
if (artifactName.equals(artifact.getName())) {
return artifact;
}
}
return null;
}
return list.get(0);
}

Expand Down
22 changes: 20 additions & 2 deletions api/src/main/java/ai/djl/repository/RemoteRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.List;
import java.util.Map;
Expand All @@ -37,6 +40,7 @@ public class RemoteRepository extends AbstractRepository {

private String name;
private URI uri;
private List<MRL> resources;

/**
* (Internal) Constructs a remote repository.
Expand Down Expand Up @@ -118,7 +122,21 @@ public Artifact resolve(MRL mrl, Map<String, String> filter) throws IOException
if (artifacts.isEmpty()) {
return null;
}
// TODO: find highest version.
return artifacts.get(0);
return artifacts.stream().max(Comparator.comparing(o -> new Version(o.getVersion()))).get();
}

/** {@inheritDoc} */
@Override
public List<MRL> getResources() {
return resources == null ? Collections.emptyList() : resources;
}

/** {@inheritDoc} */
@Override
public void addResource(MRL mrl) {
if (resources == null) {
resources = new ArrayList<>();
}
resources.add(mrl);
}
}
34 changes: 30 additions & 4 deletions api/src/main/java/ai/djl/repository/Repository.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public interface Repository {
* @return the new repository
*/
static Repository newInstance(String name, Path path) {
return RepositoryFactoryImpl.getFactory().newInstance(name, path.toUri().toString());
return RepositoryFactoryImpl.getFactory().newInstance(name, path.toUri());
}

/**
Expand All @@ -86,7 +86,7 @@ static Repository newInstance(String name, Path path) {
* @return the new repository
*/
static Repository newInstance(String name, String url) {
return RepositoryFactoryImpl.getFactory().newInstance(name, url);
return RepositoryFactoryImpl.getFactory().newInstance(name, URI.create(url));
}

/**
Expand All @@ -107,7 +107,7 @@ static void registerRepositoryFactory(RepositoryFactory factory) {
* @return a model {@code MRL}
*/
default MRL model(Application application, String groupId, String artifactId) {
return model(application, groupId, artifactId, null);
return model(application, groupId, artifactId, null, null);
}

/**
Expand All @@ -120,7 +120,26 @@ default MRL model(Application application, String groupId, String artifactId) {
* @return a model {@code MRL}
*/
default MRL model(Application application, String groupId, String artifactId, String version) {
return MRL.model(this, application, groupId, artifactId, version);
return MRL.model(this, application, groupId, artifactId, version, null);
}

/**
* Creates a model {@code MRL} with specified application.
*
* @param application the desired application
* @param groupId the desired groupId
* @param artifactId the desired artifactId
* @param version the resource version
* @param artifactName the desired artifact name
* @return a model {@code MRL}
*/
default MRL model(
Application application,
String groupId,
String artifactId,
String version,
String artifactName) {
return MRL.model(this, application, groupId, artifactId, version, artifactName);
}

/**
Expand Down Expand Up @@ -266,4 +285,11 @@ default Path getResourceDirectory(Artifact artifact) throws IOException {
* @return a list of {@link MRL}s in the repository
*/
List<MRL> getResources();

/**
* Adds resource to the repository.
*
* @param mrl the resource to add
*/
void addResource(MRL mrl);
}
5 changes: 3 additions & 2 deletions api/src/main/java/ai/djl/repository/RepositoryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.repository;

import java.net.URI;
import java.util.Set;

/** A interface responsible to create {@link ai.djl.repository.Repository} instances. */
Expand All @@ -21,10 +22,10 @@ public interface RepositoryFactory {
* Creates a new instance of a repository with a name and url.
*
* @param name the repository name
* @param url the repository location
* @param uri the repository location
* @return the new repository
*/
Repository newInstance(String name, String url);
Repository newInstance(String name, URI uri);

/**
* Returns a set of URI scheme that the {@code RepositoryFactory} supports.
Expand Down
Loading

0 comments on commit 01a9488

Please sign in to comment.