diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java index ce383e02865..0deff350d9f 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/LibUtils.java @@ -290,10 +290,24 @@ private static LibTorch copyNativeLibraryFromClasspath(Platform platform) { return new LibTorch(dir.toAbsolutePath(), platform, flavor); } + Matcher m = VERSION_PATTERN.matcher(version); + if (!m.matches()) { + throw new IllegalArgumentException("Unexpected version: " + version); + } + String[] versions = m.group(1).split("\\."); + int minorVersion = Integer.parseInt(versions[1]); + int buildVersion = Integer.parseInt(versions[2]); + String pathPrefix; + if (minorVersion > 10 || (minorVersion == 10 && buildVersion == 2)) { + pathPrefix = "pytorch/" + flavor + '/' + classifier; + } else { + pathPrefix = "native/lib"; + } + Files.createDirectories(cacheDir); tmp = Files.createTempDirectory(cacheDir, "tmp"); for (String file : platform.getLibraries()) { - String libPath = "native/lib/" + file; + String libPath = pathPrefix + '/' + file; logger.info("Extracting {} to cache ...", libPath); try (InputStream is = ClassLoaderUtils.getResourceAsStream(libPath)) { Files.copy(is, tmp.resolve(file), StandardCopyOption.REPLACE_EXISTING); diff --git a/engines/pytorch/pytorch-native/build.gradle b/engines/pytorch/pytorch-native/build.gradle index 9c193fec037..73fbed5f068 100644 --- a/engines/pytorch/pytorch-native/build.gradle +++ b/engines/pytorch/pytorch-native/build.gradle @@ -283,7 +283,7 @@ flavorNames.each { flavor -> platformNames.each { osName -> tasks.create(name: "${flavor}-${osName}Jar", type: Jar) { doFirst { - def propFile = file("${BINARY_ROOT}/${flavor}/${osName}/native/lib/pytorch.properties") + def propFile = file("${BINARY_ROOT}/pytorch.properties") propFile.delete() def dsStore = file("${BINARY_ROOT}/${flavor}/${osName}/native/lib/.DS_Store") dsStore.delete() @@ -315,8 +315,13 @@ flavorNames.each { flavor -> libstd.text = new URL("https://publish.djl.ai/extra/THIRD-PARTY-LICENSES_qHnMKgbdWa.txt").text } } - from file("${BINARY_ROOT}/${flavor}/${osName}") - from file("src/main/resources") + from ("${BINARY_ROOT}/${flavor}/${osName}/native/lib") { + into ("pytorch/${flavor}/${osName}") + } + from ("${BINARY_ROOT}/pytorch.properties") { + into ("native/lib") + } + from "src/main/resources" archiveClassifier = "${osName}" manifest {