-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bfloat16 arrays are incompatible with pickle #8505
Comments
I think for this to work correctly, we'd need the That's not straightforward, though, because this type definition is shared with tensorflow and exported at a different package path within that library. |
Indeed the tensorflow bfloat16 dtype has the same issue: import tensorflow as tf
import pickle
import numpy as np
pickle.dumps(np.ones(10, dtype=tf.bfloat16.as_numpy_dtype)) ---------------------------------------------------------------------------
PicklingError Traceback (most recent call last)
<ipython-input-24-272a81341603> in <module>()
2 import pickle
3 import numpy as np
----> 4 pickle.dumps(np.ones(10, dtype=tf.bfloat16.as_numpy_dtype))
PicklingError: Can't pickle <class 'bfloat16'>: attribute lookup bfloat16 on builtins failed |
I think once this bfloat16 name issue is resolved, numpy's pickle serialization/deserialization should work. For example: import jax.numpy as jnp
x = jnp.arange(4, dtype='bfloat16')
constructor, args, state = x.__reduce__()
x2 = constructor(*args)
x2.__setstate__(state)
print(repr(x2))
# array([0, 1, 2, 3], dtype=bfloat16) The error here comes from when pickle tries to serialize the |
Here's a hack that can be used to work around this: import builtins
import pickle
from jax._src.lib import xla_client
import jax.numpy as jnp
# Hack: this is the module reported by this object.
builtins.bfloat16 = xla_client.bfloat16
s = pickle.dumps(jnp.arange(10, dtype='bfloat16'))
y = pickle.loads(s)
print(repr(y))
# array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=bfloat16) One thing we could do to fix this (though it's kind of terrible) is to set this attribute of |
A marginally simpler hack that we've been using is
as it avoids the private
|
#15122 fixed this! The type is now just |
Pickle is solved, but there are still problems with numpy-native array serialization: jax-ml/ml_dtypes#41 |
I think this is due to the fact that picle the
bfloat16
custom type does not correctly define its full class name. Compare these::The text was updated successfully, but these errors were encountered: