Skip to content

Commit

Permalink
Restore backup from specific location
Browse files Browse the repository at this point in the history
  • Loading branch information
mdegat01 committed Dec 17, 2024
1 parent 90590ae commit e3cc263
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 46 deletions.
9 changes: 7 additions & 2 deletions supervisor/api/backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _ensure_list(item: Any) -> list:
{
vol.Optional(ATTR_PASSWORD): vol.Maybe(str),
vol.Optional(ATTR_BACKGROUND, default=False): vol.Boolean(),
vol.Optional(ATTR_LOCATION): vol.Maybe(str),
}
)

Expand Down Expand Up @@ -379,8 +380,10 @@ async def backup_partial(self, request: web.Request):
async def restore_full(self, request: web.Request):
"""Full restore of a backup."""
backup = self._extract_slug(request)
self._validate_cloud_backup_location(request, backup.location)
body = await api_validate(SCHEMA_RESTORE_FULL, request)
self._validate_cloud_backup_location(
request, body.get(ATTR_LOCATION, backup.location)
)
background = body.pop(ATTR_BACKGROUND)
restore_task, job_id = await self._background_backup_task(
self.sys_backups.do_restore_full, backup, **body
Expand All @@ -397,8 +400,10 @@ async def restore_full(self, request: web.Request):
async def restore_partial(self, request: web.Request):
"""Partial restore a backup."""
backup = self._extract_slug(request)
self._validate_cloud_backup_location(request, backup.location)
body = await api_validate(SCHEMA_RESTORE_PARTIAL, request)
self._validate_cloud_backup_location(
request, body.get(ATTR_LOCATION, backup.location)
)
background = body.pop(ATTR_BACKGROUND)
restore_task, job_id = await self._background_backup_task(
self.sys_backups.do_restore_partial, backup, **body
Expand Down
87 changes: 49 additions & 38 deletions supervisor/backups/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import asyncio
from base64 import b64decode, b64encode
from collections import defaultdict
from collections.abc import Awaitable
from collections.abc import AsyncGenerator, Awaitable
from contextlib import asynccontextmanager
from copy import deepcopy
from datetime import timedelta
from functools import cached_property
Expand All @@ -12,6 +13,7 @@
import logging
from pathlib import Path
import tarfile
from tarfile import TarFile
from tempfile import TemporaryDirectory
import time
from typing import Any, Self
Expand Down Expand Up @@ -56,6 +58,7 @@
from ..utils import remove_folder
from ..utils.dt import parse_datetime, utcnow
from ..utils.json import json_bytes
from ..utils.sentinel import DEFAULT
from .const import BUF_SIZE, LOCATION_CLOUD_BACKUP, BackupType
from .utils import key_to_iv, password_to_key
from .validate import SCHEMA_BACKUP
Expand Down Expand Up @@ -86,7 +89,6 @@ def __init__(
self._data: dict[str, Any] = data or {ATTR_SLUG: slug}
self._tmp = None
self._outer_secure_tarfile: SecureTarFile | None = None
self._outer_secure_tarfile_tarfile: tarfile.TarFile | None = None
self._key: bytes | None = None
self._aes: Cipher | None = None
self._locations: dict[str | None, Path] = {location: tar_file}
Expand Down Expand Up @@ -375,59 +377,68 @@ def _load_file():

return True

async def __aenter__(self):
"""Async context to open a backup."""
@asynccontextmanager
async def create(self) -> AsyncGenerator[None]:
"""Create new backup file."""
if self.tarfile.is_file():
raise BackupError(
f"Cannot make new backup at {self.tarfile.as_posix()}, file already exists!",
_LOGGER.error,
)

# create a backup
if not self.tarfile.is_file():
self._outer_secure_tarfile = SecureTarFile(
self.tarfile,
"w",
gzip=False,
bufsize=BUF_SIZE,
self._outer_secure_tarfile = SecureTarFile(
self.tarfile,
"w",
gzip=False,
bufsize=BUF_SIZE,
)
try:
with self._outer_secure_tarfile as outer_tarfile:
yield
await self._create_cleanup(outer_tarfile)
finally:
self._outer_secure_tarfile = None

@asynccontextmanager
async def open(self, location: str | None | type[DEFAULT]) -> AsyncGenerator[None]:
"""Open backup for restore."""
if location != DEFAULT and location not in self.all_locations:
raise BackupError(
f"Backup {self.slug} does not exist in location {location}",
_LOGGER.error,
)

backup_tarfile = (
self.tarfile if location == DEFAULT else self.all_locations[location]
)
if not backup_tarfile.is_file():
raise BackupError(
f"Cannot open backup at {backup_tarfile.as_posix()}, file does not exist!",
_LOGGER.error,
)
self._outer_secure_tarfile_tarfile = self._outer_secure_tarfile.__enter__()
return

# extract an existing backup
self._tmp = TemporaryDirectory(dir=str(self.tarfile.parent))
self._tmp = TemporaryDirectory(dir=str(backup_tarfile.parent))

def _extract_backup():
"""Extract a backup."""
with tarfile.open(self.tarfile, "r:") as tar:
with tarfile.open(backup_tarfile, "r:") as tar:
tar.extractall(
path=self._tmp.name,
members=secure_path(tar),
filter="fully_trusted",
)

await self.sys_run_in_executor(_extract_backup)

async def __aexit__(self, exception_type, exception_value, traceback):
"""Async context to close a backup."""
# exists backup or exception on build
try:
await self._aexit(exception_type, exception_value, traceback)
finally:
if self._tmp:
self._tmp.cleanup()
if self._outer_secure_tarfile:
self._outer_secure_tarfile.__exit__(
exception_type, exception_value, traceback
)
self._outer_secure_tarfile = None
self._outer_secure_tarfile_tarfile = None
with self._tmp:
await self.sys_run_in_executor(_extract_backup)
yield

async def _aexit(self, exception_type, exception_value, traceback):
async def _create_cleanup(self, outer_tarfile: TarFile) -> None:
"""Cleanup after backup creation.
This is a separate method to allow it to be called from __aexit__ to ensure
Separate method to be called from create to ensure
that cleanup is always performed, even if an exception is raised.
"""
# If we're not creating a new backup, or if an exception was raised, we're done
if not self._outer_secure_tarfile or exception_type is not None:
return

# validate data
try:
self._data = SCHEMA_BACKUP(self._data)
Expand All @@ -445,7 +456,7 @@ def _add_backup_json():
tar_info = tarfile.TarInfo(name="./backup.json")
tar_info.size = len(raw_bytes)
tar_info.mtime = int(time.time())
self._outer_secure_tarfile_tarfile.addfile(tar_info, fileobj=fileobj)
outer_tarfile.addfile(tar_info, fileobj=fileobj)

try:
await self.sys_run_in_executor(_add_backup_json)
Expand Down
25 changes: 20 additions & 5 deletions supervisor/backups/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ async def _do_backup(
try:
self.sys_core.state = CoreState.FREEZE

async with backup:
async with backup.create():
# HomeAssistant Folder is for v1
if homeassistant:
self._change_stage(BackupJobStage.HOME_ASSISTANT, backup)
Expand Down Expand Up @@ -575,6 +575,7 @@ async def _do_restore(
folder_list: list[str],
homeassistant: bool,
replace: bool,
location: str | None | type[DEFAULT],
) -> bool:
"""Restore from a backup.
Expand All @@ -585,7 +586,7 @@ async def _do_restore(

try:
task_hass: asyncio.Task | None = None
async with backup:
async with backup.open(location):
# Restore docker config
self._change_stage(RestoreJobStage.DOCKER_CONFIG, backup)
backup.restore_dockerconfig(replace)
Expand Down Expand Up @@ -671,7 +672,10 @@ async def _do_restore(
cleanup=False,
)
async def do_restore_full(
self, backup: Backup, password: str | None = None
self,
backup: Backup,
password: str | None = None,
location: str | None | type[DEFAULT] = DEFAULT,
) -> bool:
"""Restore a backup."""
# Add backup ID to job
Expand Down Expand Up @@ -702,7 +706,12 @@ async def do_restore_full(
await self.sys_core.shutdown()

success = await self._do_restore(
backup, backup.addon_list, backup.folders, True, True
backup,
backup.addon_list,
backup.folders,
homeassistant=True,
replace=True,
location=location,
)
finally:
self.sys_core.state = CoreState.RUNNING
Expand Down Expand Up @@ -731,6 +740,7 @@ async def do_restore_partial(
addons: list[str] | None = None,
folders: list[Path] | None = None,
password: str | None = None,
location: str | None | type[DEFAULT] = DEFAULT,
) -> bool:
"""Restore a backup."""
# Add backup ID to job
Expand Down Expand Up @@ -766,7 +776,12 @@ async def do_restore_partial(

try:
success = await self._do_restore(
backup, addon_list, folder_list, homeassistant, False
backup,
addon_list,
folder_list,
homeassistant=homeassistant,
replace=False,
location=location,
)
finally:
self.sys_core.state = CoreState.RUNNING
Expand Down
47 changes: 47 additions & 0 deletions tests/api/test_backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,50 @@ async def test_partial_backup_all_addons(
)
assert resp.status == 200
store_addons.assert_called_once_with([install_addon_ssh])


async def test_restore_backup_from_location(
api_client: TestClient, coresys: CoreSys, tmp_supervisor_data: Path
):
"""Test restoring a backup from a specific location."""
coresys.core.state = CoreState.RUNNING
coresys.hardware.disk.get_disk_free_space = lambda x: 5000

# Make a backup and a file to test with
(test_file := coresys.config.path_share / "test.txt").touch()
resp = await api_client.post(
"/backups/new/partial",
json={
"name": "Test",
"folders": ["share"],
"location": [None, ".cloud_backup"],
},
)
assert resp.status == 200
body = await resp.json()
backup = coresys.backups.get(body["data"]["slug"])
assert set(backup.all_locations) == {None, ".cloud_backup"}

# The use case of this is user might want to pick a particular mount if one is flaky
# To simulate this, remove the file from one location and show one works and the other doesn't
assert backup.location is None
backup.all_locations[None].unlink()
test_file.unlink()

resp = await api_client.post(
f"/backups/{backup.slug}/restore/partial",
json={"location": None, "folders": ["share"]},
)
assert resp.status == 400
body = await resp.json()
assert (
body["message"]
== f"Cannot open backup at {backup.all_locations[None].as_posix()}, file does not exist!"
)

resp = await api_client.post(
f"/backups/{backup.slug}/restore/partial",
json={"location": ".cloud_backup", "folders": ["share"]},
)
assert resp.status == 200
assert test_file.is_file()
2 changes: 1 addition & 1 deletion tests/backups/test_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def test_new_backup_stays_in_folder(coresys: CoreSys, tmp_path: Path):
backup.new("test", "2023-07-21T21:05:00.000000+00:00", BackupType.FULL)
assert not listdir(tmp_path)

async with backup:
async with backup.create():
assert len(listdir(tmp_path)) == 1
assert backup.tarfile.exists()

Expand Down

0 comments on commit e3cc263

Please sign in to comment.