diff --git a/src/poetry/repositories/http_repository.py b/src/poetry/repositories/http_repository.py index e9d06d20cf1..ae96481a71d 100644 --- a/src/poetry/repositories/http_repository.py +++ b/src/poetry/repositories/http_repository.py @@ -2,14 +2,13 @@ import functools import hashlib -import os -import urllib -import urllib.parse from collections import defaultdict +from contextlib import contextmanager from pathlib import Path from typing import TYPE_CHECKING from typing import Any +from typing import Iterator import requests @@ -75,39 +74,28 @@ def authenticated_url(self) -> str: def _download(self, url: str, dest: Path) -> None: return download_file(url, dest, session=self.session) - def _get_info_from_wheel(self, url: str) -> PackageInfo: - from poetry.inspection.info import PackageInfo - - wheel_name = urllib.parse.urlparse(url).path.rsplit("/")[-1] - - filepath = self._authenticator.get_cached_file_for_url(url) + @contextmanager + def _cached_or_downloaded_file(self, link: Link) -> Iterator[Path]: + filepath = self._authenticator.get_cached_file_for_url(link.url) if filepath: - return PackageInfo.from_wheel(filepath) + yield filepath + else: + self._log(f"Downloading: {link.url}", level="debug") + with temporary_directory() as temp_dir: + filepath = Path(temp_dir) / link.filename + self._download(link.url, filepath) + yield filepath - self._log(f"Downloading wheel: {wheel_name}", level="debug") - filename = os.path.basename(wheel_name) - with temporary_directory() as temp_dir: - filepath = Path(temp_dir) / filename - self._download(url, filepath) + def _get_info_from_wheel(self, url: str) -> PackageInfo: + from poetry.inspection.info import PackageInfo + with self._cached_or_downloaded_file(Link(url)) as filepath: return PackageInfo.from_wheel(filepath) def _get_info_from_sdist(self, url: str) -> PackageInfo: from poetry.inspection.info import PackageInfo - sdist_name = urllib.parse.urlparse(url).path - sdist_name_log = sdist_name.rsplit("/")[-1] - - filepath = self._authenticator.get_cached_file_for_url(url) - if filepath: - return PackageInfo.from_wheel(filepath) - - self._log(f"Downloading sdist: {sdist_name_log}", level="debug") - filename = os.path.basename(sdist_name) - with temporary_directory() as temp_dir: - filepath = Path(temp_dir) / filename - self._download(url, filepath) - + with self._cached_or_downloaded_file(Link(url)) as filepath: return PackageInfo.from_sdist(filepath) def _get_info_from_urls(self, urls: dict[str, list[str]]) -> PackageInfo: @@ -242,12 +230,7 @@ def _links_to_data(self, links: list[Link], data: PackageInfo) -> dict[str, Any] and link.hash_name not in ("sha256", "sha384", "sha512") and hasattr(hashlib, link.hash_name) ): - with temporary_directory() as temp_dir: - filepath = self._authenticator.get_cached_file_for_url(link.url) - if not filepath: - filepath = Path(temp_dir) / link.filename - self._download(link.url, filepath) - + with self._cached_or_downloaded_file(link) as filepath: known_hash = ( getattr(hashlib, link.hash_name)() if link.hash_name else None )