Skip to content

Commit

Permalink
[pytorch] Allows multiple native jars to package into fat jar
Browse files Browse the repository at this point in the history
Fixes #1422

Change-Id: I477d85725e84f30a7a79733b16d0224abe83f3ce
  • Loading branch information
frankfliu committed Mar 24, 2022
1 parent 33c29ed commit 1089b45
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,24 @@ private static LibTorch copyNativeLibraryFromClasspath(Platform platform) {
return new LibTorch(dir.toAbsolutePath(), platform, flavor);
}

Matcher m = VERSION_PATTERN.matcher(version);
if (!m.matches()) {
throw new IllegalArgumentException("Unexpected version: " + version);
}
String[] versions = m.group(1).split("\\.");
int minorVersion = Integer.parseInt(versions[1]);
int buildVersion = Integer.parseInt(versions[2]);
String pathPrefix;
if (minorVersion > 10 || (minorVersion == 10 && buildVersion == 2)) {
pathPrefix = "pytorch/" + flavor + '/' + classifier;
} else {
pathPrefix = "native/lib";
}

Files.createDirectories(cacheDir);
tmp = Files.createTempDirectory(cacheDir, "tmp");
for (String file : platform.getLibraries()) {
String libPath = "native/lib/" + file;
String libPath = pathPrefix + '/' + file;
logger.info("Extracting {} to cache ...", libPath);
try (InputStream is = ClassLoaderUtils.getResourceAsStream(libPath)) {
Files.copy(is, tmp.resolve(file), StandardCopyOption.REPLACE_EXISTING);
Expand Down
11 changes: 8 additions & 3 deletions engines/pytorch/pytorch-native/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ flavorNames.each { flavor ->
platformNames.each { osName ->
tasks.create(name: "${flavor}-${osName}Jar", type: Jar) {
doFirst {
def propFile = file("${BINARY_ROOT}/${flavor}/${osName}/native/lib/pytorch.properties")
def propFile = file("${BINARY_ROOT}/pytorch.properties")
propFile.delete()
def dsStore = file("${BINARY_ROOT}/${flavor}/${osName}/native/lib/.DS_Store")
dsStore.delete()
Expand Down Expand Up @@ -315,8 +315,13 @@ flavorNames.each { flavor ->
libstd.text = new URL("https://publish.djl.ai/extra/THIRD-PARTY-LICENSES_qHnMKgbdWa.txt").text
}
}
from file("${BINARY_ROOT}/${flavor}/${osName}")
from file("src/main/resources")
from ("${BINARY_ROOT}/${flavor}/${osName}/native/lib") {
into ("pytorch/${flavor}/${osName}")
}
from ("${BINARY_ROOT}/pytorch.properties") {
into ("native/lib")
}
from "src/main/resources"
archiveClassifier = "${osName}"

manifest {
Expand Down

0 comments on commit 1089b45

Please sign in to comment.