From e2b7bbc58555ab686f328beec8ab8af7bd05c82e Mon Sep 17 00:00:00 2001 From: "Bala.FA" Date: Wed, 22 Nov 2023 18:52:22 +0530 Subject: [PATCH] Add typing to helpers.py Signed-off-by: Bala.FA --- minio/helpers.py | 238 +++++++++++++++++++++++++++------------ tests/unit/minio_test.py | 4 +- 2 files changed, 169 insertions(+), 73 deletions(-) diff --git a/minio/helpers.py b/minio/helpers.py index f68bc987a..cde5bb463 100644 --- a/minio/helpers.py +++ b/minio/helpers.py @@ -16,7 +16,7 @@ """Helper functions.""" -from __future__ import absolute_import, division, unicode_literals +from __future__ import absolute_import, annotations, division, unicode_literals import base64 import errno @@ -26,14 +26,17 @@ import platform import re import urllib.parse +from abc import ABCMeta from queue import Queue from threading import BoundedSemaphore, Thread +from typing import BinaryIO + +from urllib3._collections import HTTPHeaderDict from . import __title__, __version__ from .sse import Sse, SseCustomerKey from .time import to_iso8601utc -# Constants _DEFAULT_USER_AGENT = ( f"MinIO ({platform.system()}; {platform.machine()}) " f"{__title__}/{__version__}" @@ -76,7 +79,12 @@ re.IGNORECASE) -def quote(resource, safe='/', encoding=None, errors=None): +def quote( + resource: str, + safe: str = "/", + encoding: str | None = None, + errors: str | None = None, +) -> str: """ Wrapper to urllib.parse.quote() replacing back to '~' for older python versions. @@ -89,17 +97,25 @@ def quote(resource, safe='/', encoding=None, errors=None): ).replace("%7E", "~") -def queryencode(query, safe='', encoding=None, errors=None): +def queryencode( + query: str, + safe: str = "", + encoding: str | None = None, + errors: str | None = None, +) -> str: """Encode query parameter value.""" return quote(query, safe, encoding, errors) -def headers_to_strings(headers, titled_key=False): +def headers_to_strings( + headers: dict[str, str], + titled_key: bool = False, +) -> str: """Convert HTTP headers to multi-line string.""" - def _get_key(key): + def _get_key(key: str) -> str: return key.title() if titled_key else key - def _get_value(value): + def _get_value(value: str) -> str: return re.sub( r"Credential=([^/]+)", "Credential=*REDACTED*", @@ -118,7 +134,7 @@ def _get_value(value): ) -def _validate_sizes(object_size, part_size): +def _validate_sizes(object_size: int, part_size: int): """Validate object and part size.""" if part_size > 0: if part_size < MIN_PART_SIZE: @@ -142,7 +158,7 @@ def _validate_sizes(object_size, part_size): ) -def _get_part_info(object_size, part_size): +def _get_part_info(object_size: int, part_size: int): """Compute part information for object and part size.""" _validate_sizes(object_size, part_size) @@ -159,7 +175,7 @@ def _get_part_info(object_size, part_size): return part_size, math.ceil(object_size / part_size) if part_size else 1 -def get_part_info(object_size, part_size): +def get_part_info(object_size: int, part_size: int) -> tuple[int, int]: """Compute part information for object and part size.""" part_size, part_count = _get_part_info(object_size, part_size) if part_count > MAX_MULTIPART_COUNT: @@ -170,7 +186,23 @@ def get_part_info(object_size, part_size): return part_size, part_count -def read_part_data(stream, size, part_data=b'', progress=None): +class Progress: + """Progress base class for put object API.""" + __metaclass__ = ABCMeta + + def set_meta(self, total_length: int, object_name: str): + """Set object information to progress.""" + + def update(self, size: int): + """Update current progress size.""" + + +def read_part_data( + stream: BinaryIO, + size: int, + part_data: bytes = b"", + progress: Progress | None = None, +) -> bytes: """Read part data of given size from stream.""" size -= len(part_data) while size: @@ -186,7 +218,7 @@ def read_part_data(stream, size, part_data=b'', progress=None): return part_data -def makedirs(path): +def makedirs(path: str): """Wrapper of os.makedirs() ignores errno.EEXIST.""" try: if path: @@ -199,7 +231,11 @@ def makedirs(path): raise ValueError(f"path {path} is not a directory") from exc -def check_bucket_name(bucket_name, strict=False, s3_check=False): +def check_bucket_name( + bucket_name: str, + strict: bool = False, + s3_check: bool = False, +): """Check whether bucket name is valid optional with strict check or not.""" if strict: @@ -229,7 +265,7 @@ def check_bucket_name(bucket_name, strict=False, s3_check=False): "'--ol-s3'") -def check_non_empty_string(string): +def check_non_empty_string(string: str | bytes): """Check whether given string is not empty.""" try: if not string.strip(): @@ -238,7 +274,7 @@ def check_non_empty_string(string): raise TypeError() from exc -def is_valid_policy_type(policy): +def is_valid_policy_type(policy: str | bytes): """ Validate if policy is type str @@ -254,19 +290,19 @@ def is_valid_policy_type(policy): return True -def check_ssec(sse): +def check_ssec(sse: SseCustomerKey): """Check sse is SseCustomerKey type or not.""" if sse and not isinstance(sse, SseCustomerKey): raise ValueError("SseCustomerKey type is required") -def check_sse(sse): +def check_sse(sse: Sse): """Check sse is Sse type or not.""" if sse and not isinstance(sse, Sse): raise ValueError("Sse type is required") -def md5sum_hash(data): +def md5sum_hash(data: str | bytes | None) -> str | None: """Compute MD5 of data and return hash as Base64 encoded value.""" if data is None: return None @@ -279,18 +315,25 @@ def md5sum_hash(data): return md5sum.decode() if isinstance(md5sum, bytes) else md5sum -def sha256_hash(data): +def sha256_hash(data: str | bytes | None) -> str: """Compute SHA-256 of data and return hash as hex encoded value.""" data = data or b"" hasher = hashlib.sha256() hasher.update(data.encode() if isinstance(data, str) else data) sha256sum = hasher.hexdigest() - return sha256sum.decode() if isinstance(sha256sum, bytes) else sha256sum + if isinstance(sha256sum, bytes): + return sha256sum.decode() + return sha256sum def url_replace( - url, scheme=None, netloc=None, path=None, query=None, fragment=None -): + url: urllib.parse.SplitResult, + scheme: str | None = None, + netloc: str | None = None, + path: str | None = None, + query: str | None = None, + fragment: str | None = None, +) -> urllib.parse.SplitResult: """Return new URL with replaced properties in given URL.""" return urllib.parse.SplitResult( scheme if scheme is not None else url.scheme, @@ -301,14 +344,16 @@ def url_replace( ) -def _metadata_to_headers(metadata): +def _metadata_to_headers( + metadata: dict[str, str | list | tuple], +) -> dict[str, list[str]]: """Convert user metadata to headers.""" - def normalize_key(key): + def normalize_key(key: str) -> str: if not key.lower().startswith("x-amz-meta-"): key = "X-Amz-Meta-" + key return key - def to_string(value): + def to_string(value) -> str: value = str(value) try: value.encode("us-ascii") @@ -319,7 +364,7 @@ def to_string(value): ) from exc return value - def normalize_value(values): + def normalize_value(values: str | list | tuple) -> list[str]: if not isinstance(values, (list, tuple)): values = [values] return [to_string(value) for value in values] @@ -330,11 +375,13 @@ def normalize_value(values): } -def normalize_headers(headers): +def normalize_headers( + headers: dict[str, str | list | tuple], +) -> dict[str, str | list | tuple]: """Normalize headers by prefixing 'X-Amz-Meta-' for user metadata.""" headers = {str(key): value for key, value in (headers or {}).items()} - def guess_user_metadata(key): + def guess_user_metadata(key: str) -> bool: key = key.lower() return not ( key.startswith("x-amz-") or @@ -359,7 +406,13 @@ def guess_user_metadata(key): return headers -def genheaders(headers, sse, tags, retention, legal_hold): +def genheaders( + headers: dict[str, str | list | tuple], + sse: Sse | None, + tags: dict[str, str] | None, + retention, + legal_hold: bool, +) -> dict[str, str | list | tuple]: """Generate headers for given parameters.""" headers = normalize_headers(headers) headers.update(sse.headers() if sse else {}) @@ -374,14 +427,18 @@ def genheaders(headers, sse, tags, retention, legal_hold): if retention and retention.mode: headers["x-amz-object-lock-mode"] = retention.mode headers["x-amz-object-lock-retain-until-date"] = ( - to_iso8601utc(retention.retain_until_date) + to_iso8601utc(retention.retain_until_date) or "" ) if legal_hold: headers["x-amz-object-lock-legal-hold"] = "ON" return headers -def _get_aws_info(host, https, region): +def _get_aws_info( + host: str, + https: bool, + region: str | None, +) -> tuple[dict | None, str | None]: """Extract AWS domain information. """ if not _HOSTNAME_REGEX.match(host): @@ -397,7 +454,8 @@ def _get_aws_info(host, https, region): if not _AWS_S3_ENDPOINT_REGEX.match(host): raise ValueError(f"invalid Amazon AWS host {host}") - end = _AWS_S3_PREFIX_REGEX.match(host).end() + matcher = _AWS_S3_PREFIX_REGEX.match(host) + end = matcher.end() if matcher else 0 aws_s3_prefix = host[:end] if "s3-accesspoint" in aws_s3_prefix and not https: @@ -407,7 +465,7 @@ def _get_aws_info(host, https, region): dualstack = tokens[0] == "dualstack" if dualstack: tokens = tokens[1:] - region_in_host = None + region_in_host = "" if tokens[0] not in ["vpce", "amazonaws"]: region_in_host = tokens[0] tokens = tokens[1:] @@ -434,7 +492,7 @@ def _get_aws_info(host, https, region): "dualstack": dualstack}, None) -def _parse_url(endpoint): +def _parse_url(endpoint: str) -> urllib.parse.SplitResult: """Parse url string.""" url = urllib.parse.urlsplit(endpoint) @@ -478,17 +536,23 @@ def _parse_url(endpoint): class BaseURL: """Base URL of S3 endpoint.""" + _aws_info: dict | None + _virtual_style_flag: bool + _url: urllib.parse.SplitResult + _region: str | None + _accelerate_host_flag: bool - def __init__(self, endpoint, region): + def __init__(self, endpoint: str, region: str | None): url = _parse_url(endpoint) if region and not _REGION_REGEX.match(region): raise ValueError(f"invalid region {region}") + hostname = url.hostname or "" self._aws_info, region_in_host = _get_aws_info( - url.hostname, url.scheme == "https", region) + hostname, url.scheme == "https", region) self._virtual_style_flag = ( - self._aws_info or url.hostname.endswith("aliyuncs.com") + self._aws_info is not None or hostname.endswith("aliyuncs.com") ) self._url = url self._region = region or region_in_host @@ -500,32 +564,32 @@ def __init__(self, endpoint, region): ) @property - def region(self): + def region(self) -> str | None: """Get region.""" return self._region @property - def is_https(self): + def is_https(self) -> bool: """Check if scheme is HTTPS.""" return self._url.scheme == "https" @property - def host(self): + def host(self) -> str: """Get hostname.""" return self._url.netloc @property - def is_aws_host(self): + def is_aws_host(self) -> bool: """Check if URL points to AWS host.""" return self._aws_info is not None @property - def aws_s3_prefix(self): + def aws_s3_prefix(self) -> str | None: """Get AWS S3 domain prefix.""" return self._aws_info["s3_prefix"] if self._aws_info else None @aws_s3_prefix.setter - def aws_s3_prefix(self, s3_prefix): + def aws_s3_prefix(self, s3_prefix: str): """Set AWS s3 domain prefix.""" if not _AWS_S3_PREFIX_REGEX.match(s3_prefix): raise ValueError(f"invalid AWS S3 domain prefix {s3_prefix}") @@ -533,40 +597,48 @@ def aws_s3_prefix(self, s3_prefix): self._aws_info["s3_prefix"] = s3_prefix @property - def accelerate_host_flag(self): + def accelerate_host_flag(self) -> bool: """Get AWS accelerate host flag.""" return self._accelerate_host_flag @accelerate_host_flag.setter - def accelerate_host_flag(self, flag): + def accelerate_host_flag(self, flag: bool): """Set AWS accelerate host flag.""" self._accelerate_host_flag = flag @property - def dualstack_host_flag(self): + def dualstack_host_flag(self) -> bool: """Check if URL points to AWS dualstack host.""" return self._aws_info["dualstack"] if self._aws_info else False @dualstack_host_flag.setter - def dualstack_host_flag(self, flag): + def dualstack_host_flag(self, flag: bool): """Set AWS dualstack host.""" if self._aws_info: self._aws_info["dualstack"] = flag @property - def virtual_style_flag(self): + def virtual_style_flag(self) -> bool: """Check to use virtual style or not.""" return self._virtual_style_flag @virtual_style_flag.setter - def virtual_style_flag(self, flag): + def virtual_style_flag(self, flag: bool): """Check to use virtual style or not.""" self._virtual_style_flag = flag - def _build_aws_url(self, url, bucket_name, enforce_path_style, region): + @classmethod + def _build_aws_url( + cls, + aws_info: dict, + url: urllib.parse.SplitResult, + bucket_name: str | None, + enforce_path_style: bool, + region: str, + ) -> urllib.parse.SplitResult: """Build URL for given information.""" - s3_prefix = self._aws_info["s3_prefix"] - domain_suffix = self._aws_info["domain_suffix"] + s3_prefix = aws_info["s3_prefix"] + domain_suffix = aws_info["domain_suffix"] host = f"{s3_prefix}{domain_suffix}" if host in ["s3-external-1.amazonaws.com", @@ -584,7 +656,7 @@ def _build_aws_url(self, url, bucket_name, enforce_path_style, region): if enforce_path_style: netloc = netloc.replace("-accelerate", "", 1) - if self._aws_info["dualstack"]: + if aws_info["dualstack"]: netloc += "dualstack." if "s3-accelerate" not in s3_prefix: netloc += region + "." @@ -592,7 +664,11 @@ def _build_aws_url(self, url, bucket_name, enforce_path_style, region): return url_replace(url, netloc=netloc) - def _build_list_buckets_url(self, url, region): + def _build_list_buckets_url( + self, + url: urllib.parse.SplitResult, + region: str | None, + ) -> urllib.parse.SplitResult: """Build URL for ListBuckets API.""" if not self._aws_info: return url @@ -613,9 +689,13 @@ def _build_list_buckets_url(self, url, region): return url_replace(url, netloc=f"{s3_prefix}{region}.{domain_suffix}") def build( - self, method, region, - bucket_name=None, object_name=None, query_params=None, - ): + self, + method: str, + region: str, + bucket_name: str | None = None, + object_name: str | None = None, + query_params: dict[str, str | list | tuple] | None = None, + ) -> urllib.parse.SplitResult: """Build URL for given information.""" if not bucket_name and object_name: raise ValueError( @@ -649,8 +729,8 @@ def build( ) if self._aws_info: - url = self._build_aws_url( - url, bucket_name, enforce_path_style, region) + url = BaseURL._build_aws_url( + self._aws_info, url, bucket_name, enforce_path_style, region) netloc = url.netloc path = "/" @@ -669,8 +749,14 @@ class ObjectWriteResult: """Result class of any APIs doing object creation.""" def __init__( - self, bucket_name, object_name, version_id, etag, http_headers, - last_modified=None, location=None, + self, + bucket_name: str, + object_name: str, + version_id: str | None, + etag: str | None, + http_headers: HTTPHeaderDict, + last_modified: str | None = None, + location: str | None = None, ): self._bucket_name = bucket_name self._object_name = object_name @@ -681,37 +767,37 @@ def __init__( self._location = location @property - def bucket_name(self): + def bucket_name(self) -> str: """Get bucket name.""" return self._bucket_name @property - def object_name(self): + def object_name(self) -> str: """Get object name.""" return self._object_name @property - def version_id(self): + def version_id(self) -> str | None: """Get version ID.""" return self._version_id @property - def etag(self): + def etag(self) -> str | None: """Get etag.""" return self._etag @property - def http_headers(self): + def http_headers(self) -> HTTPHeaderDict: """Get HTTP headers.""" return self._http_headers @property - def last_modified(self): + def last_modified(self) -> str | None: """Get last-modified time.""" return self._last_modified @property - def location(self): + def location(self) -> str | None: """Get location.""" return self._location @@ -719,7 +805,12 @@ def location(self): class Worker(Thread): """ Thread executing tasks from a given tasks queue """ - def __init__(self, tasks_queue, results_queue, exceptions_queue): + def __init__( + self, + tasks_queue: Queue, + results_queue: Queue, + exceptions_queue: Queue, + ): Thread.__init__(self, daemon=True) self._tasks_queue = tasks_queue self._results_queue = results_queue @@ -751,8 +842,13 @@ def run(self): class ThreadPool: """ Pool of threads consuming tasks from a queue """ + _results_queue: Queue + _exceptions_queue: Queue + _tasks_queue: Queue + _sem: BoundedSemaphore + _num_threads: int - def __init__(self, num_threads): + def __init__(self, num_threads: int): self._results_queue = Queue() self._exceptions_queue = Queue() self._tasks_queue = Queue() @@ -777,7 +873,7 @@ def start_parallel(self): self._tasks_queue, self._results_queue, self._exceptions_queue, ) - def result(self): + def result(self) -> Queue: """ Stop threads and return the result of all called tasks """ # Send None to all threads to cleanly stop them for _ in range(self._num_threads): diff --git a/tests/unit/minio_test.py b/tests/unit/minio_test.py index 38d9312a2..9c7346271 100644 --- a/tests/unit/minio_test.py +++ b/tests/unit/minio_test.py @@ -141,7 +141,7 @@ def test_region_none(self): def test_region_us_west(self): region = BaseURL('https://s3-us-west-1.amazonaws.com', None).region - self.assertIsNone(region) + self.assertEqual(region, "") def test_region_with_dot(self): region = BaseURL('https://s3.us-west-1.amazonaws.com', None).region @@ -155,7 +155,7 @@ def test_region_with_dualstack(self): def test_region_us_east(self): region = BaseURL('http://s3.amazonaws.com', None).region - self.assertIsNone(region) + self.assertEqual(region, "") def test_invalid_value(self): self.assertRaises(ValueError, BaseURL, None, None)