diff --git a/pathwaysutils/persistence/pathways_orbax_handler.py b/pathwaysutils/persistence/pathways_orbax_handler.py index 6d1babd..98ad276 100644 --- a/pathwaysutils/persistence/pathways_orbax_handler.py +++ b/pathwaysutils/persistence/pathways_orbax_handler.py @@ -144,7 +144,7 @@ async def deserialize( results = [None] * len(infos) logging.warning(f"[ksadi] restore loop will take {len(inputs_by_global_mesh.items())} iterations") - for i, global_mesh, idxs in enumerate(inputs_by_global_mesh.items()): + for i, (global_mesh, idxs) in enumerate(inputs_by_global_mesh.items()): grouped_infos = [infos[idx] for idx in idxs] grouped_global_shapes = [global_shapes[idx] for idx in idxs] grouped_dtypes = [dtypes[idx] for idx in idxs]