From dcb9e118e1cea806320a9900d8ea62699b5210e4 Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Fri, 31 May 2024 15:40:45 -0400 Subject: [PATCH] Fix broadcastMask/Update Accept partially unknown shaped mask --- .../org/tensorflow/op/core/BooleanMask.java | 2 +- .../tensorflow/op/core/BooleanMaskUpdate.java | 2 +- .../tensorflow/op/core/BooleanMaskTest.java | 36 ++++++++++++++++ .../op/core/BooleanMaskUpdateTest.java | 41 +++++++++++++++++++ 4 files changed, 79 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java index 5c20bc7c9e4..f83cf577889 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java @@ -78,7 +78,7 @@ public static Operand create( if (maskShape.numDimensions() == 0) { throw new IllegalArgumentException("Mask cannot be a scalar."); } - if (maskShape.hasUnknownDimension()) { + if (maskShape.isUnknown()) { throw new IllegalArgumentException("Mask cannot have unknown number of dimensions"); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java index 81eb5c507ea..d402bda432a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java @@ -86,7 +86,7 @@ public static Operand create( if (maskShape.numDimensions() == 0) { throw new IllegalArgumentException("Mask cannot be a scalar."); } - if (maskShape.hasUnknownDimension()) { + if (maskShape.isUnknown()) { throw new IllegalArgumentException("Mask cannot have unknown number of dimensions"); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java index 246b44b8077..53af60d44bf 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java @@ -17,6 +17,7 @@ package org.tensorflow.op.core; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; @@ -66,4 +67,39 @@ public void testBooleanMask() { } } } + + @Test + public void testBooleanMaskWithPartiallyUnknownShape() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new OpScope(g); + + Operand input = Constant.arrayOf(scope, 1, 2, 3, 4); + Placeholder inputMask = + Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE))); + + Operand output = BooleanMask.create(scope, input, inputMask); + + try (TBool mask = TBool.vectorOf(true, false, false, true); + TInt32 result = (TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) { + // expected shape from Python tensorflow + assertEquals(Shape.of(2), result.shape()); + assertEquals(1, result.getInt(0)); + assertEquals(4, result.getInt(1)); + } + } + } + + @Test + public void testBooleanMaskWithUnknownShape() { + try (Graph g = new Graph()) { + Scope scope = new OpScope(g); + + Operand input = Constant.arrayOf(scope, 1, 2, 3, 4); + Placeholder inputMask = Placeholder.create(scope, TBool.class); + + assertThrows( + IllegalArgumentException.class, () -> BooleanMask.create(scope, input, inputMask)); + } + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java index 4edbea33b0d..84f4229144b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -17,6 +17,7 @@ package org.tensorflow.op.core; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; @@ -151,4 +152,44 @@ public void testBooleanMaskUpdateAxis() { } } } + + @Test + public void testBooleanMaskUpdateWithPartiallyUnknownShape() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new OpScope(g); + + Operand input = Constant.arrayOf(scope, 1, 2, 3, 4); + Operand updates = Constant.arrayOf(scope, -1, 2); + Placeholder inputMask = + Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE))); + + Operand output = BooleanMaskUpdate.create(scope, input, inputMask, updates); + + try (TBool mask = TBool.vectorOf(false, true, false, true); + TInt32 result = (TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) { + // expected shape from Python tensorflow + assertEquals(Shape.of(4), result.shape()); + assertEquals(1, result.getInt(0)); + assertEquals(-1, result.getInt(1)); + assertEquals(3, result.getInt(2)); + assertEquals(2, result.getInt(3)); + } + } + } + + @Test + public void testBooleanMaskUpdateWithUnknownShape() { + try (Graph g = new Graph()) { + Scope scope = new OpScope(g); + + Operand input = Constant.arrayOf(scope, 1, 2, 3, 4); + Operand updates = Constant.arrayOf(scope, -1, 2); + Placeholder inputMask = Placeholder.create(scope, TBool.class); + + assertThrows( + IllegalArgumentException.class, + () -> BooleanMaskUpdate.create(scope, input, inputMask, updates)); + } + } }