From 3fb0f383c6e9f304897d6738f294553980c074b9 Mon Sep 17 00:00:00 2001 From: Matthew Broadway Date: Thu, 11 Apr 2024 11:53:18 +0100 Subject: [PATCH] add prototype for safer checkpoint format --- composer/utils/safe_checkpoint.py | 211 ++++++++++++++++++++++++++++ tests/utils/test_safe_checkpoint.py | 80 +++++++++++ 2 files changed, 291 insertions(+) create mode 100644 composer/utils/safe_checkpoint.py create mode 100644 tests/utils/test_safe_checkpoint.py diff --git a/composer/utils/safe_checkpoint.py b/composer/utils/safe_checkpoint.py new file mode 100644 index 0000000000..3f8dea19bb --- /dev/null +++ b/composer/utils/safe_checkpoint.py @@ -0,0 +1,211 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""A custom safer checkpoint format that does not allow for arbitrary code execution (unlike pickle).""" + +import json +import shutil +import tempfile +import zipfile +from datetime import timedelta +from pathlib import Path +from types import NoneType +from typing import Any + +import numpy as np +import numpy.typing as npt +import safetensors.torch +import torch + +import composer.utils.checkpoint +from composer.loggers import LoggerDestination +from composer.utils.object_store import ObjectStore + +_original_write_checkpoint_file = composer.utils.checkpoint._write_checkpoint_file +_original_download_checkpoint = composer.utils.checkpoint.download_checkpoint +_original_safe_torch_load = composer.utils.checkpoint.safe_torch_load + + +def _patched_write_checkpoint_file(state_dict: dict[str, Any], filename: str) -> None: + if Path(filename).suffix == '.safe': + SafeState.from_data(state_dict).save(Path(filename)) + else: + _original_write_checkpoint_file(state_dict, filename) + + +def _patched_download_checkpoint( + path: str, + node_checkpoint_folder: str, + object_store: ObjectStore | LoggerDestination | None, + progress_bar: bool, + fsdp_sharded_state_dict_enabled: bool = False, + deepspeed_sharded_checkpoint: bool = False, +) -> tuple[str, str | None, bool]: + """A monkeypatched version of _download_checkpoint to enable the safe checkpoint format. + + In the case of a local checkpoint file that is not a .tar archive, `_download_checkpoint` writes a single symlink + into a temporary directory and later on `safe_torch_load` is called on that file. This patch sets the file extension + of the symlink so that the patched `safe_torch_load` can identify it as being '.safe' instead of '.pt'. + """ + composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = _original_download_checkpoint( + path, + node_checkpoint_folder, + object_store, + progress_bar, + fsdp_sharded_state_dict_enabled, + deepspeed_sharded_checkpoint, + ) + if Path(path).suffix == '.safe': + symlink_path = Path(composer_states_filepath) + composer_states_filepath = str(symlink_path.rename(symlink_path.with_suffix('.safe'))) + return composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n + + +def _patched_safe_torch_load( + composer_states_filepath: Path | str, + map_location: str = 'cpu', + load_fsdp_monolith_rank0_only: bool = False, +) -> dict[str, Any]: + if Path(composer_states_filepath).suffix == '.safe': + return SafeState.load(Path(composer_states_filepath)).to_data() # type: ignore[no-any-return] + else: + return _original_safe_torch_load(composer_states_filepath, map_location, load_fsdp_monolith_rank0_only) + + +def install_safe_checkpoint() -> None: + """Patch composer to enable the safer checkpoint format when saving with `.safe`.""" + composer.utils.checkpoint._write_checkpoint_file = _patched_write_checkpoint_file + composer.utils.checkpoint.download_checkpoint = _patched_download_checkpoint + composer.utils.checkpoint.safe_torch_load = _patched_safe_torch_load + + +class SafeState: + """A data structure that can be serialized and deserialized without executing arbitrary code (unlike Pickle).""" + + def __init__( + self, + tensors: dict[str, torch.Tensor], + ndarrays: dict[str, npt.NDArray[Any]], + other_data: Any, + ) -> None: + self._tensors = tensors + self._ndarrays = ndarrays + self._other_data = other_data + + @staticmethod + def load(path: Path) -> 'SafeState': + if not path.exists(): + raise FileNotFoundError(path) + with tempfile.TemporaryDirectory(prefix='SafeState') as tmpdir: + output_dir = Path(tmpdir) + shutil.unpack_archive(path, extract_dir=output_dir, format='gztar') + tensors = safetensors.torch.load_file(output_dir / 'tensors.safetensors') + ndarrays = np.load(output_dir / 'ndarrays.npz', allow_pickle=False) + with (output_dir / 'other.json').open() as f: + other_data = json.load(f) + return SafeState(tensors, ndarrays, other_data) + + def save(self, path: Path) -> None: + if path.exists(): + raise FileExistsError(path) + with tempfile.TemporaryDirectory(prefix='SafeState') as tmpdir: + tmp_path = Path(tmpdir) + safetensors.torch.save_file(self._tensors, tmp_path / 'tensors.safetensors') + _savez(tmp_path / 'ndarrays.npz', self._ndarrays) + with (tmp_path / 'other.json').open('w') as f: + json.dump(self._other_data, f) + filename = shutil.make_archive(str(path), format='gztar', root_dir=tmp_path) + Path(filename).rename(path) + + @staticmethod + def from_data(data: Any) -> 'SafeState': + objects = _NonJsonObjects({}, {}) + extracted_data = _extract_objects(data, objects) + return SafeState(objects.tensors, objects.ndarrays, extracted_data) + + def to_data(self) -> Any: + objects = _NonJsonObjects(self._tensors, self._ndarrays) + return _insert_objects(self._other_data, objects) + + +def _savez(output_path: Path, ndarrays: dict[str, npt.NDArray[Any]]) -> None: + """Save numpy arrays to a `.npz` file. + + Based on np.lib.npyio._savez() which is used by np.savez() but with allow_pickle=False. + """ + zipf = zipfile.ZipFile(output_path, mode='w', compression=zipfile.ZIP_STORED, allowZip64=True) + for key, val in ndarrays.items(): + with zipf.open(key + '.npy', 'w', force_zip64=True) as fid: + np.lib.format.write_array(fid, np.asanyarray(val), allow_pickle=False, pickle_kwargs={}) + zipf.close() + + +_PRIMITIVE_TYPES = (float, int, str, bool, NoneType) +_KEY = '__safe_storage_obj' + + +class _NonJsonObjects: + + def __init__(self, tensors: dict[str, torch.Tensor], ndarrays: dict[str, npt.NDArray[Any]]) -> None: + self.tensors = tensors + self.ndarrays = ndarrays + self._id_counter = 0 + + @property + def next_id(self) -> str: + next_id = str(self._id_counter) + self._id_counter += 1 + return next_id + + def add_tensor(self, item: torch.Tensor) -> str: + item_id = self.next_id + self.tensors[item_id] = item + return item_id + + def add_ndarray(self, item: npt.NDArray[Any]) -> str: + item_id = self.next_id + self.ndarrays[item_id] = item + return item_id + + +def _extract_objects(data: Any, objects: _NonJsonObjects) -> Any: + if isinstance(data, _PRIMITIVE_TYPES): + return data + elif isinstance(data, list): + return [_extract_objects(item, objects) for item in data] + elif isinstance(data, dict): + return {k: _extract_objects(v, objects) for k, v in data.items()} + elif isinstance(data, tuple): + return {_KEY: 'tuple', 'items': [_extract_objects(item, objects) for item in data]} + elif isinstance(data, timedelta): + return {_KEY: 'timedelta', 'seconds': data.total_seconds()} + elif isinstance(data, torch.Tensor): + return {_KEY: 'tensor', 'id': objects.add_tensor(data)} + elif isinstance(data, np.ndarray): + return {_KEY: 'ndarray', 'id': objects.add_ndarray(data)} + else: + raise TypeError(f'unsupported type: {type(data).__name__}') + + +def _insert_objects(data: Any, objects: _NonJsonObjects) -> Any: + if isinstance(data, _PRIMITIVE_TYPES): + return data + elif isinstance(data, list): + return [_insert_objects(item, objects) for item in data] + elif isinstance(data, dict): + if _KEY in data: + key = data[_KEY] + if key == 'tensor': + return objects.tensors[data['id']] + elif key == 'ndarray': + return objects.ndarrays[data['id']] + elif key == 'tuple': + return tuple(_insert_objects(v, objects) for v in data['items']) + elif key == 'timedelta': + return timedelta(seconds=data['seconds']) + else: + raise ValueError(f'unknown key: "{key}"') + else: + return {k: _insert_objects(v, objects) for k, v in data.items()} + else: + raise TypeError(f'unsupported type: {type(data).__name__}') diff --git a/tests/utils/test_safe_checkpoint.py b/tests/utils/test_safe_checkpoint.py new file mode 100644 index 0000000000..7338382958 --- /dev/null +++ b/tests/utils/test_safe_checkpoint.py @@ -0,0 +1,80 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the safer checkpoint format.""" + +from pathlib import Path + +import numpy as np +import torch + +from composer.utils.safe_checkpoint import SafeState + + +class TestSafeState: + """Tests for the safer checkpoint format.""" + + def test_convert_between_basic_data(self, tmp_path: Path) -> None: + data = {'foo': 123, 'bar': {'a': [True, False, None], 'b': 1.23}} + + safe_state = SafeState.from_data(data) + + loaded_data = safe_state.to_data() + assert data == loaded_data + + safe_state.save(tmp_path / 'state.safe') + loaded_state = safe_state.load(tmp_path / 'state.safe') + + loaded_data = loaded_state.to_data() + assert data == loaded_data + + def test_convert_with_tensor(self, tmp_path: Path) -> None: + data = {'foo': 123, 'bar': torch.tensor([1, 2, 3]), 'baz': np.array([1, 2], dtype=np.uint8)} + + safe_state = SafeState.from_data(data) + + loaded_data = safe_state.to_data() + assert set(loaded_data.keys()) == {'foo', 'bar', 'baz'} + assert loaded_data['foo'] == 123 + assert torch.all(torch.tensor([1, 2, 3]).eq(loaded_data['bar'])) + assert np.array_equal(loaded_data['baz'], [1, 2]) + assert loaded_data['baz'].dtype == np.uint8 + + safe_state.save(tmp_path / 'state.safe') + loaded_state = safe_state.load(tmp_path / 'state.safe') + + loaded_data = loaded_state.to_data() + assert set(loaded_data.keys()) == {'foo', 'bar', 'baz'} + assert loaded_data['foo'] == 123 + assert torch.all(torch.tensor([1, 2, 3]).eq(loaded_data['bar'])) + assert np.array_equal(loaded_data['baz'], [1, 2]) + assert loaded_data['baz'].dtype == np.uint8 + + def test_convert_with_multiple_tensors(self, tmp_path: Path) -> None: + data = { + 'foo': 123, + 'bar': torch.tensor([1, 2, 3]), + 'baz': [torch.tensor([1]), torch.tensor([1.1]), np.array([1.2])], + } + + safe_state = SafeState.from_data(data) + + loaded_data = safe_state.to_data() + assert set(loaded_data.keys()) == {'foo', 'bar', 'baz'} + assert loaded_data['foo'] == 123 + assert torch.all(torch.tensor([1, 2, 3]).eq(loaded_data['bar'])) + assert len(loaded_data['baz']) == 3 + assert torch.all(torch.tensor([1]).eq(loaded_data['baz'][0])) + assert torch.all(torch.tensor([1.1]).eq(loaded_data['baz'][1])) + assert np.array_equal(loaded_data['baz'][2], [1.2]) + + safe_state.save(tmp_path / 'state.safe') + loaded_state = safe_state.load(tmp_path / 'state.safe') + + loaded_data = loaded_state.to_data() + assert set(loaded_data.keys()) == {'foo', 'bar', 'baz'} + assert loaded_data['foo'] == 123 + assert len(loaded_data['baz']) == 3 + assert torch.all(torch.tensor([1]).eq(loaded_data['baz'][0])) + assert torch.all(torch.tensor([1.1]).eq(loaded_data['baz'][1])) + assert np.array_equal(loaded_data['baz'][2], [1.2])