Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

download_content cleanup #2397

Merged
merged 1 commit into from
Dec 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions custom_components/hacs/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from ..utils.backup import Backup, BackupNetDaemon
from ..utils.decode import decode_content
from ..utils.download import download_content
from ..utils.download import dowload_repository_content, gather_files_to_download
from ..utils.logger import getLogger
from ..utils.path import is_safe
from ..utils.queue_manager import QueueManager
Expand Down Expand Up @@ -547,19 +547,23 @@ async def common_update(self, ignore_issues=False, force=False) -> bool:

return True

async def download_zip_files(self, validate) -> Validate:
async def download_zip_files(self, validate) -> None:
"""Download ZIP archive from repository release."""
download_queue = QueueManager()
try:
contents = False
contents = None
target_ref = self.ref.split("/")[1]

for release in self.releases.objects:
self.logger.info("%s ref: %s --- tag: %s.", self, self.ref, release.tag_name)
if release.tag_name == self.ref.split("/")[1]:
self.logger.debug("%s ref: %s --- tag: %s", self, target_ref, release.tag_name)
if release.tag_name == target_ref:
contents = release.assets
break

if not contents:
return validate
validate.errors.append(f"No assets found for release '{self.ref}'")
return

download_queue = QueueManager()

for content in contents or []:
download_queue.add(self.async_download_zip_file(content, validate))
Expand All @@ -568,9 +572,7 @@ async def download_zip_files(self, validate) -> Validate:
except BaseException: # pylint: disable=broad-except
validate.errors.append("Download was not completed")

return validate

async def async_download_zip_file(self, content, validate) -> Validate:
async def async_download_zip_file(self, content, validate) -> None:
"""Download ZIP archive from repository release."""
try:
filecontent = await self.hacs.async_download_file(content.download_url)
Expand Down Expand Up @@ -601,19 +603,21 @@ def cleanup_temp_dir():
except BaseException: # pylint: disable=broad-except
validate.errors.append("Download was not completed")

return validate

async def download_content(
self,
validate,
_directory_path,
_local_directory,
_ref,
) -> Validate:
async def download_content(self) -> None:
"""Download the content of a directory."""
contents = gather_files_to_download(self)
self.logger.debug(self.data.filename)
if not contents:
raise HacsException("No content to download")

download_queue = QueueManager()

validate = await download_content(self)
return validate
for content in contents:
if self.data.content_in_root and self.data.filename:
if content.name != self.data.filename:
continue
download_queue.add(dowload_repository_content(self, content))
await download_queue.execute()

async def async_get_hacs_json(self, ref: str = None) -> dict[str, Any] | None:
"""Get the content of the hacs.json file."""
Expand Down Expand Up @@ -843,7 +847,7 @@ async def async_install_repository(self) -> None:
if self.data.zip_release and version != self.data.default_branch:
await self.download_zip_files(self.validate)
else:
await download_content(self)
await self.download_content()

if self.validate.errors:
for error in self.validate.errors:
Expand Down
78 changes: 0 additions & 78 deletions custom_components/hacs/utils/download.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
"""Helpers to download repository content."""
from __future__ import annotations

import os
import pathlib
import tempfile
from typing import TYPE_CHECKING
import zipfile

from ..exceptions import HacsException
from ..utils import filters
from ..utils.decorator import concurrent
from ..utils.queue_manager import QueueManager

if TYPE_CHECKING:
from ..base import HacsBase


class FileInformation:
Expand Down Expand Up @@ -97,75 +88,6 @@ def gather_files_to_download(repository):
return files


async def download_zip_files(repository, validate):
"""Download ZIP archive from repository release."""
contents = []
queue = QueueManager()
try:
for release in repository.releases.objects:
repository.logger.info(f"ref: {repository.ref} --- tag: {release.tag_name}")
if release.tag_name == repository.ref.split("/")[1]:
contents = release.assets

if not contents:
return validate

for content in contents or []:
queue.add(async_download_zip_file(repository, content, validate))

await queue.execute()
except BaseException as exception: # pylint: disable=broad-except
validate.errors.append(f"Download was not completed [{exception}]")

return validate


async def async_download_zip_file(repository, content, validate):
"""Download ZIP archive from repository release."""
try:
filecontent = await repository.hacs.async_download_file(content.download_url)

if filecontent is None:
validate.errors.append(f"[{content.name}] was not downloaded.")
return

result = await repository.hacs.async_save_file(
f"{tempfile.gettempdir()}/{repository.data.filename}", filecontent
)
with zipfile.ZipFile(
f"{tempfile.gettempdir()}/{repository.data.filename}", "r"
) as zip_file:
zip_file.extractall(repository.content.path.local)

os.remove(f"{tempfile.gettempdir()}/{repository.data.filename}")

if result:
repository.logger.info(f"Download of {content.name} completed")
return
validate.errors.append(f"[{content.name}] was not downloaded.")
except BaseException as exception: # pylint: disable=broad-except
validate.errors.append(f"Download was not completed [{exception}]")

return validate


async def download_content(repository):
"""Download the content of a directory."""
queue = QueueManager()
contents = gather_files_to_download(repository)
repository.logger.debug(repository.data.filename)
if not contents:
raise HacsException("No content to download")

for content in contents:
if repository.data.content_in_root and repository.data.filename:
if content.name != repository.data.filename:
continue
queue.add(dowload_repository_content(repository, content))
await queue.execute()
return repository.validate


@concurrent(10)
async def dowload_repository_content(repository, content):
"""Download content."""
Expand Down
6 changes: 2 additions & 4 deletions tests/helpers/download/test_download_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from aiogithubapi.objects.repository.content import AIOGitHubAPIRepositoryTreeContent
import pytest

from custom_components.hacs.utils.download import download_content

from tests.sample_data import response_rate_limit_header


Expand All @@ -27,7 +25,7 @@ async def test_download_content(repository, aresponses, tmp_path):
)
]

await download_content(repository)
await repository.download_content()
assert os.path.exists(f"{repository.content.path.local}/test/path/file.file")


Expand Down Expand Up @@ -74,6 +72,6 @@ async def test_download_content_integration(repository_integration, aresponses,
"main",
)
)
await download_content(repository_integration)
await repository_integration.download_content()
for path in repository_integration.tree:
assert os.path.exists(f"{hacs.core.config_path}/{path.full_path}")