-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
daaf6cb
commit f956173
Showing
1 changed file
with
131 additions
and
154 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,205 +1,182 @@ | ||
# Built-in imports | ||
from pathlib import Path | ||
from os import PathLike | ||
from threading import Thread | ||
from contextlib import nullcontext | ||
from pathlib import Path | ||
from mimetypes import guess_extension as guess_mimetype_extension | ||
from queue import Queue | ||
from typing import Union, Optional, Literal, Any, Dict | ||
from math import ceil | ||
from urllib.parse import urlparse, unquote | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import Union, Literal, Optional, Dict, List, Tuple | ||
|
||
# Third-party imports | ||
try: | ||
from requests import get, head | ||
except (ImportError, ModuleNotFoundError): | ||
pass | ||
|
||
try: | ||
from rich.progress import ( | ||
Progress, | ||
BarColumn, | ||
DownloadColumn, | ||
TransferSpeedColumn, | ||
TimeRemainingColumn, | ||
TextColumn, | ||
TimeElapsedColumn, | ||
) | ||
except (ImportError, ModuleNotFoundError): | ||
pass | ||
from requests import get, head, exceptions as requests_exceptions | ||
from rich.progress import Progress, DownloadColumn, TransferSpeedColumn, TextColumn, TimeRemainingColumn, BarColumn | ||
|
||
# Local imports | ||
from .exceptions import DownloadError | ||
|
||
|
||
class Downloader: | ||
"""A class for downloading direct download URLs. Created to download YouTube videos and audio streams. However, it can be used to download any direct download URL.""" | ||
"""A class for downloading direct download URLs.""" | ||
|
||
def __init__( | ||
self, | ||
max_connections: Union[int, Literal['auto']] = 'auto', | ||
overwrite: bool = False, | ||
overwrite: bool = True, | ||
show_progress_bar: bool = True, | ||
headers: Dict[Any, Any] = None, | ||
timeout: int = 1440, | ||
timeout: Optional[int] = 1440, | ||
) -> None: | ||
""" | ||
Initialize the Downloader class with the required settings for downloading a file. | ||
:param max_connections: The maximum number of connections (threads) to use for downloading the file. | ||
:param overwrite: Overwrite the file if it already exists. Otherwise, a "_1", "_2", etc. suffix will be added to the file name. | ||
:param overwrite: Overwrite the file if it already exists. Otherwise, a "_1", "_2", etc. suffix will be added. | ||
:param show_progress_bar: Show or hide the download progress bar. | ||
:param headers: A dictionary of custom headers to be sent with the request. | ||
:param timeout: The timeout in seconds for the download process. | ||
:param timeout: The timeout in seconds for the download process. Or None for no timeout. | ||
""" | ||
|
||
if isinstance(max_connections, int) and max_connections <= 0: | ||
raise ValueError('The number of threads must be greater than 0.') | ||
|
||
self.max_connections = max_connections | ||
self._overwrite = overwrite | ||
self._show_progress_bar = show_progress_bar | ||
self._headers: Optional[Dict[Any, Any]] = headers | ||
self._timeout = timeout | ||
|
||
self._queue: Queue = Queue() | ||
|
||
def _generate_unique_filename(self, file_path: Union[str, PathLike]) -> Path: | ||
file_path = Path(file_path) | ||
|
||
if self._overwrite: | ||
return file_path | ||
|
||
counter = 1 | ||
unique_filename = file_path | ||
|
||
while unique_filename.exists(): | ||
unique_filename = Path(f'{file_path.parent}/{file_path.stem}_{counter}{file_path.suffix}') | ||
counter += 1 | ||
|
||
return unique_filename | ||
|
||
def _get_filename_from_url(self, headers: Dict[Any, Any]) -> str: | ||
content_disposition = headers.get('Content-Disposition', '') | ||
|
||
if 'filename=' in content_disposition: | ||
filename = content_disposition.split('filename=')[-1].strip().strip('"') | ||
self._max_connections: Union[int, Literal['auto']] = max_connections | ||
self._overwrite: bool = overwrite | ||
self._show_progress_bar: bool = show_progress_bar | ||
self._timeout: Optional[int] = timeout | ||
self._headers: Dict[str, str] = { | ||
'Accept': '*/*', | ||
'Accept-Encoding': 'identity', | ||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36', | ||
} | ||
|
||
def _calculate_connections(self, file_size: int) -> int: | ||
if self._max_connections != 'auto': | ||
return self._max_connections | ||
|
||
if file_size < 1024 * 1024: | ||
return 1 | ||
elif file_size <= 5 * 1024 * 1024: | ||
return 4 | ||
elif file_size <= 50 * 1024 * 1024: | ||
return 8 | ||
elif file_size <= 200 * 1024 * 1024: | ||
return 16 | ||
elif file_size <= 400 * 1024 * 1024: | ||
return 24 | ||
else: | ||
return 32 | ||
|
||
def _get_file_info(self, url: str) -> tuple[int, str, str]: | ||
try: | ||
response = head(url, headers=self._headers, timeout=self._timeout, allow_redirects=True) | ||
|
||
if response.status_code == 405: | ||
response = get(url, headers=self._headers, timeout=self._timeout, stream=True) | ||
|
||
response.raise_for_status() | ||
|
||
content_length = int(response.headers.get('content-length', 0)) | ||
content_type = response.headers.get('content-type', 'application/octet-stream').split(';')[0] | ||
|
||
content_disp = response.headers.get('content-disposition') | ||
|
||
if content_disp and 'filename=' in content_disp: | ||
filename = content_disp.split('filename=')[-1].strip('"\'') | ||
else: | ||
path = unquote(urlparse(url).path) | ||
filename = Path(path).name | ||
|
||
if filename: | ||
return filename | ||
if not filename: | ||
filename = 'downloaded_file' | ||
ext = guess_mimetype_extension(content_type) | ||
|
||
mimetype = headers.get('Content-Type', '') | ||
if ext: | ||
filename += ext | ||
|
||
if mimetype: | ||
extension = guess_mimetype_extension(mimetype.split(';')[0]) | ||
return content_length, content_type, filename | ||
|
||
if extension: | ||
return f'{mimetype.split('/')[0]}{extension}' | ||
except requests_exceptions.RequestException as e: | ||
raise DownloadError(f'An error occurred while getting file info: {str(e)}') from e | ||
|
||
return 'file.unknown' | ||
def _get_chunk_ranges(self, total_size: int) -> List[Tuple[int, int]]: | ||
if total_size == 0: | ||
return [(0, 0)] | ||
|
||
def _calculate_threads(self, file_size: int) -> int: | ||
if isinstance(self.max_connections, int): | ||
return self.max_connections | ||
connections = self._calculate_connections(total_size) | ||
chunk_size = ceil(total_size / connections) | ||
ranges = [] | ||
|
||
if self.max_connections == 'auto': | ||
if file_size <= 5 * 1024 * 1024: # < 5 MB | ||
return 4 | ||
elif file_size <= 50 * 1024 * 1024: # 5-50 MB | ||
return 8 | ||
elif file_size <= 200 * 1024 * 1024: # 50-200 MB | ||
return 16 | ||
elif file_size <= 400 * 1024 * 1024: # 200-400 MB | ||
return 24 | ||
else: | ||
return 32 | ||
for i in range(0, total_size, chunk_size): | ||
end = min(i + chunk_size - 1, total_size - 1) | ||
ranges.append((i, end)) | ||
|
||
def _download_chunk(self, url: str, start: int, end: int, temp_file: Union[str, PathLike]) -> None: | ||
headers = {} if self._headers is None else self._headers | ||
headers['Range'] = f'bytes={start}-{end}' | ||
return ranges | ||
|
||
r = get(url, headers=headers, stream=True, allow_redirects=True) | ||
def _download_chunk(self, url: str, start: int, end: int, progress: Progress, task_id: int) -> bytes: | ||
headers = {**self._headers} | ||
|
||
with Path(temp_file).open('r+b') as f: | ||
f.seek(start) | ||
f.write(r.content) | ||
if end > 0: | ||
headers['Range'] = f'bytes={start}-{end}' | ||
|
||
def _worker(self, progress_task_id: Optional[int] = None, progress: Optional[Progress] = None) -> None: | ||
while not self._queue.empty(): | ||
task = self._queue.get() | ||
try: | ||
response = get(url, headers=headers, timeout=self._timeout) | ||
response.raise_for_status() | ||
chunk = response.content | ||
progress.update(task_id, advance=len(chunk)) | ||
|
||
try: | ||
self._download_chunk(*task) | ||
|
||
if progress and progress_task_id is not None: | ||
progress.update(progress_task_id, advance=task[2] - task[1] + 1) | ||
finally: | ||
self._queue.task_done() | ||
return chunk | ||
except requests_exceptions.RequestException as e: | ||
raise Exception(f'Erro ao baixar chunk: {str(e)}') | ||
|
||
def download(self, url: str, output_file_path: Union[str, PathLike] = Path.cwd()) -> None: | ||
""" | ||
Download a file from a given URL to a specified output path. | ||
Downloads specified file from the given URL. | ||
:param url: The URL of the file to download. | ||
:param output_file_path: The path where the downloaded file will be saved. If it's a directory, the filename will be determined from the URL or server response. If it's a file path, it will save with that name. If overwrite=False and a conflict occurs, a unique name will be generated automatically (e.g., "file_1.ext"). If overwrite=True and a conflict occurs, existing files will be replaced without warning. Defaults to the current working directory. | ||
:param url: The URL to download from. | ||
:param output_file_path: The file path to save the downloaded file to. If it's a directory, the file name will be generated from the server response. Defaults to the current working directory. | ||
:raises DownloadError: If an error occurs while downloading the file. | ||
""" | ||
|
||
output_file_path = Path(output_file_path).resolve() | ||
|
||
r = head(url, headers=self._headers, allow_redirects=True) | ||
|
||
if r.status_code != 200 or 'Content-Length' not in r.headers: | ||
raise ValueError('Could not determine the file size or access the URL.') | ||
|
||
file_size = int(r.headers['Content-Length']) | ||
filename = self._get_filename_from_url(r.headers) | ||
|
||
if output_file_path.is_dir(): | ||
output_file_path = Path(output_file_path, filename) | ||
|
||
output_file_path.parent.mkdir(parents=True, exist_ok=True) | ||
unique_filename = self._generate_unique_filename(output_file_path) | ||
try: | ||
total_size, mime_type, suggested_filename = self._get_file_info(url) | ||
|
||
with open(unique_filename, 'wb') as f: | ||
f.write(b'\0' * file_size) | ||
output_path = Path(output_file_path) | ||
|
||
threads_count = self._calculate_threads(file_size) | ||
chunk_size = file_size // threads_count | ||
if output_path.is_dir(): | ||
output_path = output_path / suggested_filename | ||
|
||
for i in range(threads_count): | ||
start = i * chunk_size | ||
end = file_size - 1 if i == threads_count - 1 else (start + chunk_size - 1) | ||
temp_file = unique_filename | ||
self._queue.put((url, start, end, temp_file)) | ||
if not self._overwrite: | ||
base, ext = output_path.stem, output_path.suffix | ||
counter = 1 | ||
|
||
mimetype = r.headers.get('Content-Type') | ||
while output_path.exists(): | ||
output_path = output_path.parent / f'{base}_{counter}{ext}' | ||
counter += 1 | ||
|
||
context_manager = ( | ||
Progress( | ||
TextColumn(f'Downloading a {mimetype.split("/")[0] if mimetype else "file"} ({mimetype or "unknown"})'), | ||
progress_columns = [ | ||
TextColumn(f'Downloading "{output_path.name}" ({mime_type})'), | ||
BarColumn(), | ||
DownloadColumn(), | ||
TransferSpeedColumn(), | ||
TimeRemainingColumn(), | ||
TimeElapsedColumn(), | ||
) | ||
if self._show_progress_bar | ||
else nullcontext() | ||
) | ||
|
||
with context_manager as progress: | ||
task_id = None | ||
|
||
if self._show_progress_bar: | ||
task_id = progress.add_task('Downloading', total=file_size) | ||
|
||
threads = [] | ||
|
||
for _ in range(threads_count): | ||
thread = Thread( | ||
target=self._worker, | ||
kwargs={'progress_task_id': task_id, 'progress': progress if self._show_progress_bar else None}, | ||
) | ||
thread.start() | ||
threads.append(thread) | ||
|
||
for thread in threads: | ||
thread.join() | ||
] | ||
|
||
with Progress(*progress_columns, disable=not self._show_progress_bar) as progress: | ||
task_id = progress.add_task('download', total=total_size or 100, filename=output_path.name, mime=mime_type) | ||
|
||
if total_size == 0: | ||
chunk = self._download_chunk(url, 0, 0, progress, task_id) | ||
|
||
with open(output_path, 'wb') as f: | ||
f.write(chunk) | ||
else: | ||
chunks = [] | ||
ranges = self._get_chunk_ranges(total_size) | ||
connections = len(ranges) | ||
|
||
with ThreadPoolExecutor(max_workers=connections) as executor: | ||
futures = [ | ||
executor.submit(self._download_chunk, url, start, end, progress, task_id) for start, end in ranges | ||
] | ||
chunks = [f.result() for f in futures] | ||
|
||
with open(output_path, 'wb') as f: | ||
for chunk in chunks: | ||
f.write(chunk) | ||
except Exception as e: | ||
raise DownloadError(f'An error occurred while downloading file: {str(e)}') from e |