From 35e88f2ad4d6f5786ca291e3c6ba997087c5d112 Mon Sep 17 00:00:00 2001 From: Sadi Kneipp Date: Tue, 3 Dec 2024 00:43:26 +0000 Subject: [PATCH] add bulk persistence read code --- pathwaysutils/persistence/helper.py | 287 +++++++++++------ .../persistence/pathways_orbax_handler.py | 304 +++++++++--------- 2 files changed, 336 insertions(+), 255 deletions(-) diff --git a/pathwaysutils/persistence/helper.py b/pathwaysutils/persistence/helper.py index 794deaf..9153844 100644 --- a/pathwaysutils/persistence/helper.py +++ b/pathwaysutils/persistence/helper.py @@ -26,55 +26,55 @@ def base64_utf8_stringify(bs: bytes) -> str: - """Converts bytes to a base64-encoded utf-8 string. + """Converts bytes to a base64-encoded utf-8 string. - Args: - bs: The bytes to convert. + Args: + bs: The bytes to convert. - Returns: - The base64-encoded utf-8 string. - """ - return base64.b64encode(bs).decode("utf-8") + Returns: + The base64-encoded utf-8 string. + """ + return base64.b64encode(bs).decode("utf-8") def string_to_base64(text: str) -> str: - """Encodes a string to base64 format. + """Encodes a string to base64 format. - Args: - text: The string to encode. + Args: + text: The string to encode. - Returns: - The base64-encoded string. - """ - return base64_utf8_stringify(text.encode("utf-8")) + Returns: + The base64-encoded string. + """ + return base64_utf8_stringify(text.encode("utf-8")) def get_hlo_sharding_string( sharding: jax.sharding.Sharding, num_dimensions: int, ) -> str: - """Serializes the sharding to an hlo-sharding, encodes it to base64 and returns the base-64 as an utf-8 string.""" - return base64_utf8_stringify( - # pylint:disable=protected-access - sharding._to_xla_hlo_sharding(num_dimensions) # pytype: disable=attribute-error - # pylint:enable=protected-access - .to_proto().SerializeToString() - ) + """Serializes the sharding to an hlo-sharding, encodes it to base64 and returns the base-64 as an utf-8 string.""" + return base64_utf8_stringify( + # pylint:disable=protected-access + sharding._to_xla_hlo_sharding(num_dimensions) # pytype: disable=attribute-error + # pylint:enable=protected-access + .to_proto().SerializeToString() + ) def get_shape_string( dtype: np.dtype, shape: Sequence[int], ) -> str: - """Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string.""" - return base64_utf8_stringify( - xc.Shape.array_shape( - xc.PrimitiveType(xc.dtype_to_etype(dtype)), - shape, - ) - .with_major_to_minor_layout_if_absent() - .to_serialized_proto() - ) + """Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string.""" + return base64_utf8_stringify( + xc.Shape.array_shape( + xc.PrimitiveType(xc.dtype_to_etype(dtype)), + shape, + ) + .with_major_to_minor_layout_if_absent() + .to_serialized_proto() + ) def get_write_request( @@ -82,37 +82,57 @@ def get_write_request( name: str, jax_array: jax.Array, timeout: datetime.timedelta, + return_dict: bool = False, ) -> str: - """Returns a string representation of the plugin program which writes the given jax_array to the given location.""" - sharding = jax_array.sharding - assert isinstance(sharding, jax.sharding.Sharding), sharding - - timeout_seconds, timeout_fractional_seconds = divmod( - timeout.total_seconds(), 1 - ) - timeout_nanoseconds = timeout_fractional_seconds * 1e9 - return json.dumps({ - "persistenceWriteRequest": { - "b64_location": string_to_base64(location_path), - "b64_name": string_to_base64(name), - "b64_hlo_sharding_string": get_hlo_sharding_string( - jax_array.sharding, len(jax_array.shape) - ), - "shape": jax_array.shape, - "devices": { - "device_ids": [ - # pylint:disable=protected-access - device.id - for device in sharding._device_assignment - # pylint:enable=protected-access - ], - }, - "timeout": { - "seconds": int(timeout_seconds), - "nanos": int(timeout_nanoseconds), - }, - } - }) + """Returns a string representation of the plugin program which writes the given jax_array to the given location.""" + sharding = jax_array.sharding + assert isinstance(sharding, jax.sharding.Sharding), sharding + + timeout_seconds, timeout_fractional_seconds = divmod(timeout.total_seconds(), 1) + timeout_nanoseconds = timeout_fractional_seconds * 1e9 + d = { + "persistenceWriteRequest": { + "b64_location": string_to_base64(location_path), + "b64_name": string_to_base64(name), + "b64_hlo_sharding_string": get_hlo_sharding_string( + jax_array.sharding, len(jax_array.shape) + ), + "shape": jax_array.shape, + "devices": { + "device_ids": [ + # pylint:disable=protected-access + device.id + for device in sharding._device_assignment + # pylint:enable=protected-access + ], + }, + "timeout": { + "seconds": int(timeout_seconds), + "nanos": int(timeout_nanoseconds), + }, + } + } + + if return_dict: + return d + return json.dumps(d) + + +def get_bulk_write_request( + location_path: str, + names: Sequence[str], + jax_arrays: Sequence[jax.Array], + timeout: datetime.timedelta, +) -> str: + write_requests = [ + get_write_request(location_path, name, jax_array, timeout, True)[ + "persistenceWriteRequest" + ] + for name, jax_array in zip(names, jax_arrays) + ] + return json.dumps( + {"bulk_persistence_write_request": {"write_requests": write_requests}} + ) def get_read_request( @@ -123,32 +143,50 @@ def get_read_request( sharding: jax.sharding.Sharding, devices: Sequence[jax.Device], timeout: datetime.timedelta, + return_dict: bool = False, +) -> str: + """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding.""" + if not isinstance(devices, np.ndarray): + devices = np.array(devices) + + timeout_seconds, timeout_fractional_seconds = divmod(timeout.total_seconds(), 1) + timeout_nanoseconds = timeout_fractional_seconds * 1e9 + d = { + "persistenceReadRequest": { + "b64_location": string_to_base64(location_path), + "b64_shape_proto_string": get_shape_string(dtype, shape), + "b64_name": string_to_base64(name), + "b64_hlo_sharding_string": get_hlo_sharding_string(sharding, len(shape)), + "devices": {"device_ids": [device.id for device in devices.flatten()]}, + "timeout": { + "seconds": int(timeout_seconds), + "nanos": int(timeout_nanoseconds), + }, + } + } + if return_dict: + return d + return json.dumps(d) + + +def get_bulk_read_request( + location_path: str, + names: str, + dtypes: jnp.dtype, + shapes: Sequence[Sequence[int]], + shardings: Sequence[jax.sharding.Sharding], + devices: Sequence[jax.Device], + timeout: datetime.timedelta, ) -> str: - """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding.""" - if not isinstance(devices, np.ndarray): - devices = np.array(devices) - - timeout_seconds, timeout_fractional_seconds = divmod( - timeout.total_seconds(), 1 - ) - timeout_nanoseconds = timeout_fractional_seconds * 1e9 - return json.dumps({ - "persistenceReadRequest": { - "b64_location": string_to_base64(location_path), - "b64_shape_proto_string": get_shape_string(dtype, shape), - "b64_name": string_to_base64(name), - "b64_hlo_sharding_string": get_hlo_sharding_string( - sharding, len(shape) - ), - "devices": { - "device_ids": [device.id for device in devices.flatten()] - }, - "timeout": { - "seconds": int(timeout_seconds), - "nanos": int(timeout_nanoseconds), - }, - } - }) + read_requests = [ + get_read_request( + location_path, name, dtype, shape, sharding, devices, timeout, True + )["persistenceReadRequest"] + for name, dtype, shape, sharding in zip(names, dtypes, shapes, shardings) + ] + return json.dumps( + {"bulk_persistence_read_request": {"read_requests": read_requests}} + ) def write_one_array( @@ -157,11 +195,24 @@ def write_one_array( value: jax.Array, timeout: datetime.timedelta, ): - """Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future.""" - write_request = get_write_request(location, name, value, timeout) - write_executable = plugin_executable.PluginExecutable(write_request) - _, write_future = write_executable.call([value]) - return write_future + """Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future.""" + write_request = get_write_request(location, name, value, timeout) + write_executable = plugin_executable.PluginExecutable(write_request) + _, write_future = write_executable.call([value]) + return write_future + + +def write_arrays( + location: str, + names: Sequence[str], + values: Sequence[jax.Array], + timeout: datetime.timedelta, +): + """Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future.""" + bulk_write_request = get_bulk_write_request(location, names, values, timeout) + bulk_write_executable = plugin_executable.PluginExecutable(bulk_write_request) + _, bulk_write_future = bulk_write_executable.call(values) + return bulk_write_future def read_one_array( @@ -173,20 +224,42 @@ def read_one_array( devices: Union[Sequence[jax.Device], np.ndarray], timeout: datetime.timedelta, ): - """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" - read_request = get_read_request( - location, - name, - dtype, - shape, - shardings, - devices, - timeout, - ) - read_executable = plugin_executable.PluginExecutable(read_request) - out_aval = core.ShapedArray(shape, dtype) - read_array, read_future = read_executable.call( - out_shardings=[shardings], out_avals=[out_aval] - ) - read_future.result() - return read_array[0] + """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" + read_request = get_read_request( + location, + name, + dtype, + shape, + shardings, + devices, + timeout, + ) + read_executable = plugin_executable.PluginExecutable(read_request) + out_aval = core.ShapedArray(shape, dtype) + read_array, read_future = read_executable.call( + out_shardings=[shardings], out_avals=[out_aval] + ) + read_future.result() + return read_array[0] + + +def read_arrays( + location: str, + names: Sequence[str], + dtypes: Sequence[np.dtype], + shapes: Sequence[int], + shardings: Sequence[jax.sharding.Sharding], + devices: Union[Sequence[jax.Device], np.ndarray], + timeout: datetime.timedelta, +): + """Creates the read array plugin program string, compiles it to an executable, calls it and returns the result.""" + + bulk_read_request = get_bulk_read_request( + location, names, dtypes, shapes, shardings, devices, timeout + ) + bulk_read_executable = plugin_executable.PluginExecutable(bulk_read_request) + out_avals = [core.ShapedArray(shape, dtype) for shape, dtype in zip(shapes, dtypes)] + read_arrays, read_future = bulk_read_executable.call( + out_shardings=shardings, out_avals=out_avals + ) + return (read_arrays, read_future) diff --git a/pathwaysutils/persistence/pathways_orbax_handler.py b/pathwaysutils/persistence/pathways_orbax_handler.py index cb84e48..4148210 100644 --- a/pathwaysutils/persistence/pathways_orbax_handler.py +++ b/pathwaysutils/persistence/pathways_orbax_handler.py @@ -37,159 +37,167 @@ def extract_parent_dir_and_name( infos: Sequence[ParamInfo], ) -> tuple[Sequence[str], Sequence[str]]: - """Extracts names and locations from ParamInfos.""" - parent_dirs = [str(info.parent_dir) for info in infos] - names = [str(info.name) for info in infos] - return parent_dirs, names + """Extracts names and locations from ParamInfos.""" + parent_dirs = [str(info.parent_dir) for info in infos] + logging.info(f"[KSADI] parent_dirs: {parent_dirs}") + names = [str(info.name) for info in infos] + return parent_dirs, names class CloudPathwaysArrayHandler(type_handlers.ArrayHandler): - """A TypeHandler for array types when using Pathways.""" - - def __init__( - self, - read_timeout: Optional[datetime.timedelta] = None, - use_ocdbt: bool = False, - ): - """Constructor. - - Args: - read_timeout: Duration indicating the timeout for reading arrays - use_ocdbt: allows using Tensorstore OCDBT driver. - """ - self._read_timeout = read_timeout - - if use_ocdbt: - raise ValueError('OCDBT not supported for Pathways.') - super().__init__() - - async def serialize( - self, - values: Sequence[jax.Array], - infos: Sequence[ParamInfo], - args: Optional[Sequence[SaveArgs]] = None, - ) -> Sequence[future.Future]: - """Uses Pathways Persistence API to serialize a jax array.""" - type_handlers.check_input_arguments(values, infos, args) - - if any([arg.dtype is not None for arg in args]): - raise ValueError('Casting during save not supported for Pathways.') - - locations, names = extract_parent_dir_and_name(infos) - f = functools.partial( - helper.write_one_array, timeout=self._read_timeout - ) - return list(map(f, locations, names, values)) - - async def deserialize( - self, - infos: Sequence[ParamInfo], - args: Optional[Sequence[RestoreArgs]] = None, - ) -> Sequence[jax.Array]: - """Uses Pathways Persistence API to deserialize a jax array.""" - if args is None: - raise ValueError('Must provide ArrayRestoreArgs to restore as jax.Array.') - type_handlers.check_input_arguments(infos, args) - - global_meshes = [] - mesh_axes = [] - global_shapes = [] - dtypes = [] - shardings = [] - - should_open_metadata = False - for arg in args: - if not isinstance(arg, ArrayRestoreArgs): - raise ValueError( - 'To restore jax.Array, provide ArrayRestoreArgs; found' - f' {type(arg).__name__}' - ) - arg = typing.cast(ArrayRestoreArgs, arg) - if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None): - raise ValueError( - 'Sharding of jax.Array cannot be None. Provide `mesh`' - ' and `mesh_axes` OR `sharding`.' - ) - if arg.sharding is None: - global_meshes.append(arg.mesh) - mesh_axes.append(arg.mesh_axes) - shardings.append( - jax.sharding.NamedSharding(mesh=arg.mesh, spec=arg.mesh_axes) - ) - else: - if not isinstance(arg.sharding, jax.sharding.NamedSharding): - raise ValueError('Pathways only supports jax.sharding.NamedSharding.') - sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding) - global_meshes.append(sharding.mesh) - mesh_axes.append(sharding.spec) - shardings.append(sharding) - if arg.global_shape is None or arg.dtype is None: - logger.warning( - 'Shape or dtype not provided for restoration. Provide these' - ' properties for improved performance.' - ) - should_open_metadata = True - global_shapes.append(arg.global_shape) - dtypes.append(arg.dtype) - - if should_open_metadata: - metadatas = await self.metadata(infos) - global_shapes = [ - m.shape if s is None else s for m, s in zip(metadatas, global_shapes) - ] - dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)] - - # Group inputs by global_mesh so that we can perform batched Array - # construction for each global_mesh. - inputs_by_global_mesh = collections.defaultdict(list) - for i, global_mesh in enumerate(global_meshes): - inputs_by_global_mesh[global_mesh].append(i) - - results = [None] * len(infos) - - for global_mesh, idxs in inputs_by_global_mesh.items(): - grouped_infos = [infos[idx] for idx in idxs] - grouped_global_shapes = [global_shapes[idx] for idx in idxs] - grouped_dtypes = [dtypes[idx] for idx in idxs] - grouped_shardings = [shardings[idx] for idx in idxs] - locations, names = extract_parent_dir_and_name(grouped_infos) - f = functools.partial( - helper.read_one_array, - devices=global_mesh.devices, - timeout=self._read_timeout, - ) - grouped_arrays = [ - f( - location=location, - name=name, - dtype=dtype, - shape=shape, - shardings=sharding, - ) - for location, name, dtype, shape, sharding in zip( - locations, - names, - grouped_dtypes, - grouped_global_shapes, - grouped_shardings, - ) - ] - for idx, arr in zip(idxs, grouped_arrays): - results[idx] = arr - return results # pytype: disable=bad-return-type + """A TypeHandler for array types when using Pathways.""" + + def __init__( + self, + read_timeout: Optional[datetime.timedelta] = None, + use_ocdbt: bool = False, + ): + """Constructor. + + Args: + read_timeout: Duration indicating the timeout for reading arrays + use_ocdbt: allows using Tensorstore OCDBT driver. + """ + self._read_timeout = read_timeout + + if use_ocdbt: + raise ValueError("OCDBT not supported for Pathways.") + super().__init__() + + async def serialize( + self, + values: Sequence[jax.Array], + infos: Sequence[ParamInfo], + args: Optional[Sequence[SaveArgs]] = None, + ) -> Sequence[future.Future]: + """Uses Pathways Persistence API to serialize a jax array.""" + type_handlers.check_input_arguments(values, infos, args) + + if any([arg.dtype is not None for arg in args]): + raise ValueError("Casting during save not supported for Pathways.") + + locations, names = extract_parent_dir_and_name(infos) + f = functools.partial(helper.write_one_array, timeout=self._read_timeout) + return list(map(f, locations, names, values)) + + async def deserialize( + self, + infos: Sequence[ParamInfo], + args: Optional[Sequence[RestoreArgs]] = None, + ) -> Sequence[jax.Array]: + """Uses Pathways Persistence API to deserialize a jax array.""" + if args is None: + raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.") + type_handlers.check_input_arguments(infos, args) + + global_meshes = [] + mesh_axes = [] + global_shapes = [] + dtypes = [] + shardings = [] + + should_open_metadata = False + for arg in args: + if not isinstance(arg, ArrayRestoreArgs): + raise ValueError( + "To restore jax.Array, provide ArrayRestoreArgs; found" + f" {type(arg).__name__}" + ) + arg = typing.cast(ArrayRestoreArgs, arg) + if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None): + raise ValueError( + "Sharding of jax.Array cannot be None. Provide `mesh`" + " and `mesh_axes` OR `sharding`." + ) + if arg.sharding is None: + global_meshes.append(arg.mesh) + mesh_axes.append(arg.mesh_axes) + shardings.append( + jax.sharding.NamedSharding(mesh=arg.mesh, spec=arg.mesh_axes) + ) + else: + if not isinstance(arg.sharding, jax.sharding.NamedSharding): + raise ValueError( + "Pathways only supports jax.sharding.NamedSharding." + ) + sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding) + global_meshes.append(sharding.mesh) + mesh_axes.append(sharding.spec) + shardings.append(sharding) + if arg.global_shape is None or arg.dtype is None: + logger.warning( + "Shape or dtype not provided for restoration. Provide these" + " properties for improved performance." + ) + should_open_metadata = True + global_shapes.append(arg.global_shape) + dtypes.append(arg.dtype) + + if should_open_metadata: + metadatas = await self.metadata(infos) + global_shapes = [ + m.shape if s is None else s for m, s in zip(metadatas, global_shapes) + ] + dtypes = [m.dtype if d is None else d for m, d in zip(metadatas, dtypes)] + + # Group inputs by global_mesh so that we can perform batched Array + # construction for each global_mesh. + inputs_by_global_mesh = collections.defaultdict(list) + for i, global_mesh in enumerate(global_meshes): + inputs_by_global_mesh[global_mesh].append(i) + + results = [None] * len(infos) + + for global_mesh, idxs in inputs_by_global_mesh.items(): + grouped_infos = [infos[idx] for idx in idxs] + grouped_global_shapes = [global_shapes[idx] for idx in idxs] + grouped_dtypes = [dtypes[idx] for idx in idxs] + grouped_shardings = [shardings[idx] for idx in idxs] + locations, names = extract_parent_dir_and_name(grouped_infos) + grouped_arrays = helper.read_arrays( + location[0], + names, + grouped_dtypes, + grouped_global_shapes, + grouped_shardings, + global_mesh.devices, + timeout=self._read_timeout, + ) + # f = functools.partial( + # helper.read_one_array, + # devices=global_mesh.devices, + # timeout=self._read_timeout, + # ) + # grouped_arrays = [ + # f( + # location=location, + # name=name, + # dtype=dtype, + # shape=shape, + # shardings=sharding, + # ) + # for location, name, dtype, shape, sharding in zip( + # locations, + # names, + # grouped_dtypes, + # grouped_global_shapes, + # grouped_shardings, + # ) + # ] + for idx, arr in zip(idxs, grouped_arrays): + results[idx] = arr + return results # pytype: disable=bad-return-type def register_pathways_handlers( read_timeout: Optional[datetime.timedelta] = None, ): - """Function that must be called before saving or restoring with Pathways.""" - logger.debug( - 'Registering CloudPathwaysArrayHandler (Pathways Persistence API).' - ) - type_handlers.register_type_handler( - jax.Array, - CloudPathwaysArrayHandler( - read_timeout=read_timeout, - ), - override=True, - ) + """Function that must be called before saving or restoring with Pathways.""" + logger.debug("Registering CloudPathwaysArrayHandler (Pathways Persistence API).") + type_handlers.register_type_handler( + jax.Array, + CloudPathwaysArrayHandler( + read_timeout=read_timeout, + ), + override=True, + )