Skip to content

Commit

Permalink
Add type aliases
Browse files Browse the repository at this point in the history
Signed-off-by: Bala.FA <bala@minio.io>
  • Loading branch information
balamurugana committed Dec 12, 2023
1 parent 9855c36 commit 6fa8bab
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 58 deletions.
95 changes: 46 additions & 49 deletions minio/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@
from .error import InvalidResponseError, S3Error, ServerError
from .helpers import (_DEFAULT_USER_AGENT, MAX_MULTIPART_COUNT,
MAX_MULTIPART_OBJECT_SIZE, MAX_PART_SIZE, MIN_PART_SIZE,
BaseURL, ObjectWriteResult, ProgressType, ThreadPool,
check_bucket_name, check_non_empty_string, check_sse,
check_ssec, genheaders, get_part_info,
headers_to_strings, is_valid_policy_type, makedirs,
md5sum_hash, queryencode, read_part_data, sha256_hash)
BaseURL, HeaderType, ObjectWriteResult, ProgressType,
QueryType, ThreadPool, check_bucket_name,
check_non_empty_string, check_sse, check_ssec,
genheaders, get_part_info, headers_to_strings,
is_valid_policy_type, makedirs, md5sum_hash, queryencode,
read_part_data, sha256_hash)
from .legalhold import LegalHold
from .lifecycleconfig import LifecycleConfig
from .notificationconfig import NotificationConfig
Expand Down Expand Up @@ -205,10 +206,10 @@ def _handle_redirect_response(
def _build_headers(
self,
host: str,
headers: dict[str, str | list[str] | tuple[str]] | None,
headers: HeaderType | None,
body: bytes | None,
creds: Credentials | None,
) -> tuple[dict[str, str | list[str] | tuple[str]], datetime]:
) -> tuple[HeaderType, datetime]:
"""Build headers with given parameters."""
headers = headers or {}
md5sum_added = headers.get("Content-MD5")
Expand Down Expand Up @@ -244,9 +245,8 @@ def _url_open(
bucket_name: str | None = None,
object_name: str | None = None,
body: bytes | None = None,
headers: dict[str, str | list[str] | tuple[str]] | None = None,
query_params: dict[str, str | list[str]
| tuple[str]] | None = None,
headers: HeaderType | None = None,
query_params: QueryType | None = None,
preload_content: bool = True,
no_body_trace: bool = False,
) -> BaseHTTPResponse:
Expand Down Expand Up @@ -424,9 +424,9 @@ def _execute(
bucket_name: str | None = None,
object_name: str | None = None,
body: bytes | None = None,
headers: dict[str, str | list[str] | tuple[str]] | None = None,
headers: HeaderType | None = None,
query_params:
dict[str, str | list[str] | tuple[str]] | None = None,
HeaderType | None = None,
preload_content: bool = True,
no_body_trace: bool = False,
) -> BaseHTTPResponse:
Expand Down Expand Up @@ -640,7 +640,7 @@ def make_bucket(
f"but passed {location}"
)
location = self._base_url.region or location or "us-east-1"
headers: dict[str, str | list[str] | tuple[str]] | None = (
headers: HeaderType | None = (
{"x-amz-bucket-object-lock-enabled": "true"}
if object_lock else None
)
Expand Down Expand Up @@ -1051,7 +1051,7 @@ def fput_object(
file_size,
content_type=content_type,
metadata=cast(
dict[str, str | list[str] | tuple[str]] | None,
HeaderType | None,
metadata,
),
sse=sse,
Expand All @@ -1072,7 +1072,7 @@ def fget_object(
ssec: SseCustomerKey | None = None,
version_id: str | None = None,
extra_query_params:
dict[str, str | list[str] | tuple[str]] | None = None,
HeaderType | None = None,
tmp_file_path: str | None = None,
progress: ProgressType | None = None,
):
Expand Down Expand Up @@ -1169,7 +1169,7 @@ def get_object(
ssec: SseCustomerKey | None = None,
version_id: str | None = None,
extra_query_params:
dict[str, str | list[str] | tuple[str]] | None = None,
HeaderType | None = None,
) -> BaseHTTPResponse:
"""
Get data of an object. Returned response should be closed after use to
Expand Down Expand Up @@ -1247,7 +1247,7 @@ def get_object(
"GET",
bucket_name,
object_name,
headers=cast(dict[str, str | list[str] | tuple[str]], headers),
headers=cast(HeaderType, headers),
query_params=extra_query_params,
preload_content=False,
)
Expand All @@ -1258,7 +1258,7 @@ def copy_object(
object_name: str,
source: CopySource,
sse: Sse | None = None,
metadata: dict[str, str | list[str] | tuple[str]] | None = None,
metadata: HeaderType | None = None,
tags: Tags | None = None,
retention: Retention | None = None,
legal_hold: bool = False,
Expand Down Expand Up @@ -1466,7 +1466,7 @@ def _upload_part_copy(
object_name: str,
upload_id: str,
part_number: int,
headers: dict[str, str | list[str] | tuple[str]],
headers: HeaderType,
) -> tuple[str, datetime | None]:
"""Execute UploadPartCopy S3 API."""
response = self._execute(
Expand All @@ -1487,7 +1487,7 @@ def compose_object(
object_name: str,
sources: list[ComposeSource],
sse: Sse | None = None,
metadata: dict[str, str | list[str] | tuple[str]] | None = None,
metadata: HeaderType | None = None,
tags: Tags | None = None,
retention: Retention | None = None,
legal_hold: bool = False,
Expand Down Expand Up @@ -1581,7 +1581,7 @@ def compose_object(
size -= src.offset
offset = src.offset or 0
headers = cast(
dict[str, str | list[str] | tuple[str]], src.headers,
HeaderType, src.headers,
)
headers.update(ssec_headers)
if size <= MAX_PART_SIZE:
Expand Down Expand Up @@ -1686,7 +1686,7 @@ def _create_multipart_upload(
self,
bucket_name: str,
object_name: str,
headers: dict[str, str | list[str] | tuple[str]],
headers: HeaderType,
) -> str:
"""Execute CreateMultipartUpload S3 API."""
if not headers.get("Content-Type"):
Expand All @@ -1706,9 +1706,8 @@ def _put_object(
bucket_name: str,
object_name: str,
data: bytes,
headers: dict[str, str | list[str] | tuple[str]] | None,
query_params: dict[str, str | list[str]
| tuple[str]] | None = None,
headers: HeaderType | None,
query_params: QueryType | None = None,
) -> ObjectWriteResult:
"""Execute PutObject S3 API."""
response = self._execute(
Expand All @@ -1733,7 +1732,7 @@ def _upload_part(
bucket_name: str,
object_name: str,
data: bytes,
headers: dict[str, str | list[str] | tuple[str]] | None,
headers: HeaderType | None,
upload_id: str,
part_number: int,
) -> str:
Expand Down Expand Up @@ -1761,7 +1760,7 @@ def put_object(
data: BinaryIO,
length: int,
content_type: str = "application/octet-stream",
metadata: dict[str, str | list[str] | tuple[str]] | None = None,
metadata: HeaderType | None = None,
sse: Sse | None = None,
progress: ProgressType | None = None,
part_size: int = 0,
Expand Down Expand Up @@ -1890,7 +1889,7 @@ def put_object(
part_data,
(
cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
sse.headers(),
)
if isinstance(sse, SseCustomerKey) else None
Expand Down Expand Up @@ -2009,10 +2008,8 @@ def stat_object(
object_name: str,
ssec: SseCustomerKey | None = None,
version_id: str | None = None,
extra_headers: dict[str, str | list[str]
| tuple[str]] | None = None,
extra_query_params: dict[str, str |
list[str] | tuple[str]] | None = None,
extra_headers: HeaderType | None = None,
extra_query_params: QueryType | None = None,
) -> Object:
"""
Get object information and metadata of an object.
Expand Down Expand Up @@ -2047,7 +2044,7 @@ def stat_object(
check_ssec(ssec)

headers = cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
ssec.headers() if ssec else {},
)
if extra_headers:
Expand Down Expand Up @@ -2130,7 +2127,7 @@ def _delete_objects(
:return: :class:`DeleteResult <DeleteResult>` object.
"""
body = marshal(DeleteRequest(delete_object_list, quiet=quiet))
headers: dict[str, str | list[str] | tuple[str]] = {
headers: HeaderType = {
"Content-MD5": cast(str, md5sum_hash(body)),
}
if bypass_governance_mode:
Expand Down Expand Up @@ -2284,7 +2281,7 @@ def get_presigned_url(
bucket_name=bucket_name,
object_name=object_name,
query_params=cast(
dict[str, str | list[str] | tuple[str]], query_params,
HeaderType, query_params,
),
)

Expand Down Expand Up @@ -2661,7 +2658,7 @@ def delete_object_tags(
bucket_name,
object_name=object_name,
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query_params,
),
)
Expand Down Expand Up @@ -2693,7 +2690,7 @@ def get_object_tags(
bucket_name,
object_name=object_name,
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query_params,
),
)
Expand Down Expand Up @@ -2739,7 +2736,7 @@ def set_object_tags(
body=body,
headers={"Content-MD5": cast(str, md5sum_hash(body))},
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query_params,
),
)
Expand Down Expand Up @@ -2772,7 +2769,7 @@ def enable_object_legal_hold(
body=body,
headers={"Content-MD5": cast(str, md5sum_hash(body))},
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query_params,
),
)
Expand Down Expand Up @@ -2805,7 +2802,7 @@ def disable_object_legal_hold(
body=body,
headers={"Content-MD5": cast(str, md5sum_hash(body))},
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query_params,
),
)
Expand Down Expand Up @@ -2839,7 +2836,7 @@ def is_object_legal_hold_enabled(
bucket_name,
object_name=object_name,
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query_params,
),
)
Expand Down Expand Up @@ -2933,7 +2930,7 @@ def get_object_retention(
bucket_name,
object_name=object_name,
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query_params,
),
)
Expand Down Expand Up @@ -2978,7 +2975,7 @@ def set_object_retention(
body=body,
headers={"Content-MD5": cast(str, md5sum_hash(body))},
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query_params,
),
)
Expand Down Expand Up @@ -3081,7 +3078,7 @@ def upload_snowball_objects(
cast(BinaryIO, fileobj),
length,
metadata=cast(
dict[str, str | list[str] | tuple[str]] | None,
HeaderType | None,
metadata,
),
sse=sse,
Expand Down Expand Up @@ -3154,7 +3151,7 @@ def _list_objects(
"GET",
bucket_name,
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query,
),
)
Expand Down Expand Up @@ -3222,11 +3219,11 @@ def _list_multipart_uploads(
"GET",
bucket_name,
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query_params,
),
headers=cast(
dict[str, str | list[str] | tuple[str]] | None,
HeaderType | None,
extra_headers,
),
)
Expand Down Expand Up @@ -3271,11 +3268,11 @@ def _list_parts(
bucket_name,
object_name=object_name,
query_params=cast(
dict[str, str | list[str] | tuple[str]],
HeaderType,
query_params,
),
headers=cast(
dict[str, str | list[str] | tuple[str]] | None,
HeaderType | None,
extra_headers,
),
)
Expand Down
17 changes: 8 additions & 9 deletions minio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from datetime import datetime
from queue import Queue
from threading import BoundedSemaphore, Thread
from typing import BinaryIO, Mapping
from typing import BinaryIO, Mapping, MutableMapping, List, Tuple

from typing_extensions import Protocol
from urllib3._collections import HTTPHeaderDict
Expand Down Expand Up @@ -79,6 +79,9 @@
_REGION_REGEX = re.compile(r'^((?!_)(?!-)[a-z_\d-]{1,63}(?<!-)(?<!_))$',
re.IGNORECASE)

HeaderType = MutableMapping[str, str | List[str] | Tuple[str]]
QueryType = MutableMapping[str, str | List[str] | Tuple[str]]


def quote(
resource: str,
Expand Down Expand Up @@ -335,9 +338,7 @@ def url_replace(
)


def _metadata_to_headers(
metadata: dict[str, str | list[str] | tuple[str]],
) -> dict[str, list[str]]:
def _metadata_to_headers(metadata: HeaderType) -> dict[str, list[str]]:
"""Convert user metadata to headers."""
def normalize_key(key: str) -> str:
if not key.lower().startswith("x-amz-meta-"):
Expand Down Expand Up @@ -366,9 +367,7 @@ def normalize_value(values: str | list[str] | tuple[str]) -> list[str]:
}


def normalize_headers(
headers: dict[str, str | list[str] | tuple[str]] | None,
) -> dict[str, str | list[str] | tuple[str]]:
def normalize_headers(headers: HeaderType | None) -> HeaderType:
"""Normalize headers by prefixing 'X-Amz-Meta-' for user metadata."""
headers = {str(key): value for key, value in (headers or {}).items()}

Expand Down Expand Up @@ -398,12 +397,12 @@ def guess_user_metadata(key: str) -> bool:


def genheaders(
headers: dict[str, str | list[str] | tuple[str]] | None,
headers: HeaderType | None,
sse: Sse | None,
tags: dict[str, str] | None,
retention,
legal_hold: bool,
) -> dict[str, str | list[str] | tuple[str]]:
) -> HeaderType:
"""Generate headers for given parameters."""
headers = normalize_headers(headers)
headers.update(sse.headers() if sse else {})
Expand Down

0 comments on commit 6fa8bab

Please sign in to comment.