Skip to content

Commit

Permalink
add bulk persistence read code
Browse files Browse the repository at this point in the history
  • Loading branch information
sadikneipp committed Dec 3, 2024
1 parent 2a494e4 commit 35e88f2
Show file tree
Hide file tree
Showing 2 changed files with 336 additions and 255 deletions.
287 changes: 180 additions & 107 deletions pathwaysutils/persistence/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,93 +26,113 @@


def base64_utf8_stringify(bs: bytes) -> str:
"""Converts bytes to a base64-encoded utf-8 string.
"""Converts bytes to a base64-encoded utf-8 string.
Args:
bs: The bytes to convert.
Args:
bs: The bytes to convert.
Returns:
The base64-encoded utf-8 string.
"""
return base64.b64encode(bs).decode("utf-8")
Returns:
The base64-encoded utf-8 string.
"""
return base64.b64encode(bs).decode("utf-8")


def string_to_base64(text: str) -> str:
"""Encodes a string to base64 format.
"""Encodes a string to base64 format.
Args:
text: The string to encode.
Args:
text: The string to encode.
Returns:
The base64-encoded string.
"""
return base64_utf8_stringify(text.encode("utf-8"))
Returns:
The base64-encoded string.
"""
return base64_utf8_stringify(text.encode("utf-8"))


def get_hlo_sharding_string(
sharding: jax.sharding.Sharding,
num_dimensions: int,
) -> str:
"""Serializes the sharding to an hlo-sharding, encodes it to base64 and returns the base-64 as an utf-8 string."""
return base64_utf8_stringify(
# pylint:disable=protected-access
sharding._to_xla_hlo_sharding(num_dimensions) # pytype: disable=attribute-error
# pylint:enable=protected-access
.to_proto().SerializeToString()
)
"""Serializes the sharding to an hlo-sharding, encodes it to base64 and returns the base-64 as an utf-8 string."""
return base64_utf8_stringify(
# pylint:disable=protected-access
sharding._to_xla_hlo_sharding(num_dimensions) # pytype: disable=attribute-error
# pylint:enable=protected-access
.to_proto().SerializeToString()
)


def get_shape_string(
dtype: np.dtype,
shape: Sequence[int],
) -> str:
"""Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string."""
return base64_utf8_stringify(
xc.Shape.array_shape(
xc.PrimitiveType(xc.dtype_to_etype(dtype)),
shape,
)
.with_major_to_minor_layout_if_absent()
.to_serialized_proto()
)
"""Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string."""
return base64_utf8_stringify(
xc.Shape.array_shape(
xc.PrimitiveType(xc.dtype_to_etype(dtype)),
shape,
)
.with_major_to_minor_layout_if_absent()
.to_serialized_proto()
)


def get_write_request(
location_path: str,
name: str,
jax_array: jax.Array,
timeout: datetime.timedelta,
return_dict: bool = False,
) -> str:
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
sharding = jax_array.sharding
assert isinstance(sharding, jax.sharding.Sharding), sharding

timeout_seconds, timeout_fractional_seconds = divmod(
timeout.total_seconds(), 1
)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
return json.dumps({
"persistenceWriteRequest": {
"b64_location": string_to_base64(location_path),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
jax_array.sharding, len(jax_array.shape)
),
"shape": jax_array.shape,
"devices": {
"device_ids": [
# pylint:disable=protected-access
device.id
for device in sharding._device_assignment
# pylint:enable=protected-access
],
},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
})
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
sharding = jax_array.sharding
assert isinstance(sharding, jax.sharding.Sharding), sharding

timeout_seconds, timeout_fractional_seconds = divmod(timeout.total_seconds(), 1)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
d = {
"persistenceWriteRequest": {
"b64_location": string_to_base64(location_path),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
jax_array.sharding, len(jax_array.shape)
),
"shape": jax_array.shape,
"devices": {
"device_ids": [
# pylint:disable=protected-access
device.id
for device in sharding._device_assignment
# pylint:enable=protected-access
],
},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
}

if return_dict:
return d
return json.dumps(d)


