Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix StructuredDataset empty-str file_format in dc attr access #3027

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
9 changes: 9 additions & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,10 +737,19 @@
# return StructuredDataset(uri=uri)
if python_val.dataframe is None:
uri = python_val.uri
file_format = python_val.file_format

Check warning on line 740 in flytekit/types/structured/structured_dataset.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/structured/structured_dataset.py#L740

Added line #L740 was not covered by tests

# Check the user-specified uri
if not uri:
raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}")
if not ctx.file_access.is_remote(uri):
uri = await ctx.file_access.async_put_raw_data(uri)

# Check the user-specified file_format
# When users specify file_format for a StructuredDataset, the file_format information must be retained.
# For details, please refer to https://github.com/flyteorg/flyte/issues/6096.
sdt.format = file_format

Check warning on line 751 in flytekit/types/structured/structured_dataset.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/structured/structured_dataset.py#L751

Added line #L751 was not covered by tests

sd_model = literals.StructuredDataset(
uri=uri,
metadata=StructuredDatasetMetadata(structured_dataset_type=sdt),
Expand Down
46 changes: 46 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import uuid
import pytest
from unittest import mock
from dataclasses import dataclass

from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase
from flytekit.configuration import Config, ImageConfig, SerializationSettings
Expand All @@ -26,6 +27,7 @@
from flytekit.remote.remote import FlyteRemote
from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset
from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient
from flytekit.configuration import PlatformConfig

Expand Down Expand Up @@ -867,6 +869,50 @@ def test_attr_access_sd():
bucket, key = url.netloc, url.path.lstrip("/")
file_transfer.delete_file(bucket=bucket, key=key)


def test_sd_attr():
"""Test correctness of StructuredDataset attributes.

This test considers only the following condition:
1. Check StructuredDataset (wrapped in a dataclass) file_format attribute

We'll make sure uri aligns with the user-specified one in the future.
"""
from workflows.basic.sd_attr import wf

@dataclass
class DC:
sd: StructuredDataset

FILE_FORMAT = "parquet"

# Upload a file to minio s3 bucket
file_transfer = SimpleFileTransfer()
remote_file_path = file_transfer.upload_file(file_type=FILE_FORMAT)

# Create a dataclass as the workflow input because `pyflyte run`
# can't properly handle input arg `dc` as a json str so far
dc = DC(sd=StructuredDataset(uri=remote_file_path, file_format=FILE_FORMAT))

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True)
wf_exec = remote.execute(
wf,
inputs={"dc": dc, "file_format": FILE_FORMAT},
wait=True,
version=VERSION,
image_config=ImageConfig.from_images(IMAGE),
)
assert wf_exec.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {wf_exec.closure.phase}"
assert wf_exec.outputs["o0"].file_format == FILE_FORMAT, (
f"Workflow output StructuredDataset file_format should align with the user-specified file_format: {FILE_FORMAT}."
)

# Delete the remote file to free the space
url = urlparse(remote_file_path)
bucket, key = url.netloc, url.path.lstrip("/")
file_transfer.delete_file(bucket=bucket, key=key)


def test_signal_approve_reject(register):
from flytekit.models.types import LiteralType, SimpleType
from time import sleep
Expand Down
68 changes: 68 additions & 0 deletions tests/flytekit/integration/remote/workflows/basic/sd_attr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from dataclasses import dataclass

import pandas as pd
from flytekit import task, workflow
from flytekit.types.structured import StructuredDataset


@dataclass
class DC:
sd: StructuredDataset


@task
def create_dc(uri: str, file_format: str) -> DC:
"""Create a dataclass with a StructuredDataset attribute.

Args:
uri: File URI.
file_format: File format, e.g., parquet, csv.

Returns:
dc: A dataclass with a StructuredDataset attribute.
"""
dc = DC(sd=StructuredDataset(uri=uri, file_format=file_format))

return dc


@task
def check_file_format(sd: StructuredDataset, true_file_format: str) -> StructuredDataset:
"""Check StructuredDataset file_format attribute.

StruturedDataset file_format should align with what users specify.

Args:
sd: Python native StructuredDataset.
true_file_format: User-specified file_format.
"""
assert sd.file_format == true_file_format, (
f"StructuredDataset file_format should align with the user-specified file_format: {true_file_format}."
)
assert sd._literal_sd.metadata.structured_dataset_type.format == true_file_format, (
f"StructuredDatasetType format should align with the user-specified file_format: {true_file_format}."
)
print(f">>> SD <<<\n{sd}")
print(f">>> Literal SD <<<\n{sd._literal_sd}")
print(f">>> SDT <<<\n{sd._literal_sd.metadata.structured_dataset_type}")
print(f">>> DF <<<\n{sd.open(pd.DataFrame).all()}")

return sd


@workflow
def wf(dc: DC, file_format: str) -> StructuredDataset:
# Fail to use dc.sd.file_format as the input
sd = check_file_format(sd=dc.sd, true_file_format=file_format)

return sd


if __name__ == "__main__":
# Define inputs
uri = "tests/flytekit/integration/remote/workflows/basic/data/df.parquet"
file_format = "parquet"

dc = create_dc(uri=uri, file_format=file_format)
sd = wf(dc=dc, file_format=file_format)
print(sd.file_format)
Loading