Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device argument for safetensors.flax.load_file #399

Closed
mar-muel opened this issue Dec 8, 2023 · 4 comments · Fixed by #427
Closed

Add device argument for safetensors.flax.load_file #399

mar-muel opened this issue Dec 8, 2023 · 4 comments · Fixed by #427
Labels

Comments

@mar-muel
Copy link

mar-muel commented Dec 8, 2023

Feature request

Hey there - love this library! 👍

Any reason why the device argument is not valid (anymore?) for load_file for flax?

Also a bit confused, as it is listed as an argument in the docs 🤔 https://huggingface.co/docs/safetensors/main/en/api/flax#safetensors.flax.load_file

I'm using safetensors==0.4.1

Motivation

It's useful to have control over device placement during model load

Your contribution

Probably not...

Copy link

github-actions bot commented Jan 8, 2024

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label Jan 8, 2024
@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Jan 14, 2024
@Narsil
Copy link
Collaborator

Narsil commented Jan 17, 2024

Thanks for the note, the docstring is outdated or bad copy pasted.

The reason for the argument not being here, is that Flax doesn't provide a way to create tensors directly on device (afaik),
meaning it's not going to yield any differences from loading on CPU then moving to whatever device.

Also I thought lazy tensors placements for flax was more idiomatic. How exactly do you move the tensors ?

@mar-muel
Copy link
Author

@Narsil I later found out I can load my Flax msgpack models directly to CPU with this:

cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
    with open(msgpack_file, "rb") as state_f:
        state = from_bytes(cls, state_f.read())

In any case, I've now moved to saving my flax model in numpy format - which is what you get if you use jax.device_get():

>>> x = jnp.zeros((5,5))
>>> type(x)
<class 'jaxlib.xla_extension.ArrayImpl'>   # array is on device
>>> type(jax.device_get(x))
<class 'numpy.ndarray'>

Then to load the model

state = safetensors.numpy.load_file(st_file)   # np arrays on CPU
state = jax_utils.replicate(state)  # returns jax.numpy arrays replicated on default device

@Narsil
Copy link
Collaborator

Narsil commented Jan 18, 2024

Thanks for sharing your fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants