Skip to content

Commit

Permalink
[EH] Fix checkpointstore typing (#39451)
Browse files Browse the repository at this point in the history
* sync

* aio

* update sync

* edit

* dont break
  • Loading branch information
l0lawrence authored Jan 31, 2025
1 parent 966e630 commit 3391b16
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ class BlobCheckpointStore(CheckpointStore):
:param container_name:
The name of the container for the blobs.
:type container_name: str
:param credential:
:keyword credential:
The credentials with which to authenticate. This is optional if the
account URL already has a SAS token. The value can be a AzureSasCredential, an AzureNamedKeyCredential,
or a TokenCredential.If the URL already has a SAS token, specifying an explicit credential will take priority.
:type credential: ~azure.core.credentials_async.AsyncTokenCredential or
:paramtype credential: ~azure.core.credentials_async.AsyncTokenCredential or
~azure.core.credentials.AzureSasCredential or ~azure.core.credentials.AzureNamedKeyCredential or None
:keyword str api_version:
The Storage API version to use for requests. Default value is '2019-07-07'.
Expand All @@ -50,29 +50,36 @@ def __init__(
container_name: str,
*,
credential: Optional[Union["AsyncTokenCredential", "AzureNamedKeyCredential", "AzureSasCredential"]] = None,
api_version: str = '2019-07-07',
api_version: str = "2019-07-07",
secondary_hostname: Optional[str] = None,
**kwargs: Any
) -> None:
self._container_client = kwargs.pop("container_client", None)
if not self._container_client:
if api_version:
headers = kwargs.get("headers")
if headers:
headers["x-ms-version"] = api_version
else:
kwargs["headers"] = {"x-ms-version": api_version}
headers = kwargs.get("headers")
if headers:
headers["x-ms-version"] = api_version
else:
kwargs["headers"] = {"x-ms-version": api_version}
self._container_client = ContainerClient(
blob_account_url, container_name, credential=credential, **kwargs
blob_account_url,
container_name,
credential=credential,
api_version=api_version,
secondary_hostname=secondary_hostname,
**kwargs
)
self._cached_blob_clients = defaultdict() # type: Dict[str, BlobClient]

@classmethod
def from_connection_string( # pylint:disable=docstring-keyword-should-match-keyword-only
def from_connection_string( # pylint:disable=docstring-keyword-should-match-keyword-only
cls,
conn_str: str,
container_name: str,
*,
credential: Optional[Union["AsyncTokenCredential", "AzureNamedKeyCredential", "AzureSasCredential"]] = None,
api_version: str = "2019-07-07",
secondary_hostname: Optional[str] = None,
**kwargs: Any
) -> "BlobCheckpointStore":
"""Create BlobCheckpointStore from a storage connection string.
Expand All @@ -87,7 +94,7 @@ def from_connection_string( # pylint:disable=docstring-keyword-should-match-keyw
account URL already has a SAS token. The value can be a AzureSasCredential, an AzureNamedKeyCredential,
or a TokenCredential.If the URL already has a SAS token,
specifying an explicit credential will take priority.
:type credential: ~azure.core.credentials_async.AsyncTokenCredential or
:paramtype credential: ~azure.core.credentials_async.AsyncTokenCredential or
~azure.core.credentials.AzureSasCredential or ~azure.core.credentials.AzureNamedKeyCredential or None
:keyword str api_version:
The Storage API version to use for requests. Default value is '2019-07-07'.
Expand All @@ -96,12 +103,20 @@ def from_connection_string( # pylint:disable=docstring-keyword-should-match-keyw
:returns: A blob checkpoint store.
:rtype: ~azure.eventhub.extensions.checkpointstoreblobaio.BlobCheckpointStore
"""
account_url, secondary, credential = parse_connection_str( # type: ignore[assignment]
conn_str, credential, 'blob') # type: ignore[arg-type]
if 'secondary_hostname' not in kwargs:
kwargs['secondary_hostname'] = secondary
account_url, secondary, credential = parse_connection_str( # type: ignore[assignment]
conn_str, credential, "blob" # type: ignore[arg-type]
)
if not secondary_hostname:
secondary_hostname = secondary

return cls(account_url, container_name, credential=credential, **kwargs)
return cls(
account_url,
container_name,
credential=credential,
api_version=api_version,
secondary_hostname=secondary_hostname,
**kwargs
)

async def __aenter__(self) -> "BlobCheckpointStore":
await self._container_client.__aenter__()
Expand All @@ -114,12 +129,10 @@ def _get_blob_client(self, blob_name: str) -> BlobClient:
result = self._cached_blob_clients.get(blob_name)
if not result:
result = self._container_client.get_blob_client(blob_name)
self._cached_blob_clients[blob_name] = result # type: ignore[assignment]
return result # type: ignore[return-value]
self._cached_blob_clients[blob_name] = result # type: ignore[assignment]
return result # type: ignore[return-value]

async def _upload_ownership(
self, ownership: Dict[str, Any], **kwargs: Any
) -> None:
async def _upload_ownership(self, ownership: Dict[str, Any], **kwargs: Any) -> None:
etag = ownership.get("etag")
if etag:
kwargs["if_match"] = etag
Expand All @@ -133,7 +146,7 @@ async def _upload_ownership(
)
blob_name = blob_name.lower()
blob_client = self._get_blob_client(blob_name)
metadata = {'ownerid': ownership['owner_id']}
metadata = {"ownerid": ownership["owner_id"]}
try:
uploaded_blob_properties = await blob_client.set_blob_metadata(metadata, **kwargs)
except ResourceNotFoundError:
Expand All @@ -142,7 +155,7 @@ async def _upload_ownership(
data=UPLOAD_DATA, overwrite=True, metadata=metadata, **kwargs
)
ownership["etag"] = uploaded_blob_properties["etag"]
ownership["last_modified_time"] = uploaded_blob_properties[ # type: ignore[union-attr]
ownership["last_modified_time"] = uploaded_blob_properties[ # type: ignore[union-attr]
"last_modified"
].timestamp()

Expand Down Expand Up @@ -202,9 +215,7 @@ async def list_ownership(
:rtype: iterable[dict[str, any]]
"""
try:
blob_prefix = "{}/{}/{}/ownership/".format(
fully_qualified_namespace, eventhub_name, consumer_group
)
blob_prefix = "{}/{}/{}/ownership/".format(fully_qualified_namespace, eventhub_name, consumer_group)
blobs = self._container_client.list_blobs(
name_starts_with=blob_prefix.lower(), include=["metadata"], **kwargs
)
Expand All @@ -217,9 +228,7 @@ async def list_ownership(
"partition_id": blob.name.split("/")[-1],
"owner_id": blob.metadata["ownerid"],
"etag": blob.etag,
"last_modified_time": blob.last_modified.timestamp()
if blob.last_modified
else None,
"last_modified_time": blob.last_modified.timestamp() if blob.last_modified else None,
}
result.append(ownership)
return result
Expand Down Expand Up @@ -256,12 +265,9 @@ async def claim_ownership(
:rtype: iterable[dict[str,any]]
"""
results = await asyncio.gather(
*[self._claim_one_partition(x, **kwargs) for x in ownership_list],
return_exceptions=True
*[self._claim_one_partition(x, **kwargs) for x in ownership_list], return_exceptions=True
)
return [
ownership for ownership in results if not isinstance(ownership, Exception) # type: ignore[misc]
]
return [ownership for ownership in results if not isinstance(ownership, Exception)] # type: ignore[misc]

async def update_checkpoint(self, checkpoint: Dict[str, Any], **kwargs: Any) -> None:
"""Updates the checkpoint using the given information for the offset, associated partition and
Expand Down Expand Up @@ -301,9 +307,7 @@ async def update_checkpoint(self, checkpoint: Dict[str, Any], **kwargs: Any) ->
await blob_client.set_blob_metadata(metadata, **kwargs)
except ResourceNotFoundError:
logger.info("Upload checkpoint blob %r because it hasn't existed in the container yet.", blob_name)
await blob_client.upload_blob(
data=UPLOAD_DATA, overwrite=True, metadata=metadata
)
await blob_client.upload_blob(data=UPLOAD_DATA, overwrite=True, metadata=metadata)

async def list_checkpoints(
self, fully_qualified_namespace: str, eventhub_name: str, consumer_group: str, **kwargs: Any
Expand All @@ -327,12 +331,8 @@ async def list_checkpoints(
- `offset` (str): The offset of the :class:`EventData<azure.eventhub.EventData>`.
:rtype: iterable[dict[str,any]]
"""
blob_prefix = "{}/{}/{}/checkpoint/".format(
fully_qualified_namespace, eventhub_name, consumer_group
)
blobs = self._container_client.list_blobs(
name_starts_with=blob_prefix.lower(), include=["metadata"], **kwargs
)
blob_prefix = "{}/{}/{}/checkpoint/".format(fully_qualified_namespace, eventhub_name, consumer_group)
blobs = self._container_client.list_blobs(name_starts_with=blob_prefix.lower(), include=["metadata"], **kwargs)
result = []
async for blob in blobs:
metadata = blob.metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,33 +67,41 @@ class BlobCheckpointStore(CheckpointStore):
"""

def __init__(
self,
blob_account_url: str,
container_name: str,
credential: Optional[Union["AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None,
api_version: str = '2019-07-07',
**kwargs: Any
self,
blob_account_url: str,
container_name: str,
credential: Optional[Union["AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None,
*,
api_version: str = "2019-07-07",
secondary_hostname: Optional[str] = None,
**kwargs: Any
) -> None:
self._container_client = kwargs.pop("container_client", None)
if not self._container_client:
api_version = kwargs.pop("api_version", None)
if api_version:
headers = kwargs.get("headers")
if headers:
headers["x-ms-version"] = api_version
else:
kwargs["headers"] = {"x-ms-version": api_version}
headers = kwargs.get("headers")
if headers:
headers["x-ms-version"] = api_version
else:
kwargs["headers"] = {"x-ms-version": api_version}
self._container_client = ContainerClient(
blob_account_url, container_name, credential=credential, **kwargs
blob_account_url,
container_name,
credential=credential,
api_version=api_version,
secondary_hostname=secondary_hostname,
**kwargs
)
self._cached_blob_clients = defaultdict() # type: Dict[str, BlobClient]

@classmethod
def from_connection_string( # pylint:disable=docstring-keyword-should-match-keyword-only
def from_connection_string(
cls,
conn_str: str,
container_name: str,
credential: Optional[Union["AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None,
*,
api_version: str = "2019-07-07",
secondary_hostname: Optional[str] = None,
**kwargs: Any
) -> "BlobCheckpointStore":
"""Create BlobCheckpointStore from a storage connection string.
Expand All @@ -117,12 +125,20 @@ def from_connection_string( # pylint:disable=docstring-keyword-should-match-keyw
:rtype: ~azure.eventhub.extensions.checkpointstoreblob.BlobCheckpointStore
"""

account_url, secondary, credential = parse_connection_str( # type: ignore[assignment]
conn_str, credential, 'blob') # type: ignore[arg-type]
if 'secondary_hostname' not in kwargs:
kwargs['secondary_hostname'] = secondary

return cls(account_url, container_name, credential=credential, **kwargs)
account_url, secondary, credential = parse_connection_str( # type: ignore[assignment]
conn_str, credential, "blob"
) # type: ignore[arg-type]
if not secondary_hostname:
secondary_hostname = secondary

return cls(
account_url,
container_name,
credential=credential,
api_version=api_version,
secondary_hostname=secondary_hostname,
**kwargs
)

def __enter__(self) -> "BlobCheckpointStore":
self._container_client.__enter__()
Expand Down Expand Up @@ -152,7 +168,7 @@ def _upload_ownership(self, ownership, **kwargs):
)
blob_name = blob_name.lower()
blob_client = self._get_blob_client(blob_name)
metadata = {'ownerid': ownership['owner_id']}
metadata = {"ownerid": ownership["owner_id"]}
try:
uploaded_blob_properties = blob_client.set_blob_metadata(metadata, **kwargs)
except ResourceNotFoundError:
Expand All @@ -161,9 +177,7 @@ def _upload_ownership(self, ownership, **kwargs):
data=UPLOAD_DATA, overwrite=True, metadata=metadata, **kwargs
)
ownership["etag"] = uploaded_blob_properties["etag"]
ownership["last_modified_time"] = _to_timestamp(
uploaded_blob_properties["last_modified"]
)
ownership["last_modified_time"] = _to_timestamp(uploaded_blob_properties["last_modified"])

def _claim_one_partition(self, ownership, **kwargs):
updated_ownership = copy.deepcopy(ownership)
Expand Down Expand Up @@ -197,11 +211,7 @@ def _claim_one_partition(self, ownership, **kwargs):
return updated_ownership # Keep the ownership if an unexpected error happens

def list_ownership(
self,
fully_qualified_namespace: str,
eventhub_name: str,
consumer_group: str,
**kwargs: Any
self, fully_qualified_namespace: str, eventhub_name: str, consumer_group: str, **kwargs: Any
) -> Iterable[Dict[str, Any]]:
"""Retrieves a complete ownership list from the storage blob.
Expand All @@ -225,9 +235,7 @@ def list_ownership(
:rtype: iterable[dict[str, any]]
"""
try:
blob_prefix = "{}/{}/{}/ownership/".format(
fully_qualified_namespace, eventhub_name, consumer_group
)
blob_prefix = "{}/{}/{}/ownership/".format(fully_qualified_namespace, eventhub_name, consumer_group)
blobs = self._container_client.list_blobs(
name_starts_with=blob_prefix.lower(), include=["metadata"], **kwargs
)
Expand Down Expand Up @@ -256,11 +264,7 @@ def list_ownership(
)
raise

def claim_ownership(
self,
ownership_list: Iterable[Dict[str, Any]],
**kwargs: Any
) -> Iterable[Dict[str, Any]]:
def claim_ownership(self, ownership_list: Iterable[Dict[str, Any]], **kwargs: Any) -> Iterable[Dict[str, Any]]:
"""Tries to claim ownership for a list of specified partitions.
:param iterable[dict[str, any]] ownership_list: Iterable of dictionaries containing all the ownerships to claim.
Expand Down Expand Up @@ -324,9 +328,7 @@ def update_checkpoint(self, checkpoint: Dict[str, Union[str, int]], **kwargs: An
blob_client.set_blob_metadata(metadata, **kwargs)
except ResourceNotFoundError:
logger.info("Upload checkpoint blob %r because it hasn't existed in the container yet.", blob_name)
blob_client.upload_blob(
data=UPLOAD_DATA, overwrite=True, metadata=metadata, **kwargs
)
blob_client.upload_blob(data=UPLOAD_DATA, overwrite=True, metadata=metadata, **kwargs)

def list_checkpoints(
self, fully_qualified_namespace: str, eventhub_name: str, consumer_group: str, **kwargs: Any
Expand All @@ -350,12 +352,8 @@ def list_checkpoints(
- `offset` (str): The offset of the :class:`EventData<azure.eventhub.EventData>`.
:rtype: iterable[dict[str,any]]
"""
blob_prefix = "{}/{}/{}/checkpoint/".format(
fully_qualified_namespace, eventhub_name, consumer_group
)
blobs = self._container_client.list_blobs(
name_starts_with=blob_prefix.lower(), include=["metadata"], **kwargs
)
blob_prefix = "{}/{}/{}/checkpoint/".format(fully_qualified_namespace, eventhub_name, consumer_group)
blobs = self._container_client.list_blobs(name_starts_with=blob_prefix.lower(), include=["metadata"], **kwargs)
result = []
for b in blobs:
metadata = b.metadata
Expand Down

0 comments on commit 3391b16

Please sign in to comment.