diff --git a/kedro-datasets/kedro_datasets/api/api_dataset.py b/kedro-datasets/kedro_datasets/api/api_dataset.py index fe061e0af..5e5cdfade 100644 --- a/kedro-datasets/kedro_datasets/api/api_dataset.py +++ b/kedro-datasets/kedro_datasets/api/api_dataset.py @@ -3,15 +3,20 @@ """ from __future__ import annotations -import json as json_ # make pylint happy from copy import deepcopy -from typing import Any +from typing import Any, Type import requests from kedro.io.core import AbstractDataset, DatasetError +from kedro.io.memory_dataset import MemoryDataset from requests import Session, sessions from requests.auth import AuthBase +import json as json_ # make pylint happy +from kedro_datasets.json import JSONDataset +from kedro_datasets.pickle.pickle_dataset import PickleDataset +from kedro_datasets.text import TextDataset + class APIDataset(AbstractDataset[None, requests.Response]): """``APIDataset`` loads/saves data from/to HTTP(S) APIs. @@ -97,6 +102,8 @@ def __init__( # noqa: PLR0913 save_args: dict[str, Any] | None = None, credentials: tuple[str, str] | list[str] | AuthBase | None = None, metadata: dict[str, Any] | None = None, + extension: str | None = None, + wrapped_dataset: dict[str, Any] | None = None, ) -> None: """Creates a new instance of ``APIDataset`` to fetch data from an API endpoint. @@ -155,6 +162,9 @@ def __init__( # noqa: PLR0913 } self.metadata = metadata + self._extension = extension + self._wrapped_dataset_args = wrapped_dataset + self._wrapped_dataset = None @staticmethod def _convert_type(value: Any): @@ -171,6 +181,8 @@ def _describe(self) -> dict[str, Any]: # prevent auth from logging request_args_cp = self._request_args.copy() request_args_cp.pop("auth", None) + if self._extension: + request_args_cp["wrapped_dataset"] = self.wrapped_dataset._describe() return request_args_cp def _execute_request(self, session: Session) -> requests.Response: @@ -184,10 +196,12 @@ def _execute_request(self, session: Session) -> requests.Response: return response - def load(self) -> requests.Response: + def load(self) -> requests.Response | str | Any: if self._request_args["method"] == "GET": with sessions.Session() as session: return self._execute_request(session) + elif self._request_args["method"] in ["PUT", "POST"] and self.wrapped_dataset is not None: + return self.wrapped_dataset.load() raise DatasetError("Only GET method is supported for load") @@ -222,9 +236,19 @@ def _execute_save_request(self, json_data: Any) -> requests.Response: def save(self, data: Any) -> requests.Response: # type: ignore[override] if self._request_args["method"] in ["PUT", "POST"]: if isinstance(data, list): - return self._execute_save_with_chunks(json_data=data) - - return self._execute_save_request(json_data=data) + response: requests.Response = self._execute_save_with_chunks(json_data=data) + else: + response: requests.Response = self._execute_save_request(json_data=data) + + if self._wrapped_dataset is None: + return response + if self._extension == "json": + self.wrapped_dataset.save(response.json()) #TODO(npfp): expose json loads arguments + elif self._extension == "text": + self.wrapped_dataset.save(response.text) + elif self._extension: + self.wrapped_dataset.save(response) + return response raise DatasetError("Use PUT or POST methods for save") @@ -232,3 +256,32 @@ def _exists(self) -> bool: with sessions.Session() as session: response = self._execute_request(session) return response.ok + + @property + def _nested_dataset_type( + self, + ) -> Type[JSONDataset | PickleDataset | MemoryDataset]: + if self._extension == "json": + return JSONDataset + elif self._extension == "text": + return TextDataset + elif self._extension == "pickle": + return PickleDataset + elif self._extension == "memory": + #I'm not sure we need this + return MemoryDataset + else: + raise DatasetError( + f"Unknown extension for WrappedDataset: {self._extension}" + ) + + @property + def wrapped_dataset( + self, + ) -> JSONDataset | PickleDataset | MemoryDataset | None: + """The wrapped dataset where response data is stored.""" + if self._wrapped_dataset is None and self._extension is not None: + self._wrapped_dataset = self._nested_dataset_type( + **self._wrapped_dataset_args + ) + return self._wrapped_dataset