Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom entity deserialization #16346

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions sdk/tables/azure-data-tables/azure/data/tables/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _from_generated(cls, generated):
write=generated.write,
retention_policy=RetentionPolicy._from_generated( # pylint:disable=protected-access
generated.retention_policy
)
),
)


Expand All @@ -138,8 +138,7 @@ class Metrics(GeneratedMetrics):
"""

def __init__( # pylint:disable=super-init-not-called
self,
**kwargs # type: Any
self, **kwargs # type: Any
):
self.version = kwargs.get("version", u"1.0")
self.enabled = kwargs.get("enabled", False)
Expand All @@ -161,7 +160,7 @@ def _from_generated(cls, generated):
include_apis=generated.include_apis,
retention_policy=RetentionPolicy._from_generated( # pylint: disable=protected-access
generated.retention_policy
)
),
)


Expand Down Expand Up @@ -307,7 +306,8 @@ def _get_next_cb(self, continuation_token, **kwargs):
def _extract_data_cb(self, get_next_return):
self.location_mode, self._response, self._headers = get_next_return
props_list = [
TableItem._from_generated(t, **self._headers) for t in self._response.value # pylint:disable=protected-access
TableItem._from_generated(t, **self._headers) # pylint:disable=protected-access
for t in self._response.value # pylint:disable=protected-access
]
return self._headers[NEXT_TABLE_NAME] or None, props_list

Expand All @@ -323,13 +323,15 @@ class TableEntityPropertiesPaged(PageIterator):
:keyword str continuation_token: An opaque continuation token.
:keyword str location_mode: The location mode being used to list results. The available
options include "primary" and "secondary".
:keyword Callable[Mapping] entity_hook: A custom entity type for deserialization entities returned
from the service
"""

def __init__(self, command, table, **kwargs):
super(TableEntityPropertiesPaged, self).__init__(
self._get_next_cb,
self._extract_data_cb,
continuation_token=kwargs.get("continuation_token") or {},
continuation_token=kwargs.get("continuation_token", {}),
)
self._command = command
self._headers = None
Expand All @@ -339,6 +341,7 @@ def __init__(self, command, table, **kwargs):
self.filter = kwargs.get("filter")
self.select = kwargs.get("select")
self.location_mode = None
self.entity_hook = kwargs.pop("entity_hook", None) or _convert_to_entity

