diff --git a/tensorflow-core/tensorflow-core-api/pom.xml b/tensorflow-core/tensorflow-core-api/pom.xml index 1d43bd43454..70150b68fde 100644 --- a/tensorflow-core/tensorflow-core-api/pom.xml +++ b/tensorflow-core/tensorflow-core-api/pom.xml @@ -15,7 +15,7 @@ Platform-dependent native code and pure-Java code for the TensorFlow machine intelligence library. - 1.0.0-rc.1 + 1.0.0 1.1.5 false ${project.build.directory}/tf-text-download/ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java index 9f524ef2544..0d6866c4e70 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java @@ -158,11 +158,15 @@ public Signature build() { return new Signature(key, signatureBuilder.build()); } - private static TensorInfo toTensorInfo(Output operand) { + static TensorInfo toTensorInfo(Output operand) { Shape shape = operand.shape(); TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder(); - for (int i = 0; i < shape.numDimensions(); ++i) { - tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i))); + if (shape.isUnknown()) { + tensorShapeBuilder.setUnknownRank(true); + } else { + for (int i = 0; i < shape.numDimensions(); ++i) { + tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.get(i))); + } } return TensorInfo.newBuilder() .setDtype(operand.dataType()) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java index 8f636fb7459..819ce473b20 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SignatureTest.java @@ -19,8 +19,12 @@ import java.util.Map; import org.junit.jupiter.api.Test; import org.tensorflow.Signature.TensorDescription; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.op.math.Sign; import org.tensorflow.proto.DataType; +import org.tensorflow.types.TInt32; public class SignatureTest { @@ -80,4 +84,28 @@ public void emptyMethodNameConvertedToNull() { signature = Signature.builder().key("f").methodName(null).build(); assertNull(signature.methodName()); } + + @Test + public void createTensorInfoFromOperandWithUnknownShape() { + try (Graph g = new Graph()) { + var tf = Ops.create(g); + var placeholder = tf.placeholder(TInt32.class); + var tensorInfo = Signature.Builder.toTensorInfo(placeholder.asOutput()); + assertTrue(tensorInfo.getTensorShape().getUnknownRank()); + assertEquals(0, tensorInfo.getTensorShape().getDimCount()); + } + } + + @Test + public void createTensorInfoFromOperandWithPartiallyUnknownShape() { + try (Graph g = new Graph()) { + var tf = Ops.create(g); + var placeholder = tf.placeholder(TInt32.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE, 10))); + var tensorInfo = Signature.Builder.toTensorInfo(placeholder.asOutput()); + assertFalse(tensorInfo.getTensorShape().getUnknownRank()); + assertEquals(2, tensorInfo.getTensorShape().getDimCount()); + assertEquals(-1, tensorInfo.getTensorShape().getDim(0).getSize()); + assertEquals(10, tensorInfo.getTensorShape().getDim(1).getSize()); + } + } }