Skip to content

Commit

Permalink
Add random key support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711866172
  • Loading branch information
ChromeHearts authored and Orbax Authors committed Jan 24, 2025
1 parent a812ba2 commit ab97739
Showing 1 changed file with 119 additions and 12 deletions.
131 changes: 119 additions & 12 deletions checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
_SHARDING_SUFFIX_RE = r'/\d+(\.\d+)*$' # /0, /0.0, /1.0.1, etc.
_ZARRAY_SUFFIX_RE = r'/\.zarray$'
_ZARRAY_SUFFIX = '/.zarray'
_ARRAY_EXT_METADATA_FILE = '_array_ext_metadata.json'
_RANDOM_KEY_IMPL = 'random_key_impl'


async def _assert_parameter_files_exist(
Expand Down Expand Up @@ -1022,11 +1024,77 @@ async def _serialize_sharding(
serialized_sharding
)

async def _serialize_array_ext_metadata(
self, info: types.ParamInfo, metadata: Dict[str, Any]
):
"""Serializes extra array metadata."""

if info.parent_dir is None:
raise ValueError('parent_dir cannot be None')

kvstore_tspec = ts_utils.build_kvstore_tspec(
info.parent_dir.as_posix(),
name=_ARRAY_EXT_METADATA_FILE,
use_ocdbt=info.is_ocdbt_checkpoint,
)
tspec = {
'driver': 'json',
'kvstore': kvstore_tspec,
}
logging.vlog(
1,
'_serialize_array_extra_metadata: tspec: %s, metadata: %s',
tspec,
metadata,
)
t = await ts.open(
tspec,
open=True,
context=info.ts_context,
)
await t.write(metadata)

async def _deserialize_array_ext_metadata(
self, info: types.ParamInfo
) -> Optional[Dict[str, Any]]:
"""Serializes extra array metadata."""

if info.parent_dir is None:
raise ValueError('parent_dir cannot be None')

kvstore_tspec = ts_utils.build_kvstore_tspec(
info.parent_dir.as_posix(),
name=_ARRAY_EXT_METADATA_FILE,
use_ocdbt=info.is_ocdbt_checkpoint,
)
tspec = {
'driver': 'json',
'kvstore': kvstore_tspec,
}
try:
t = await ts.open(
tspec,
context=info.ts_context,
)
ret = (await t.read()).item()
except ValueError:
# no ext_metadata found
ret = None

logging.vlog(
1,
'_serialize_array_extra_metadata: tspec: %s, ret: %s',
tspec,
ret,
)
return ret

async def _background_serialize(
self,
values: Sequence[replica_slices.ReplicaSlices],
infos: Sequence[types.ParamInfo],
args: Sequence[types.SaveArgs],
array_ext_metadata: Dict[str, Any],
):
"""Runs serialization in a background thread."""
write_coros = []
Expand Down Expand Up @@ -1075,6 +1143,12 @@ async def _background_serialize(
process_index=multihost.process_index(),
)
)

if array_ext_metadata:
write_coros.append(
self._serialize_array_ext_metadata(infos[0], array_ext_metadata)
)

await asyncio.gather(*write_coros)
await sharding_metadata_txn.commit_async()
if ocdbt_transaction is not None:
Expand All @@ -1087,7 +1161,11 @@ async def serialize(
args: Optional[Sequence[types.SaveArgs]] = None,
) -> Sequence[future.Future]:
"""See superclass documentation."""
for v in values:

ext_metadata = {}
arrays = []

for v, info in zip(values, infos):
if (
isinstance(v, jax.Array)
and jax.process_count() > 1
Expand All @@ -1098,26 +1176,42 @@ async def serialize(
' obtained using pmap. Consider using'
' fully_replicated_host_local_array_to_global_array in'
' orbax/checkpoint/utils.py to convert your arrays into'
' serializable objects.'
f' serializable objects. Array.sharding: {v.sharding}'
)
args = args or [types.SaveArgs()] * len(values)
check_input_arguments(values, infos, args)

if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key):
# a JAX random key
arrays.append(jax.random.key_data(v))
if multihost.is_primary_host(self._primary_host):
ext_metadata[info.name] = {
_RANDOM_KEY_IMPL: str(jax.random.key_impl(v))
}
else:
# regular array
arrays.append(v)

args = args or [types.SaveArgs()] * len(arrays)
check_input_arguments(arrays, infos, args)

assert all([info.enable_pinned_host_transfer for info in infos]) or all(
[not info.enable_pinned_host_transfer for info in infos]
)

# Complete D2H transfer in parallel for each array.
values_on_host = replica_slices.transfer_arrays_to_host(
values,
arrays,
self._replica_id,
self._use_replica_parallel,
enable_pinned_host_transfer=infos[0].enable_pinned_host_transfer,
)

logging.info('extra_metadata: %s', ext_metadata)

return [
future.CommitFutureAwaitingContractedSignals(
self._background_serialize(values_on_host, infos, args),
self._background_serialize(
values_on_host, infos, args, ext_metadata
),
name='array_type_handler',
)
]
Expand Down Expand Up @@ -1221,21 +1315,34 @@ async def deserialize(
strict=arg.strict if hasattr(arg, 'strict') else True,
)
]
ret = await asyncio.gather(*deserialize_ops)

deserialize_ops.append(self._deserialize_array_ext_metadata(infos[0]))
deserialized_results = await asyncio.gather(*deserialize_ops)
ext_metadata = deserialized_results[-1]
ret = deserialized_results[:-1]

if ext_metadata:
for i, (info, v) in enumerate(zip(infos, ret)):
if meta := ext_metadata.get(info.name):
if impl := meta.get(_RANDOM_KEY_IMPL):
ret[i] = jax.random.wrap_key_data(v, impl=impl)
logging.vlog(
1, '%s: deserialized as a randon key: ', info.name, ret[i]
)

if logging.vlog_is_on(1):
for a in ret:
logging.vlog(
1,
'restored jax.Array.shape = %s, jax.array.dtype = %s,'
' jax.array.layout + %s',
a.shape,
a.dtype,
a.layout,
' jax.array.layout = %s',
getattr(a, 'shape', None),
getattr(a, 'dtype', None),
getattr(a, 'layout', None),
)
_print_ts_debug_data(self._metadata_key, infos)

return ret
return ret # pytype: disable=bad-return-type

def memory_size(
self, values: Sequence[jax.Array]
Expand Down

0 comments on commit ab97739

Please sign in to comment.