Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bfloat16 arrays are incompatible with pickle #8505

Closed
jakevdp opened this issue Nov 10, 2021 · 7 comments
Closed

bfloat16 arrays are incompatible with pickle #8505

jakevdp opened this issue Nov 10, 2021 · 7 comments
Assignees
Labels
bug Something isn't working P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 10, 2021

import pickle
import jax.numpy as jnp
pickle.dumps(jnp.bfloat16(0))
---------------------------------------------------------------------------
PicklingError                             Traceback (most recent call last)
<ipython-input-13-9fa7e9015eb7> in <module>()
      1 import pickle
      2 import jax.numpy as jnp
----> 3 pickle.dumps(jnp.bfloat16(0))

PicklingError: Can't pickle <class 'bfloat16'>: attribute lookup bfloat16 on builtins failed

I think this is due to the fact that picle the bfloat16 custom type does not correctly define its full class name. Compare these::

from jax._src.lib import xla_client
print(type(xla_client.bfloat16(0.0)))
# <class 'bfloat16'>

print(type(jnp.array(0)))
# <class 'jaxlib.xla_extension.DeviceArray'>

print(type(np.float16(0))
# <class 'numpy.float16'>
@jakevdp jakevdp added the bug Something isn't working label Nov 10, 2021
@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 10, 2021

I think for this to work correctly, we'd need the tp_name attribute here to correspond to its actual import location within JAX, i.e. jaxlib.xla_client.bfloat16.

That's not straightforward, though, because this type definition is shared with tensorflow and exported at a different package path within that library.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 10, 2021

Indeed the tensorflow bfloat16 dtype has the same issue:

import tensorflow as tf
import pickle
import numpy as np
pickle.dumps(np.ones(10, dtype=tf.bfloat16.as_numpy_dtype))
---------------------------------------------------------------------------
PicklingError                             Traceback (most recent call last)
<ipython-input-24-272a81341603> in <module>()
      2 import pickle
      3 import numpy as np
----> 4 pickle.dumps(np.ones(10, dtype=tf.bfloat16.as_numpy_dtype))

PicklingError: Can't pickle <class 'bfloat16'>: attribute lookup bfloat16 on builtins failed

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 10, 2021

I think once this bfloat16 name issue is resolved, numpy's pickle serialization/deserialization should work. For example:

import jax.numpy as jnp
x = jnp.arange(4, dtype='bfloat16')
constructor, args, state = x.__reduce__()
x2 = constructor(*args)
x2.__setstate__(state)
print(repr(x2))
# array([0, 1, 2, 3], dtype=bfloat16)

The error here comes from when pickle tries to serialize the bfloat16 dtype in state, for the reasons described above.

@jakevdp jakevdp added the P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR label Nov 10, 2021
@jakevdp
Copy link
Collaborator Author

jakevdp commented May 10, 2022

Here's a hack that can be used to work around this:

import builtins
import pickle
from jax._src.lib import xla_client
import jax.numpy as jnp

# Hack: this is the module reported by this object.
builtins.bfloat16 = xla_client.bfloat16

s = pickle.dumps(jnp.arange(10, dtype='bfloat16'))
y = pickle.loads(s)
print(repr(y))
# array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=bfloat16)

One thing we could do to fix this (though it's kind of terrible) is to set this attribute of builtins when jax is imported.

@hamzamerzic
Copy link
Contributor

A marginally simpler hack that we've been using is

builtins.bfloat16 = jnp.dtype('bfloat16').type
instead of
builtins.bfloat16 = xla_client.bfloat16,

as it avoids the private from._src.lib import.

Setting builtins.bfloat16 = jnp.bfloat16 does not seem to work, even though xla_client.bfloat16 == jnp.bfloat16 evaluates to True.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Mar 23, 2023

#15122 fixed this!

The type is now just ml_dtypes.bfloat16, and it should be pickleable (there is a test for it).

@jakevdp
Copy link
Collaborator Author

jakevdp commented Mar 23, 2023

Pickle is solved, but there are still problems with numpy-native array serialization: jax-ml/ml_dtypes#41

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

No branches or pull requests

4 participants