Skip to content

Commit

Permalink
Added recursive to_dict support to AttrDict
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Aug 30, 2024
1 parent 8cc2ed2 commit 58c4737
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 2 deletions.
17 changes: 15 additions & 2 deletions elasticsearch_dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ def _wrap(val: Any, obj_wrapper: Optional[Callable[[Any], Any]] = None) -> Any:
return val


def _recursive_to_dict(value: Any) -> Any:
if hasattr(value, "to_dict"):
return value.to_dict()
elif isinstance(value, dict) or isinstance(value, AttrDict):
return {k: _recursive_to_dict(v) for k, v in value.items()}
elif isinstance(value, list) or isinstance(value, AttrList):
return [recursive_to_dict(elem) for elem in value]
else:
return value


class AttrList(Generic[_ValT]):
def __init__(
self, l: List[_ValT], obj_wrapper: Optional[Callable[[_ValT], Any]] = None
Expand Down Expand Up @@ -228,8 +239,10 @@ def __setattr__(self, name: str, value: _ValT) -> None:
def __iter__(self) -> Iterator[str]:
return iter(self._d_)

def to_dict(self) -> Dict[str, _ValT]:
return self._d_
def to_dict(self, recursive: bool = False) -> Dict[str, _ValT]:
return cast(
Dict[str, _ValT], _recursive_to_dict(self._d_) if recursive else self._d_
)

def keys(self) -> Iterable[str]:
return self._d_.keys()
Expand Down
21 changes: 21 additions & 0 deletions tests/test_integration/_async/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,27 @@ async def test_inner_hits_are_wrapped_in_response(
)


@pytest.mark.asyncio
async def test_inner_hits_are_serialized_to_dict(
async_data_client: AsyncElasticsearch,
) -> None:
s = AsyncSearch(index="git")[0:1].query(
"has_parent", parent_type="repo", inner_hits={}, query=Q("match_all")
)
response = await s.execute()
d = response.to_dict(recursive=True)
assert isinstance(d, dict)
assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict)

# iterating over the results changes the format of the internal AttrDict
for hit in response:
pass

d = response.to_dict(recursive=True)
assert isinstance(d, dict)
assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict)


@pytest.mark.asyncio
async def test_scan_respects_doc_types(async_data_client: AsyncElasticsearch) -> None:
repos = [repo async for repo in Repository.search().scan()]
Expand Down
21 changes: 21 additions & 0 deletions tests/test_integration/_sync/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,27 @@ def test_inner_hits_are_wrapped_in_response(
)


@pytest.mark.sync
def test_inner_hits_are_serialized_to_dict(
data_client: Elasticsearch,
) -> None:
s = Search(index="git")[0:1].query(
"has_parent", parent_type="repo", inner_hits={}, query=Q("match_all")
)
response = s.execute()
d = response.to_dict(recursive=True)
assert isinstance(d, dict)
assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict)

# iterating over the results changes the format of the internal AttrDict
for hit in response:
pass

d = response.to_dict(recursive=True)
assert isinstance(d, dict)
assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict)


@pytest.mark.sync
def test_scan_respects_doc_types(data_client: Elasticsearch) -> None:
repos = [repo for repo in Repository.search().scan()]
Expand Down

0 comments on commit 58c4737

Please sign in to comment.