From dda4d874dc4b31ca13426cef39f0309ee3e05f59 Mon Sep 17 00:00:00 2001 From: Tom Burke Date: Thu, 17 Dec 2020 11:49:31 +0100 Subject: [PATCH 1/2] Add devicelist getter --- .../main/java/org/tensorflow/TensorFlow.java | 33 +++++++++++++------ .../java/org/tensorflow/TensorFlowTest.java | 17 +++++++--- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 338101c962b..ba0afb63612 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -15,21 +15,19 @@ package org.tensorflow; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteBuffer; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteLibraryHandle; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_GetAllOpList; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_GetOpList; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadLibrary; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_Version; - import com.google.protobuf.InvalidProtocolBufferException; +import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerScope; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.c_api.TF_Buffer; -import org.tensorflow.internal.c_api.TF_Library; -import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.internal.c_api.*; import org.tensorflow.proto.framework.OpList; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + /** Static utility methods describing the TensorFlow runtime. */ public final class TensorFlow { /** Returns the version of the underlying TensorFlow runtime. */ @@ -103,6 +101,21 @@ private static OpList libraryOpList(TF_Library handle) { } } + public static List listDevices(DeviceSpec.DeviceType deviceType, TFE_Context ctx) { + List deviceList = new ArrayList(); + TF_Status status = TF_Status.newStatus(); + TF_DeviceList devices = TFE_ContextListDevices(ctx, status); + for(int i = 0; i d.deviceType().equals(deviceType)).collect(Collectors.toList()); + } + private TensorFlow() {} /** Load the TensorFlow runtime C library. */ diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java index f8eeb84de90..e18f4edf8ae 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java @@ -15,16 +15,15 @@ package org.tensorflow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assumptions.assumeTrue; import java.io.File; import java.nio.file.Paths; +import java.util.List; import org.junit.jupiter.api.Test; +import org.tensorflow.internal.c_api.TFE_Context; import org.tensorflow.proto.framework.OpList; /** Unit tests for {@link org.tensorflow.TensorFlow}. */ @@ -67,4 +66,14 @@ public void loadLibrary() { g.opBuilder("MyTest", "MyTest").build(); } } + + @Test + public void getDeviceListTest(){ + List devices = TensorFlow.listDevices(DeviceSpec.DeviceType.GPU, new TFE_Context()); + assertFalse(devices.isEmpty()); + System.out.println(devices.toString()); + devices = TensorFlow.listDevices(DeviceSpec.DeviceType.CPU, new TFE_Context()); + assertFalse(devices.isEmpty()); + System.out.println(devices.toString()); + } } From 37c0c38e7fae5a248c22f7b959974754b0d05756 Mon Sep 17 00:00:00 2001 From: Tom Burke Date: Thu, 17 Dec 2020 17:54:05 +0100 Subject: [PATCH 2/2] using optionals as default without selecting devicetype --- .../src/main/java/org/tensorflow/TensorFlow.java | 6 ++++-- .../src/test/java/org/tensorflow/TensorFlowTest.java | 9 +++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index ba0afb63612..fe30130d13c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import static org.tensorflow.internal.c_api.global.tensorflow.*; @@ -101,7 +102,7 @@ private static OpList libraryOpList(TF_Library handle) { } } - public static List listDevices(DeviceSpec.DeviceType deviceType, TFE_Context ctx) { + public static List listDevices(Optional deviceType, TFE_Context ctx) { List deviceList = new ArrayList(); TF_Status status = TF_Status.newStatus(); TF_DeviceList devices = TFE_ContextListDevices(ctx, status); @@ -113,7 +114,8 @@ public static List listDevices(DeviceSpec.DeviceType deviceType, TFE deviceList.add(devSpec); } TF_DeleteDeviceList(devices); - return deviceList.stream().filter(d -> d.deviceType().equals(deviceType)).collect(Collectors.toList()); + if(deviceType.isPresent()) return deviceList; + return deviceList.stream().filter(d -> d.deviceType().equals(deviceType.get())).collect(Collectors.toList()); } private TensorFlow() {} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java index e18f4edf8ae..76dc3987a78 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java @@ -21,6 +21,7 @@ import java.io.File; import java.nio.file.Paths; import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.Test; import org.tensorflow.internal.c_api.TFE_Context; @@ -69,11 +70,11 @@ public void loadLibrary() { @Test public void getDeviceListTest(){ - List devices = TensorFlow.listDevices(DeviceSpec.DeviceType.GPU, new TFE_Context()); + TFE_Context context = new TFE_Context(); + List devices = TensorFlow.listDevices(Optional.empty(), context); assertFalse(devices.isEmpty()); - System.out.println(devices.toString()); - devices = TensorFlow.listDevices(DeviceSpec.DeviceType.CPU, new TFE_Context()); + devices = TensorFlow.listDevices(Optional.of(DeviceSpec.DeviceType.CPU), context); assertFalse(devices.isEmpty()); - System.out.println(devices.toString()); + } }