diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 021afc99..273282e7 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -86,6 +86,17 @@ def _wrap(val: Any, obj_wrapper: Optional[Callable[[Any], Any]] = None) -> Any: return val +def _recursive_to_dict(value: Any) -> Dict[str, Any]: + if hasattr(value, "to_dict"): + return value.to_dict() + elif isinstance(value, dict) or isinstance(value, AttrDict): + return {k: _recursive_to_dict(v) for k, v in value.items()} + elif isinstance(value, list) or isinstance(value, AttrList): + return [recursive_to_dict(elem) for elem in value] + else: + return value + + class AttrList(Generic[_ValT]): def __init__( self, l: List[_ValT], obj_wrapper: Optional[Callable[[_ValT], Any]] = None @@ -228,8 +239,8 @@ def __setattr__(self, name: str, value: _ValT) -> None: def __iter__(self) -> Iterator[str]: return iter(self._d_) - def to_dict(self) -> Dict[str, _ValT]: - return self._d_ + def to_dict(self, recursive: bool = False) -> Dict[str, _ValT]: + return _recursive_to_dict(self._d_) if recursive else self._d_ def keys(self) -> Iterable[str]: return self._d_.keys() @@ -436,9 +447,9 @@ def to_dict(self) -> Dict[str, Any]: else: value = value.to_dict() - # serialize anything with to_dict method - elif hasattr(value, "to_dict"): - value = value.to_dict() + # serialize anything else to dict recursively + else: + value = _recursive_to_dict(value) d[pname] = value return {self.name: d} diff --git a/tests/test_integration/_async/test_search.py b/tests/test_integration/_async/test_search.py index 11bc8c72..3dfde51b 100644 --- a/tests/test_integration/_async/test_search.py +++ b/tests/test_integration/_async/test_search.py @@ -112,6 +112,27 @@ async def test_inner_hits_are_wrapped_in_response( ) +@pytest.mark.asyncio +async def test_inner_hits_are_serialized_to_dict( + async_data_client: AsyncElasticsearch, +) -> None: + s = AsyncSearch(index="git")[0:1].query( + "has_parent", parent_type="repo", inner_hits={}, query=Q("match_all") + ) + response = await s.execute() + d = response.to_dict(recursive=True) + assert isinstance(d, dict) + assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict) + + # iterating over the results changes the format of the internal AttrDict + for hit in response: + pass + + d = response.to_dict(recursive=True) + assert isinstance(d, dict) + assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict) + + @pytest.mark.asyncio async def test_scan_respects_doc_types(async_data_client: AsyncElasticsearch) -> None: repos = [repo async for repo in Repository.search().scan()] diff --git a/tests/test_integration/_sync/test_search.py b/tests/test_integration/_sync/test_search.py index 18ed8566..d4e62016 100644 --- a/tests/test_integration/_sync/test_search.py +++ b/tests/test_integration/_sync/test_search.py @@ -104,6 +104,27 @@ def test_inner_hits_are_wrapped_in_response( ) +@pytest.mark.sync +def test_inner_hits_are_serialized_to_dict( + data_client: Elasticsearch, +) -> None: + s = Search(index="git")[0:1].query( + "has_parent", parent_type="repo", inner_hits={}, query=Q("match_all") + ) + response = s.execute() + d = response.to_dict(recursive=True) + assert isinstance(d, dict) + assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict) + + # iterating over the results changes the format of the internal AttrDict + for hit in response: + pass + + d = response.to_dict(recursive=True) + assert isinstance(d, dict) + assert isinstance(d["hits"]["hits"][0]["inner_hits"]["repo"], dict) + + @pytest.mark.sync def test_scan_respects_doc_types(data_client: Elasticsearch) -> None: repos = [repo for repo in Repository.search().scan()]