From e58f02d4703aca9bcb4de314a237aff95ecfff0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Ricks?= Date: Fri, 14 Oct 2022 14:09:47 +0200 Subject: [PATCH] Add: Add async GitHub API for downloading release zip and tarball files --- pontos/github/api/release.py | 51 ++++++++++++++++++++++++++++-- tests/github/api/test_release.py | 53 ++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/pontos/github/api/release.py b/pontos/github/api/release.py index 88aa5bf7..21b27e2b 100644 --- a/pontos/github/api/release.py +++ b/pontos/github/api/release.py @@ -16,13 +16,26 @@ # along with this program. If not, see . from pathlib import Path -from typing import ContextManager, Iterable, Iterator, Optional, Tuple, Union +from typing import ( + AsyncContextManager, + ContextManager, + Iterable, + Iterator, + Optional, + Tuple, + Union, +) import httpx from pontos.github.api.client import GitHubAsyncREST from pontos.github.api.helper import JSON_OBJECT -from pontos.helper import DownloadProgressIterable, download +from pontos.helper import ( + AsyncDownloadProgressIterable, + DownloadProgressIterable, + download, + download_async, +) class GitHubAsyncRESTReleases(GitHubAsyncREST): @@ -99,6 +112,40 @@ async def get(self, repo: str, tag: str) -> JSON_OBJECT: response.raise_for_status() return response.json() + def download_release_tarball( + self, repo: str, tag: str + ) -> AsyncContextManager[AsyncDownloadProgressIterable]: + """ + Download a release tarball (tar.gz) file + + Args: + repo: GitHub repository (owner/name) to use + tag: The git tag for the release + + Raises: + HTTPStatusError if the request was invalid + """ + api = f"https://github.com/{repo}/archive/refs/tags/{tag}.tar.gz" + return download_async(self._client.stream(api)) + + def download_release_zip( + self, + repo: str, + tag: str, + ) -> AsyncContextManager[AsyncDownloadProgressIterable]: + """ + Download a release zip file + + Args: + repo: GitHub repository (owner/name) to use + tag: The git tag for the release + + Raises: + HTTPStatusError if the request was invalid + """ + api = f"https://github.com/{repo}/archive/refs/tags/{tag}.zip" + return download_async(self._client.stream(api)) + class GitHubRESTReleaseMixin: def create_tag( diff --git a/tests/github/api/test_release.py b/tests/github/api/test_release.py index 0b3cabb3..c4bb8cde 100644 --- a/tests/github/api/test_release.py +++ b/tests/github/api/test_release.py @@ -15,6 +15,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +# pylint: disable=redefined-builtin + import json import unittest from pathlib import Path @@ -25,6 +27,7 @@ from pontos.github.api import GitHubRESTApi from pontos.github.api.release import GitHubAsyncRESTReleases from pontos.helper import DEFAULT_TIMEOUT +from tests import AsyncIteratorMock, AsyncMock, aiter, anext from tests.github.api import ( GitHubAsyncRESTTestCase, create_response, @@ -157,6 +160,56 @@ async def test_get_failure(self): "/repos/foo/bar/releases/tags/v1.2.3", ) + async def test_download_release_tarball(self): + response = create_response(headers=MagicMock()) + response.headers.get.return_value = 2 + response.aiter_bytes.return_value = AsyncIteratorMock(["1", "2"]) + stream_context = AsyncMock() + stream_context.__aenter__.return_value = response + self.client.stream.return_value = stream_context + + async with self.api.download_release_tarball( + "foo/bar", "v1.2.3" + ) as download_iterable: + it = aiter(download_iterable) + content, progress = await anext(it) + + self.assertEqual(content, "1") + self.assertEqual(progress, 50) + + content, progress = await anext(it) + self.assertEqual(content, "2") + self.assertEqual(progress, 100) + + self.client.stream.assert_called_once_with( + "https://github.com/foo/bar/archive/refs/tags/v1.2.3.tar.gz" + ) + + async def test_download_release_zip(self): + response = create_response(headers=MagicMock()) + response.headers.get.return_value = 2 + response.aiter_bytes.return_value = AsyncIteratorMock(["1", "2"]) + stream_context = AsyncMock() + stream_context.__aenter__.return_value = response + self.client.stream.return_value = stream_context + + async with self.api.download_release_zip( + "foo/bar", "v1.2.3" + ) as download_iterable: + it = aiter(download_iterable) + content, progress = await anext(it) + + self.assertEqual(content, "1") + self.assertEqual(progress, 50) + + content, progress = await anext(it) + self.assertEqual(content, "2") + self.assertEqual(progress, 100) + + self.client.stream.assert_called_once_with( + "https://github.com/foo/bar/archive/refs/tags/v1.2.3.zip" + ) + class GitHubReleaseTestCase(unittest.TestCase): @patch("pontos.github.api.api.httpx.post")