Skip to content

Commit

Permalink
Add: Add async GitHub API for downloading release zip and tarball files
Browse files Browse the repository at this point in the history
  • Loading branch information
bjoernricks committed Oct 25, 2022
1 parent 11d7c13 commit e58f02d
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 2 deletions.
51 changes: 49 additions & 2 deletions pontos/github/api/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,26 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from pathlib import Path
from typing import ContextManager, Iterable, Iterator, Optional, Tuple, Union
from typing import (
AsyncContextManager,
ContextManager,
Iterable,
Iterator,
Optional,
Tuple,
Union,
)

import httpx

from pontos.github.api.client import GitHubAsyncREST
from pontos.github.api.helper import JSON_OBJECT
from pontos.helper import DownloadProgressIterable, download
from pontos.helper import (
AsyncDownloadProgressIterable,
DownloadProgressIterable,
download,
download_async,
)


class GitHubAsyncRESTReleases(GitHubAsyncREST):
Expand Down Expand Up @@ -99,6 +112,40 @@ async def get(self, repo: str, tag: str) -> JSON_OBJECT:
response.raise_for_status()
return response.json()

def download_release_tarball(
self, repo: str, tag: str
) -> AsyncContextManager[AsyncDownloadProgressIterable]:
"""
Download a release tarball (tar.gz) file
Args:
repo: GitHub repository (owner/name) to use
tag: The git tag for the release
Raises:
HTTPStatusError if the request was invalid
"""
api = f"https://github.com/{repo}/archive/refs/tags/{tag}.tar.gz"
return download_async(self._client.stream(api))

def download_release_zip(
self,
repo: str,
tag: str,
) -> AsyncContextManager[AsyncDownloadProgressIterable]:
"""
Download a release zip file
Args:
repo: GitHub repository (owner/name) to use
tag: The git tag for the release
Raises:
HTTPStatusError if the request was invalid
"""
api = f"https://github.com/{repo}/archive/refs/tags/{tag}.zip"
return download_async(self._client.stream(api))


class GitHubRESTReleaseMixin:
def create_tag(
Expand Down
53 changes: 53 additions & 0 deletions tests/github/api/test_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

# pylint: disable=redefined-builtin

import json
import unittest
from pathlib import Path
Expand All @@ -25,6 +27,7 @@
from pontos.github.api import GitHubRESTApi
from pontos.github.api.release import GitHubAsyncRESTReleases
from pontos.helper import DEFAULT_TIMEOUT
from tests import AsyncIteratorMock, AsyncMock, aiter, anext
from tests.github.api import (
GitHubAsyncRESTTestCase,
create_response,
Expand Down Expand Up @@ -157,6 +160,56 @@ async def test_get_failure(self):
"/repos/foo/bar/releases/tags/v1.2.3",
)

async def test_download_release_tarball(self):
response = create_response(headers=MagicMock())
response.headers.get.return_value = 2
response.aiter_bytes.return_value = AsyncIteratorMock(["1", "2"])
stream_context = AsyncMock()
stream_context.__aenter__.return_value = response
self.client.stream.return_value = stream_context

async with self.api.download_release_tarball(
"foo/bar", "v1.2.3"
) as download_iterable:
it = aiter(download_iterable)
content, progress = await anext(it)

self.assertEqual(content, "1")
self.assertEqual(progress, 50)

content, progress = await anext(it)
self.assertEqual(content, "2")
self.assertEqual(progress, 100)

self.client.stream.assert_called_once_with(
"https://github.com/foo/bar/archive/refs/tags/v1.2.3.tar.gz"
)

async def test_download_release_zip(self):
response = create_response(headers=MagicMock())
response.headers.get.return_value = 2
response.aiter_bytes.return_value = AsyncIteratorMock(["1", "2"])
stream_context = AsyncMock()
stream_context.__aenter__.return_value = response
self.client.stream.return_value = stream_context

async with self.api.download_release_zip(
"foo/bar", "v1.2.3"
) as download_iterable:
it = aiter(download_iterable)
content, progress = await anext(it)

self.assertEqual(content, "1")
self.assertEqual(progress, 50)

content, progress = await anext(it)
self.assertEqual(content, "2")
self.assertEqual(progress, 100)

self.client.stream.assert_called_once_with(
"https://github.com/foo/bar/archive/refs/tags/v1.2.3.zip"
)


class GitHubReleaseTestCase(unittest.TestCase):
@patch("pontos.github.api.api.httpx.post")
Expand Down

0 comments on commit e58f02d

Please sign in to comment.