Skip to content

Commit

Permalink
fix(datasets): pass self for __wrapped__ calls
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <deepyaman.datta@utexas.edu>
  • Loading branch information
deepyaman committed Aug 2, 2024
1 parent 42201cf commit 3204803
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/json/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,6 @@ def preview(self) -> JSONPreview:
Returns:
A string representing the JSON data for previewing.
"""
data = self.load.__wrapped__() # type: ignore[attr-defined]
data = self.load.__wrapped__(self) # type: ignore[attr-defined]

return JSONPreview(json.dumps(data))
8 changes: 4 additions & 4 deletions kedro-datasets/kedro_datasets/spark/spark_hive_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def save(self, data: DataFrame) -> None:
self._validate_save(data)
if self._write_mode == "upsert":
# check if _table_pk is a subset of df columns
if not set(self._table_pk) <= set(self.load.__wrapped__().columns): # type: ignore[attr-defined]
if not set(self._table_pk) <= set(self.load.__wrapped__(self).columns): # type: ignore[attr-defined]
raise DatasetError(
f"Columns {str(self._table_pk)} selected as primary key(s) not found in "
f"table {self._full_table_address}"
Expand All @@ -165,13 +165,13 @@ def save(self, data: DataFrame) -> None:
self._create_hive_table(data=data)

def _upsert_save(self, data: DataFrame) -> None:
if not self._exists() or self.load.__wrapped__().rdd.isEmpty(): # type: ignore[attr-defined]
if not self._exists() or self.load.__wrapped__(self).rdd.isEmpty(): # type: ignore[attr-defined]
self._create_hive_table(data=data, mode="overwrite")
else:
_tmp_colname = "tmp_colname"
_tmp_row = "tmp_row"
_w = Window.partitionBy(*self._table_pk).orderBy(col(_tmp_colname).desc())
df_old = self.load.__wrapped__().select("*", lit(1).alias(_tmp_colname)) # type: ignore[attr-defined]
df_old = self.load.__wrapped__(self).select("*", lit(1).alias(_tmp_colname)) # type: ignore[attr-defined]
df_new = data.select("*", lit(2).alias(_tmp_colname))
df_stacked = df_new.unionByName(df_old).select(
"*", row_number().over(_w).alias(_tmp_row)
Expand All @@ -188,7 +188,7 @@ def _validate_save(self, data: DataFrame):
# or if the `write_mode` is set to overwrite
if (not self._exists()) or self._write_mode == "overwrite":
return
hive_dtypes = set(self.load.__wrapped__().dtypes) # type: ignore[attr-defined]
hive_dtypes = set(self.load.__wrapped__(self).dtypes) # type: ignore[attr-defined]
data_dtypes = set(data.dtypes)
if data_dtypes != hive_dtypes:
new_cols = data_dtypes - hive_dtypes
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/yaml/yaml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,6 @@ def preview(self) -> JSONPreview:
Returns:
A string representing the YAML data for previewing.
"""
data = self.load.__wrapped__() # type: ignore[attr-defined]
data = self.load.__wrapped__(self) # type: ignore[attr-defined]

return JSONPreview(json.dumps(data))
10 changes: 5 additions & 5 deletions kedro-datasets/tests/api/test_api_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def json_callback(request: requests.Request, context: Any) -> dict:
status_code=requests.codes.ok,
json=json_callback,
)
response = api_dataset._save(data)
response = api_dataset.save.__wrapped__(api_dataset, data)
assert isinstance(response, requests.Response)
assert response.json() == TEST_SAVE_DATA

Expand All @@ -312,7 +312,7 @@ def json_callback(request: requests.Request, context: Any) -> dict:
save_args={"params": TEST_PARAMS, "headers": TEST_HEADERS},
)
with pytest.raises(DatasetError, match="Use PUT or POST methods for save"):
api_dataset._save(TEST_SAVE_DATA)
api_dataset.save.__wrapped__(api_dataset, TEST_SAVE_DATA)
else:
with pytest.raises(
ValueError,
Expand Down Expand Up @@ -343,16 +343,16 @@ def json_callback(request: requests.Request, context: Any) -> dict:
headers=TEST_HEADERS,
json=json_callback,
)
response_list = api_dataset._save(TEST_SAVE_DATA)
response_list = api_dataset.save.__wrapped__(api_dataset, TEST_SAVE_DATA)
assert isinstance(response_list, requests.Response)
# check that the data was sent in the correct format
assert response_list.json() == TEST_SAVE_DATA

response_dict = api_dataset._save({"item1": "key1"})
response_dict = api_dataset.save.__wrapped__(api_dataset, {"item1": "key1"})
assert isinstance(response_dict, requests.Response)
assert response_dict.json() == {"item1": "key1"}

response_json = api_dataset._save(TEST_SAVE_DATA[0])
response_json = api_dataset.save.__wrapped__(api_dataset, TEST_SAVE_DATA[0])
assert isinstance(response_json, requests.Response)
assert response_json.json() == TEST_SAVE_DATA[0]

Expand Down

0 comments on commit 3204803

Please sign in to comment.