Skip to content

Commit

Permalink
[util/misc][BugFix] transfer GPU arrays to CPU prior to save_zarr call
Browse files Browse the repository at this point in the history
  • Loading branch information
joanrue committed Mar 20, 2024
1 parent 5b53bd8 commit 37ac94f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/pyxu/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import inspect
import os
import types
import typing as typ

import dask.array as da
import zarr

import pyxu.info.deps as pxd
import pyxu.info.ptype as pxt
import pyxu.util.array_module as pxam

__all__ = [
"copy_if_unsafe",
Expand Down Expand Up @@ -250,12 +250,12 @@ def save_zarr(filedir: pxt.Path, kw_in: dict[str, pxt.NDArray]) -> None:
if array is not None:
if pxd.NDArrayInfo.from_obj(array) == pxd.NDArrayInfo.DASK:
array.to_zarr(
filedir / ("dask_" + filename),
overwrite=True,
compute=True,
filedir / ("dask_" + filename),
overwrite=True,
compute=True,
)
else:
zarr.save(filedir / filename, array)
zarr.save(filedir / filename, pxam.to_NUMPY(array))
except Exception as e:
print(f"Failed to save {filename}: {e}")

Expand Down

0 comments on commit 37ac94f

Please sign in to comment.