diff --git a/keras_nlp/src/utils/tensor_utils.py b/keras_nlp/src/utils/tensor_utils.py index 7502c38bcf..26d603a5d2 100644 --- a/keras_nlp/src/utils/tensor_utils.py +++ b/keras_nlp/src/utils/tensor_utils.py @@ -30,20 +30,19 @@ NO_CONVERT_COUNTER = threading.local() -NO_CONVERT_COUNTER.count = 0 @contextlib.contextmanager def no_convert_scope(): try: - NO_CONVERT_COUNTER.count += 1 + NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) + 1 yield finally: - NO_CONVERT_COUNTER.count -= 1 + NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) - 1 def in_no_convert_scope(): - return NO_CONVERT_COUNTER.count > 0 + return getattr(NO_CONVERT_COUNTER, "count", 0) > 0 def preprocessing_function(fn): @@ -119,7 +118,7 @@ def convert_preprocessing_inputs(x): return {k: convert_preprocessing_inputs(x[k]) for k, v in x.items()} if isinstance(x, tuple): return tuple(convert_preprocessing_inputs(v) for v in x) - if isinstance(x, str): + if isinstance(x, (str, bytes)): return tf.constant(x) if isinstance(x, list): try: @@ -132,7 +131,7 @@ def convert_preprocessing_inputs(x): # If ragged conversion failed return to the numpy error. raise e # If we have a string input, use tf.tensor. - if numpy_x.dtype.type is np.str_: + if numpy_x.dtype.type is np.str_ or numpy_x.dtype.type is np.bytes_: return tf.convert_to_tensor(x) # Numpy will default to int64, int32 works with more ops. if numpy_x.dtype == np.int64: diff --git a/keras_nlp/src/utils/tensor_utils_test.py b/keras_nlp/src/utils/tensor_utils_test.py index 463a267292..e9c5e97844 100644 --- a/keras_nlp/src/utils/tensor_utils_test.py +++ b/keras_nlp/src/utils/tensor_utils_test.py @@ -49,6 +49,17 @@ def test_strings(self): self.assertIsInstance(outputs, list) self.assertEqual(outputs, inputs) + def test_bytestrings(self): + inputs = ["one".encode("utf-8"), "two".encode("utf-8")] + # Convert to tf. + outputs = convert_preprocessing_inputs(inputs) + self.assertIsInstance(outputs, tf.Tensor) + self.assertAllEqual(outputs, tf.constant(inputs)) + # Convert from tf. + outputs = convert_preprocessing_outputs(outputs) + self.assertIsInstance(outputs, list) + self.assertEqual(outputs, [x.decode("utf-8") for x in inputs]) + def test_ragged(self): inputs = [np.ones((1, 3)), np.ones((1, 2))] # Convert to tf.