def _get_next_cb(self, continuation_token, **kwargs):
next_partition_key, next_row_key = _extract_continuation_token(
Expand All @@ -361,7 +364,9 @@ def _get_next_cb(self, continuation_token, **kwargs):

def _extract_data_cb(self, get_next_return):
self.location_mode, self._response, self._headers = get_next_return
props_list = [_convert_to_entity(t) for t in self._response.value]
props_list = [
self.entity_hook(t) for t in self._response.value
]
next_entity = {}
if self._headers[NEXT_PARTITION_KEY] or self._headers[NEXT_ROW_KEY]:
next_entity = {
Expand Down Expand Up @@ -460,15 +465,18 @@ def service_stats_deserialize(generated):
def service_properties_deserialize(generated):
"""Deserialize a ServiceProperties objects into a dict."""
return {
"analytics_logging": TableAnalyticsLogging._from_generated(generated.logging), # pylint: disable=protected-access
"analytics_logging": TableAnalyticsLogging._from_generated( # pylint: disable=protected-access
generated.logging
),
"hour_metrics": Metrics._from_generated( # pylint: disable=protected-access
generated.hour_metrics
),
"minute_metrics": Metrics._from_generated( # pylint: disable=protected-access
generated.minute_metrics
),
"cors": [
CorsRule._from_generated(cors) for cors in generated.cors # pylint: disable=protected-access
CorsRule._from_generated(cors) # pylint: disable=protected-access
for cors in generated.cors
],
}

Expand Down
8 changes: 7 additions & 1 deletion sdk/tables/azure-data-tables/azure/data/tables/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import datetime
from math import isnan
from enum import Enum
import six
import sys
import uuid
import isodate
Expand Down Expand Up @@ -112,7 +113,12 @@ def _to_entity_bool(value):


def _to_entity_datetime(value):
return EdmType.DATETIME, _to_utc_datetime(value)
try:
return EdmType.DATETIME, _to_utc_datetime(value)
except AttributeError:
if not isinstance(value, six.string_types):
raise
return EdmType.DATETIME, value


def _to_entity_float(value):
Expand Down
31 changes: 23 additions & 8 deletions sdk/tables/azure-data-tables/azure/data/tables/_table_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def __init__(
super(TableClient, self).__init__(
account_url, table_name, credential=credential, **kwargs
)
kwargs['connection_timeout'] = kwargs.get('connection_timeout') or CONNECTION_TIMEOUT
kwargs["connection_timeout"] = (
kwargs.get("connection_timeout") or CONNECTION_TIMEOUT
)
self._client = AzureTable(
self.url,
policies=kwargs.pop('policies', self._policies),
**kwargs
self.url, policies=kwargs.pop("policies", self._policies), **kwargs
)

@classmethod
Expand Down Expand Up @@ -425,8 +425,9 @@ def list_entities(
# type: (...) -> ItemPaged[TableEntity]
"""Lists entities in a table.

:keyword int results_per_page: Number of entities per page in return ItemPaged
:keyword int results_per_page: Number of entities per page in returned ItemPaged
:keyword select: Specify desired properties of an entity to return certain entities
:keyword Callable[Mapping] entity_hook: for custom deserialization
:paramtype select: str or list[str]
:return: Query of table entities
:rtype: ~azure.core.paging.ItemPaged[~azure.data.tables.TableEntity]
Expand All @@ -441,6 +442,8 @@ def list_entities(
:dedent: 8
:caption: List all entities held within a table
"""
entity_hook = kwargs.pop("entity_hook", None)

user_select = kwargs.pop("select", None)
if user_select and not isinstance(user_select, str):
user_select = ", ".join(user_select)
Expand All @@ -452,7 +455,9 @@ def list_entities(
table=self.table_name,
results_per_page=top,
select=user_select,
page_iterator_class=TableEntityPropertiesPaged,
page_iterator_class=functools.partial(
TableEntityPropertiesPaged, entity_hook=entity_hook
),
)

@distributed_trace
Expand All @@ -467,6 +472,7 @@ def query_entities(
:param str filter: Specify a filter to return certain entities
:keyword int results_per_page: Number of entities per page in return ItemPaged
:keyword select: Specify desired properties of an entity to return certain entities
:keyword Callable[Mapping] entity_hook: for custom deserialization
:paramtype select: str or list[str]
:keyword dict parameters: Dictionary for formatting query with additional, user defined parameters
:return: Query of table entities
Expand All @@ -482,6 +488,8 @@ def query_entities(
:dedent: 8
:caption: Query entities held within a table
"""
entity_hook = kwargs.pop("entity_hook", None)

parameters = kwargs.pop("parameters", None)
filter = self._parameter_filter_substitution(
parameters, filter
Expand All @@ -498,7 +506,9 @@ def query_entities(
results_per_page=top,
filter=filter,
select=user_select,
page_iterator_class=TableEntityPropertiesPaged,
page_iterator_class=functools.partial(
TableEntityPropertiesPaged, entity_hook=entity_hook
),
)

@distributed_trace
Expand Down Expand Up @@ -528,13 +538,16 @@ def get_entity(
:dedent: 8
:caption: Get a single entity from a table
"""
entity_hook = kwargs.pop("entity_hook", None)
try:
entity = self._client.table.query_entity_with_partition_and_row_key(
table=self.table_name,
partition_key=partition_key,
row_key=row_key,
**kwargs
)
if entity_hook:
return entity_hook(entity)

properties = _convert_to_entity(entity)
return properties
Expand Down Expand Up @@ -654,5 +667,7 @@ def send_batch(
:caption: Using batches to send multiple requests at once
"""
return self._batch_send( # pylint:disable=protected-access
batch._entities, *batch._requests, **kwargs # pylint:disable=protected-access
batch._entities, # pylint:disable=protected-access
*batch._requests, # pylint:disable=protected-access
**kwargs
)
8 changes: 6 additions & 2 deletions sdk/tables/azure-data-tables/azure/data/tables/aio/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ async def _get_next_cb(self, continuation_token, **kwargs):
async def _extract_data_cb(self, get_next_return):
self.location_mode, self._response, self._headers = get_next_return
props_list = [
TableItem._from_generated(t, **self._headers) for t in self._response.value # pylint:disable=protected-access
TableItem._from_generated(t, **self._headers) # pylint:disable=protected-access
for t in self._response.value # pylint:disable=protected-access
]
return self._headers[NEXT_TABLE_NAME] or None, props_list

Expand All @@ -76,6 +77,8 @@ class TableEntityPropertiesPaged(AsyncPageIterator):
:keyword str continuation_token: An opaque continuation token.
:keyword str location_mode: The location mode being used to list results. The available
options include "primary" and "secondary".
:keyword callable entity_hook: A custom entity type for deserialization entities returned
from the service
"""

def __init__(self, command, table, **kwargs):
Expand All @@ -92,6 +95,7 @@ def __init__(self, command, table, **kwargs):
self.filter = kwargs.get("filter")
self.select = kwargs.get("select")
self.location_mode = None
self.entity_hook = kwargs.pop("entity_hook", None) or _convert_to_entity

async def _get_next_cb(self, continuation_token, **kwargs):
next_partition_key, next_row_key = _extract_continuation_token(
Expand All @@ -114,7 +118,7 @@ async def _get_next_cb(self, continuation_token, **kwargs):

async def _extract_data_cb(self, get_next_return):
self.location_mode, self._response, self._headers = get_next_return
props_list = [_convert_to_entity(t) for t in self._response.value]
props_list = [self.entity_hook(t) for t in self._response.value]
next_entity = {}
if self._headers[NEXT_PARTITION_KEY] or self._headers[NEXT_ROW_KEY]:
next_entity = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,13 @@ def __init__(
loop=loop,
**kwargs
)
kwargs['connection_timeout'] = kwargs.get('connection_timeout') or CONNECTION_TIMEOUT
kwargs["connection_timeout"] = (
kwargs.get("connection_timeout") or CONNECTION_TIMEOUT
)
self._configure_policies(**kwargs)
self._client = AzureTable(
self.url,
policies=kwargs.pop('policies', self._policies),
policies=kwargs.pop("policies", self._policies),
loop=loop,
**kwargs
)
Expand Down Expand Up @@ -445,6 +447,7 @@ def list_entities(

:keyword int results_per_page: Number of entities per page in return AsyncItemPaged
:keyword select: Specify desired properties of an entity to return certain entities
:keywork Callable[Mapping] entity_hook: Callable for custom deserialization
:paramtype select: str or list[str]
:return: Query of table entities
:rtype: ~azure.core.async_paging.AsyncItemPaged[~azure.data.tables.TableEntity]
Expand All @@ -459,6 +462,7 @@ def list_entities(
:dedent: 8
:caption: Querying entities from a TableClient
"""
entity_hook = kwargs.pop("entity_hook", None)
user_select = kwargs.pop("select", None)
if user_select and not isinstance(user_select, str):
user_select = ", ".join(user_select)
Expand All @@ -470,7 +474,9 @@ def list_entities(
table=self.table_name,
results_per_page=top,
select=user_select,
page_iterator_class=TableEntityPropertiesPaged,
page_iterator_class=functools.partial(
TableEntityPropertiesPaged, entity_hook=entity_hook
),
)

@distributed_trace
Expand All @@ -485,6 +491,7 @@ def query_entities(
:param str filter: Specify a filter to return certain entities
:keyword int results_per_page: Number of entities per page in return AsyncItemPaged
:keyword select: Specify desired properties of an entity to return certain entities
:keywork Callable[Mapping] entity_hook: Callable for custom deserialization
:paramtype select: str or list[str]
:keyword dict parameters: Dictionary for formatting query with additional, user defined parameters
:return: Query of table entities
Expand All @@ -500,6 +507,7 @@ def query_entities(
:dedent: 8
:caption: Querying entities from a TableClient
"""
entity_hook = kwargs.pop("entity_hook", None)
parameters = kwargs.pop("parameters", None)
filter = self._parameter_filter_substitution(
parameters, filter
Expand All @@ -516,7 +524,9 @@ def query_entities(
results_per_page=top,
filter=filter,
select=user_select,
page_iterator_class=TableEntityPropertiesPaged,
page_iterator_class=functools.partial(
TableEntityPropertiesPaged, entity_hook=entity_hook
),
)

@distributed_trace_async
Expand All @@ -534,6 +544,7 @@ async def get_entity(
:param row_key: The row key of the entity.
:type row_key: str
:return: Dictionary mapping operation metadata returned from the service
:keywork Callable[Mapping] entity_hook: Callable for custom deserialization
:rtype: ~azure.data.tables.TableEntity
:raises ~azure.core.exceptions.HttpResponseError:

Expand All @@ -546,13 +557,16 @@ async def get_entity(
:dedent: 8
:caption: Getting an entity from PartitionKey and RowKey
"""
entity_hook = kwargs.pop("entity_hook", None)
try:
entity = await self._client.table.query_entity_with_partition_and_row_key(
table=self.table_name,
partition_key=partition_key,
row_key=row_key,
**kwargs
)
if entity_hook:
return entity_hook(entity)

properties = _convert_to_entity(entity)
return properties
Expand Down Expand Up @@ -668,5 +682,7 @@ async def send_batch(
:caption: Using batches to send multiple requests at once
"""
return await self._batch_send(
batch._entities, *batch._requests, **kwargs # pylint:disable=protected-access
batch._entities, # pylint:disable=protected-access
*batch._requests, # pylint:disable=protected-access
**kwargs
)
Loading