def get_bulk_write_request(
location_path: str,
names: Sequence[str],
jax_arrays: Sequence[jax.Array],
timeout: datetime.timedelta,
) -> str:
write_requests = [
get_write_request(location_path, name, jax_array, timeout, True)[
"persistenceWriteRequest"
]
for name, jax_array in zip(names, jax_arrays)
]
return json.dumps(
{"bulk_persistence_write_request": {"write_requests": write_requests}}
)


def get_read_request(
Expand All @@ -123,32 +143,50 @@ def get_read_request(
sharding: jax.sharding.Sharding,
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
return_dict: bool = False,
) -> str:
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
devices = np.array(devices)

timeout_seconds, timeout_fractional_seconds = divmod(timeout.total_seconds(), 1)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
d = {
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"b64_shape_proto_string": get_shape_string(dtype, shape),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(sharding, len(shape)),
"devices": {"device_ids": [device.id for device in devices.flatten()]},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
}
if return_dict:
return d
return json.dumps(d)


def get_bulk_read_request(
location_path: str,
names: str,
dtypes: jnp.dtype,
shapes: Sequence[Sequence[int]],
shardings: Sequence[jax.sharding.Sharding],
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
) -> str:
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
devices = np.array(devices)

timeout_seconds, timeout_fractional_seconds = divmod(
timeout.total_seconds(), 1
)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
return json.dumps({
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"b64_shape_proto_string": get_shape_string(dtype, shape),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
sharding, len(shape)
),
"devices": {
"device_ids": [device.id for device in devices.flatten()]
},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
})
read_requests = [
get_read_request(
location_path, name, dtype, shape, sharding, devices, timeout, True
)["persistenceReadRequest"]
for name, dtype, shape, sharding in zip(names, dtypes, shapes, shardings)
]
return json.dumps(
{"bulk_persistence_read_request": {"read_requests": read_requests}}
)


def write_one_array(
Expand All @@ -157,11 +195,24 @@ def write_one_array(
value: jax.Array,
timeout: datetime.timedelta,
):
"""Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future."""
write_request = get_write_request(location, name, value, timeout)
write_executable = plugin_executable.PluginExecutable(write_request)
_, write_future = write_executable.call([value])
return write_future
"""Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future."""
write_request = get_write_request(location, name, value, timeout)
write_executable = plugin_executable.PluginExecutable(write_request)
_, write_future = write_executable.call([value])
return write_future


def write_arrays(
location: str,
names: Sequence[str],
values: Sequence[jax.Array],
timeout: datetime.timedelta,
):
"""Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future."""
bulk_write_request = get_bulk_write_request(location, names, values, timeout)
bulk_write_executable = plugin_executable.PluginExecutable(bulk_write_request)
_, bulk_write_future = bulk_write_executable.call(values)
return bulk_write_future


def read_one_array(
Expand All @@ -173,20 +224,42 @@ def read_one_array(
devices: Union[Sequence[jax.Device], np.ndarray],
timeout: datetime.timedelta,
):
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""
read_request = get_read_request(
location,
name,
dtype,
shape,
shardings,
devices,
timeout,
)
read_executable = plugin_executable.PluginExecutable(read_request)
out_aval = core.ShapedArray(shape, dtype)
read_array, read_future = read_executable.call(
out_shardings=[shardings], out_avals=[out_aval]
)
read_future.result()
return read_array[0]
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""
read_request = get_read_request(
location,
name,
dtype,
shape,
shardings,
devices,
timeout,
)
read_executable = plugin_executable.PluginExecutable(read_request)
out_aval = core.ShapedArray(shape, dtype)
read_array, read_future = read_executable.call(
out_shardings=[shardings], out_avals=[out_aval]
)
read_future.result()
return read_array[0]


def read_arrays(
location: str,
names: Sequence[str],
dtypes: Sequence[np.dtype],
shapes: Sequence[int],
shardings: Sequence[jax.sharding.Sharding],
devices: Union[Sequence[jax.Device], np.ndarray],
timeout: datetime.timedelta,
):
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""

bulk_read_request = get_bulk_read_request(
location, names, dtypes, shapes, shardings, devices, timeout
)
bulk_read_executable = plugin_executable.PluginExecutable(bulk_read_request)
out_avals = [core.ShapedArray(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
read_arrays, read_future = bulk_read_executable.call(
out_shardings=shardings, out_avals=out_avals
)
return (read_arrays, read_future)
Loading

0 comments on commit 35e88f2

Please sign in to comment.