diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java index 338101c962b..fe30130d13c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java @@ -15,21 +15,20 @@ package org.tensorflow; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteBuffer; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteLibraryHandle; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_GetAllOpList; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_GetOpList; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_LoadLibrary; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_Version; - import com.google.protobuf.InvalidProtocolBufferException; +import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.PointerScope; import org.tensorflow.exceptions.TensorFlowException; -import org.tensorflow.internal.c_api.TF_Buffer; -import org.tensorflow.internal.c_api.TF_Library; -import org.tensorflow.internal.c_api.TF_Status; +import org.tensorflow.internal.c_api.*; import org.tensorflow.proto.framework.OpList; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + /** Static utility methods describing the TensorFlow runtime. */ public final class TensorFlow { /** Returns the version of the underlying TensorFlow runtime. */ @@ -103,6 +102,22 @@ private static OpList libraryOpList(TF_Library handle) { } } + public static List listDevices(Optional deviceType, TFE_Context ctx) { + List deviceList = new ArrayList(); + TF_Status status = TF_Status.newStatus(); + TF_DeviceList devices = TFE_ContextListDevices(ctx, status); + for(int i = 0; i d.deviceType().equals(deviceType.get())).collect(Collectors.toList()); + } + private TensorFlow() {} /** Load the TensorFlow runtime C library. */ diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java index f8eeb84de90..76dc3987a78 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorFlowTest.java @@ -15,16 +15,16 @@ package org.tensorflow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assumptions.assumeTrue; import java.io.File; import java.nio.file.Paths; +import java.util.List; +import java.util.Optional; import org.junit.jupiter.api.Test; +import org.tensorflow.internal.c_api.TFE_Context; import org.tensorflow.proto.framework.OpList; /** Unit tests for {@link org.tensorflow.TensorFlow}. */ @@ -67,4 +67,14 @@ public void loadLibrary() { g.opBuilder("MyTest", "MyTest").build(); } } + + @Test + public void getDeviceListTest(){ + TFE_Context context = new TFE_Context(); + List devices = TensorFlow.listDevices(Optional.empty(), context); + assertFalse(devices.isEmpty()); + devices = TensorFlow.listDevices(Optional.of(DeviceSpec.DeviceType.CPU), context); + assertFalse(devices.isEmpty()); + + } }