diff --git a/bw2io/backup.py b/bw2io/backup.py index 78927b4..eecb980 100644 --- a/bw2io/backup.py +++ b/bw2io/backup.py @@ -11,6 +11,62 @@ from bw2data import projects from bw_processing import safe_filename +_METADATA_FIELDS = {"is_sourced", "revision", "data", "full_hash"} + + +def _add_project_metadata() -> None: + fp = projects.dir / "project-metadata.json" + data = { + field: getattr(projects.dataset, field) + for field in _METADATA_FIELDS + if getattr(projects.dataset, field) + } + with open(fp, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + +def _remove_project_metadata() -> None: + fp = projects.dir / "project-metadata.json" + if fp.is_file: + fp.unlink() + + +def _restore_project_metadata() -> None: + metadata = json.load(open(projects.dir / "project-metadata.json", encoding="utf-8")) + for field in _METADATA_FIELDS: + if field in metadata: + setattr(projects.dataset, field, metadata[field]) + projects.dataset.save() + + +def _extract_single_directory_tarball(filepath: Path, output_dir: Path) -> Path: + def is_within_directory(directory, target): + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + prefix = os.path.commonprefix([abs_directory, abs_target]) + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted path traversal in tar file") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + with tarfile.open(filepath, "r:gz") as tar: + safe_extract(tar, output_dir) + + # Find single extracted directory; don't know it ahead of time + extracted_dirs = [ + (Path(output_dir) / dirname) + for dirname in Path(output_dir).iterdir() + if (Path(output_dir) / dirname).is_dir() + ] + if not len(extracted_dirs) == 1: + raise ValueError("Can't find single directory extracted from project archive") + return extracted_dirs[0] + def backup_data_directory( timestamp: Optional[bool] = True, dir_backup: Optional[Union[str, Path]] = None @@ -67,7 +123,6 @@ def backup_data_directory( tar.add(data_directory, arcname=data_directory.name) print(f"Saved to: {fp}") - return fp @@ -75,7 +130,7 @@ def backup_project_directory( project: str, timestamp: Optional[bool] = True, dir_backup: Optional[Union[str, Path]] = None, -): +) -> Path: """ Backup project data directory to a ``.tar.gz`` (compressed tar archive) in the user's home directory, or a directory specified by ``dir_backup``. @@ -94,8 +149,8 @@ def backup_project_directory( Returns ------- - project_name : str - Name of the project that was backed up. + filepath : Path + pathlib.Path of archive file Raises ------ @@ -122,6 +177,7 @@ def backup_project_directory( if not os.access(dir_backup, os.W_OK): raise PermissionError(f"The directory {dir_backup} is not writable.") + _add_project_metadata() timestamp_str = ( datetime.datetime.now().strftime("%d-%B-%Y-%I-%M%p") if timestamp else "" ) @@ -139,6 +195,7 @@ def backup_project_directory( print(f"Saved to: {fp}") + _remove_project_metadata() return fp @@ -146,6 +203,7 @@ def restore_project_directory( fp: Union[str, Path], project_name: Optional[str] = None, overwrite_existing: Optional[bool] = False, + switch: bool = False, ): """ Restore a backed up project data directory from a ``.tar.gz`` (compressed tar archive) specified by ``fp``. Choose a custom name, or use the name of the project in the archive. If the project already exists, you must set ``overwrite_existing`` to True. @@ -155,8 +213,10 @@ def restore_project_directory( fp : str, Path File path of the project to restore. project_name : str, optional - Name of new project to create + Name of new project to create. overwrite_existing : bool, optional + switch: bool, optional. + Switch to new project after restoring it. Returns ------- @@ -197,42 +257,17 @@ def get_project_name(fp): ) with tempfile.TemporaryDirectory() as td: - with tarfile.open(fp, "r:gz") as tar: - - def is_within_directory(directory, target): - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) - - prefix = os.path.commonprefix([abs_directory, abs_target]) - - return prefix == abs_directory + extracted_path = _extract_single_directory_tarball(filepath=fp, output_dir=td) - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - for member in tar.getmembers(): - member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") - - tar.extractall(path, members, numeric_owner=numeric_owner) - - safe_extract(tar, td) - - # Find single extracted directory; don't know it ahead of time - extracted_dir = [ - (Path(td) / dirname) - for dirname in Path(td).iterdir() - if (Path(td) / dirname).is_dir() - ] - if not len(extracted_dir) == 1: - raise ValueError( - "Can't find single directory extracted from project archive" - ) - extracted_path = extracted_dir[0] - - _current = projects.current + _from_project_name = projects.current projects.set_current(project_name, update=False) shutil.copytree(extracted_path, projects.dir, dirs_exist_ok=True) - projects.set_current(_current) + + _restore_project_metadata() + _remove_project_metadata() + + if not switch: + projects.set_current(_from_project_name) print(f"Restored project: {project_name}") diff --git a/tests/test_backup.py b/tests/test_backup.py new file mode 100644 index 0000000..cb21994 --- /dev/null +++ b/tests/test_backup.py @@ -0,0 +1,117 @@ +import json +import shutil + +import pytest +from bw2data import Database, Method, projects +from bw2data.tests import bw2test + +from bw2io.backup import ( + _add_project_metadata, + _extract_single_directory_tarball, + _remove_project_metadata, + _restore_project_metadata, + backup_project_directory, + restore_project_directory, +) + + +@pytest.fixture +@bw2test +def unsourced() -> None: + projects.set_current("test-unsourced") + projects.dataset.data["arbitrary"] = True + Database("foo").write({}) + Method(("bar",)).register() + Method(("bar",)).write([]) + + +@pytest.fixture +@bw2test +def sourced() -> None: + projects.set_current("test-sourced") + projects.dataset.set_sourced() + projects.dataset.data["arbitrary"] = True + Database("foo").write({}) + Method(("bar",)).register() + Method(("bar",)).write([]) + + +def test_add_project_metadata_sourced(sourced): + assert projects.current == "test-sourced" + _add_project_metadata() + assert (projects.dir / "project-metadata.json").is_file() + metadata = json.load(open(projects.dir / "project-metadata.json")) + print(metadata) + assert "name" not in metadata + assert metadata["revision"] + assert metadata["is_sourced"] + assert "full_hash" not in metadata + assert metadata["data"]["arbitrary"] + assert metadata["data"]["25"] + + +def test_add_project_metadata_unsourced(unsourced): + assert projects.current == "test-unsourced" + _add_project_metadata() + assert (projects.dir / "project-metadata.json").is_file() + metadata = json.load(open(projects.dir / "project-metadata.json")) + assert "name" not in metadata + assert "revision" not in metadata + assert "is_sourced" not in metadata + assert "full_hash" not in metadata + assert metadata["data"]["arbitrary"] + assert metadata["data"]["25"] + + +def test_remove_project_metadata(unsourced): + assert projects.current == "test-unsourced" + _add_project_metadata() + assert (projects.dir / "project-metadata.json").is_file() + _remove_project_metadata() + assert not (projects.dir / "project-metadata.json").is_file() + + +def test_restore_project_metadata(sourced): + assert projects.current == "test-sourced" + _add_project_metadata() + + fp = projects.dir / "project-metadata.json" + assert fp.is_file() + + projects.set_current("other") + assert not projects.dataset.is_sourced + assert not projects.dataset.revision + assert not projects.dataset.data.get("arbitrary") + + shutil.copy(fp, projects.dir / "project-metadata.json") + _restore_project_metadata() + + assert projects.dataset.is_sourced + assert projects.dataset.revision + assert projects.dataset.data["arbitrary"] + + +def test_backup_project(sourced, tmp_path): + filepath = backup_project_directory( + "test-sourced", timestamp=False, dir_backup=tmp_path + ) + dirpath = _extract_single_directory_tarball(filepath, tmp_path) + assert (dirpath / "project-metadata.json").is_file() + assert (dirpath / "lci").is_dir() + assert (dirpath / "revisions").is_dir() + + +def test_restore_project(sourced, tmp_path): + filepath = backup_project_directory( + "test-sourced", timestamp=False, dir_backup=tmp_path + ) + revision = projects.dataset.revision + + projects.set_current("default") + projects.delete_project(name="test-sourced", delete_dir=True) + assert "test-sourced" not in projects + + restore_project_directory(filepath, project_name="something-else", switch=True) + assert projects.current == "something-else" + assert projects.dataset.is_sourced + assert projects.dataset.revision == revision