Skip to content

Is it possible to use tf.data with tf operations while utilizing jax or torch as the backend? #20722

@innat

Description

@innat

Apart from tensorflow as backend, what are the proper approach to use basic operatons (i.e. tf.concat) inside the tf.data API pipelines? The following code works with tensorflow backend, but not with torch or jax.

import os
os.environ["KERAS_BACKEND"] = "jax" # tensorflow, torch, jax

import keras
from keras import layers
import tensorflow as tf

aug_model = keras.Sequential([
    keras.Input(shape=(224, 224, 3)),
    layers.RandomFlip("horizontal_and_vertical")
])

def augment_data_tf(x, y):
    combined = tf.concat([x, y], axis=-1)
    z = aug_model(combined)
    x = z[..., :3]
    y = z[..., 3:]
    return x, y

a = np.ones((4, 224, 224, 3)).astype(np.float32)
b = np.ones((4, 224, 224, 2)).astype(np.float32)

dataset = tf.data.Dataset.from_tensor_slices((a, b))
dataset = dataset.batch(3, drop_remainder=True)
dataset = dataset.map(
    augment_data_tf, num_parallel_calls=tf.data.AUTOTUNE
)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
[<ipython-input-7-2d25b0c0bbad>](https://localhost:8080/#) in <cell line: 3>()
      1 dataset = tf.data.Dataset.from_tensor_slices((a, b))
      2 dataset = dataset.batch(3, drop_remainder=True)
----> 3 dataset = dataset.map(
      4     augment_data_tf, num_parallel_calls=tf.data.AUTOTUNE
      5 )

25 frames
[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in _convert_to_array_if_dtype_fails(x)
   4102     dtypes.dtype(x)
   4103   except TypeError:
-> 4104     return np.asarray(x)
   4105   else:
   4106     return x

NotImplementedError: in user code:

    File "<ipython-input-5-ca4b074b58a5>", line 6, in augment_data_tf  *
        z = aug_model(combined)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.10/dist-packages/optree/ops.py", line 752, in tree_map
        return treespec.unflatten(map(func, *flat_args))
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4252, in asarray
        return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4058, in array
        leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4058, in <listcomp>
        leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
    File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 4104, in _convert_to_array_if_dtype_fails
        return np.asarray(x)

    NotImplementedError: Cannot convert a symbolic tf.Tensor (concat:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions