-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Closed
Description
Apart from tensorflow as backend, what are the proper approach to use basic operatons (i.e. tf.concat) inside the tf.data API pipelines? The following code works with tensorflow backend, but not with torch or jax.
import os
os.environ["KERAS_BACKEND"] = "jax" # tensorflow, torch, jax
import keras
from keras import layers
import tensorflow as tf
aug_model = keras.Sequential([
keras.Input(shape=(224, 224, 3)),
layers.RandomFlip("horizontal_and_vertical")
])
def augment_data_tf(x, y):
combined = tf.concat([x, y], axis=-1)
z = aug_model(combined)
x = z[..., :3]
y = z[..., 3:]
return x, y
a = np.ones((4, 224, 224, 3)).astype(np.float32)
b = np.ones((4, 224, 224, 2)).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices((a, b))
dataset = dataset.batch(3, drop_remainder=True)
dataset = dataset.map(
augment_data_tf, num_parallel_calls=tf.data.AUTOTUNE
)---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
[<ipython-input-7-2d25b0c0bbad>](https://localhost:8080/#) in <cell line: 3>()
1 dataset = tf.data.Dataset.from_tensor_slices((a, b))
2 dataset = dataset.batch(3, drop_remainder=True)
----> 3 dataset = dataset.map(
4 augment_data_tf, num_parallel_calls=tf.data.AUTOTUNE
5 )
25 frames
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in _convert_to_array_if_dtype_fails(x)
4102 dtypes.dtype(x)
4103 except TypeError:
-> 4104 return np.asarray(x)
4105 else:
4106 return x
NotImplementedError: in user code:
File "<ipython-input-5-ca4b074b58a5>", line 6, in augment_data_tf *
z = aug_model(combined)
File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler **
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.10/dist-packages/optree/ops.py", line 752, in tree_map
return treespec.unflatten(map(func, *flat_args))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4252, in asarray
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4058, in array
leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4058, in <listcomp>
leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4104, in _convert_to_array_if_dtype_fails
return np.asarray(x)
NotImplementedError: Cannot convert a symbolic tf.Tensor (concat:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.Metadata
Metadata
Assignees
Labels
No labels