Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable defining nested data types #193

Merged
merged 9 commits into from
Jun 9, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ consumes:
embeddings:
fields:
data:
type: float32_list
type: array
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
items:
type: float32

produces:
images:
Expand Down
4 changes: 3 additions & 1 deletion components/image_embedding/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ produces:
embeddings:
fields:
data:
type: float32_list
type: array
items:
type: float32

args:
model_id:
Expand Down
4 changes: 3 additions & 1 deletion components/segment_images/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ produces:
segmentations:
fields:
data:
type: binary
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
type: array
items:
type: binary

args:
model_id:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ consumes:
segmentations:
fields:
data:
type: binary
type: array
items:
type: binary

args:
hf_token:
Expand Down
2 changes: 1 addition & 1 deletion fondant/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __repr__(self) -> str:
def fields(self) -> t.Mapping[str, Field]:
return types.MappingProxyType(
{
name: Field(name=name, type=Type[field["type"]])
name: Field(name=name, type=Type.from_json(field))
for name, field in self._specification["fields"].items()
}
)
Expand Down
6 changes: 3 additions & 3 deletions fondant/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def fields(self) -> t.Mapping[str, Field]:
"""The fields of the subset returned as an immutable mapping."""
return types.MappingProxyType(
{
name: Field(name=name, type=Type[field["type"]])
name: Field(name=name, type=Type.from_json(field))
for name, field in self._specification["fields"].items()
}
)
Expand All @@ -62,8 +62,8 @@ class Index(Subset):
@property
def fields(self) -> t.Dict[str, Field]:
return {
"id": Field(name="id", type=Type.string),
"source": Field(name="source", type=Type.string),
"id": Field(name="id", type=Type("string")),
"source": Field(name="source", type=Type("string")),
}


Expand Down
157 changes: 115 additions & 42 deletions fondant/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,129 @@
and pipelines.
"""

import enum
import typing as t

import pyarrow as pa

KubeflowCommandArguments = t.List[t.Union[str, t.Dict[str, str]]]


class Type(enum.Enum):
"""Supported types.

Based on:
- https://arrow.apache.org/docs/python/api/datatypes.html#api-types
- https://pola-rs.github.io/polars/py-polars/html/reference/datatypes.html
"""
Types based on:
- https://arrow.apache.org/docs/python/api/datatypes.html#api-types
- https://pola-rs.github.io/polars/py-polars/html/reference/datatypes.html
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now just based on arrow, right?

"""
_TYPES: t.Dict[str, pa.DataType] = {
"null": pa.null(),
"bool": pa.bool_(),
"int8": pa.int8(),
"int16": pa.int16(),
"int32": pa.int32(),
"int64": pa.int64(),
"uint8": pa.uint8(),
"uint16": pa.uint16(),
"uint32": pa.uint32(),
"uint64": pa.uint64(),
"float16": pa.float16(),
"float32": pa.float32(),
"float64": pa.float64(),
"decimal128": pa.decimal128(38),
"time32": pa.time32("s"),
"time64": pa.time64("us"),
"timestamp": pa.timestamp("us"),
"date32": pa.date32(),
"date64": pa.date64(),
"duration": pa.duration("us"),
"string": pa.string(),
"utf8": pa.utf8(),
"binary": pa.binary(),
"large_binary": pa.large_binary(),
"large_utf8": pa.large_utf8(),
}


class Type:
"""
The `Type` class provides a way to define and validate data types for various purposes. It
supports different data types including primitive types and complex types like lists.
"""

bool = pa.bool_()

int8 = pa.int8()
int16 = pa.int16()
int32 = pa.int32()
int64 = pa.int64()

uint8 = pa.uint8()
uint16 = pa.uint16()
uint32 = pa.uint32()
uint64 = pa.uint64()

float16 = pa.float16()
float32 = pa.float32()
float64 = pa.float64()

decimal = pa.decimal128(38)

time32 = pa.time32("s")
time64 = pa.time64("us")
timestamp = pa.timestamp("us")

date32 = pa.date32()
date64 = pa.date64()
duration = pa.duration("us")

string = pa.string()
utf8 = pa.utf8()

binary = pa.binary()

int8_list = pa.list_(pa.int8())

float32_list = pa.list_(pa.float32())
def __init__(self, data_type: t.Union[str, pa.DataType]):
self.value = self._validate_data_type(data_type)

@staticmethod
def _validate_data_type(data_type: t.Union[str, pa.DataType]):
"""
Validates the provided data type and returns the corresponding data type object.

Args:
data_type: The data type to validate.

Returns:
The validated `pa.DataType` object.
"""
if isinstance(data_type, str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(data_type, str):
if not isinstance(data_type, Type):

I think this is a bit more robust. What if I pass in something that is not a Type nor a string? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

try:
data_type = _TYPES[data_type]
except KeyError:
raise ValueError(
f"Invalid schema provided. Current available data types are:"
f" {_TYPES.keys()}"
)
return data_type

@classmethod
def list(cls, data_type: t.Union[str, pa.DataType, "Type"]) -> "Type":
"""
Creates a new `Type` instance representing a list of the specified data type.

Args:
data_type: The data type for the list elements. It can be a string representing the
data type or an existing `pa.DataType` object.

Returns:
A new `Type` instance representing a list of the specified data type.
"""
data_type = cls._validate_data_type(data_type)
return cls(
pa.list_(data_type.value if isinstance(data_type, Type) else data_type)
)

@classmethod
def from_json(cls, json_schema: dict):
"""
Creates a new `Type` instance based on a dictionary representation of the json schema
of a data type (https://swagger.io/docs/specification/data-models/data-types/).

Args:
json_schema: The dictionary representation of the data type, can represent nested values

Returns:
A new `Type` instance representing the specified data type.
"""
if json_schema["type"] in _TYPES:
return Type(json_schema["type"])

elif json_schema["type"] == "array":
items = json_schema["items"]
if isinstance(items, dict):
return cls.list(Type.from_json(items))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return cls.list(Type.from_json(items))
return cls.list(cls.from_json(items))

else:
raise ValueError(f"Invalid schema provided: {json_schema}")

@property
def name(self):
"""Name of the data type."""
return str(self.value)

def __repr__(self):
"""Returns a string representation of the `Type` instance."""
return f"Type({repr(self.value)})"

def __eq__(self, other):
if isinstance(other, Type):
return self.value == other.value

return False


class Field(t.NamedTuple):
Expand Down
21 changes: 16 additions & 5 deletions fondant/schemas/common.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
"binary",
"list",
"struct",
"int8_list",
"float32_list"
"array"
]
},
"field": {
Expand All @@ -36,11 +35,23 @@
"type": {
"type": "string",
"$ref": "#/definitions/subset_data_type"
},
"items": {
"oneOf": [
{
"$ref": "#/definitions/field"
},
{
"type": "array",
"items": {
"$ref": "#/definitions/field"
}
}
]
}
},
"required": [
"type"
]
"required": ["type"],
"additionalProperties": false
},
"fields": {
"type": "object",
Expand Down
4 changes: 3 additions & 1 deletion tests/example_specs/component_specs/valid_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ consumes:
embeddings:
fields:
data:
type: int8_list
type: array
items:
type: binary

produces:
captions:
Expand Down
10 changes: 3 additions & 7 deletions tests/test_component_specs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Fondant component specs test."""
from pathlib import Path

import pyarrow as pa
import pytest
import yaml

Expand Down Expand Up @@ -53,13 +52,10 @@ def test_attribute_access(valid_fondant_schema):

assert fondant_component.name == "Example component"
assert fondant_component.description == "This is an example component"
assert fondant_component.consumes["images"].fields["data"].type == Type.binary
assert (
fondant_component.consumes["embeddings"].fields["data"].type == Type.int8_list
assert fondant_component.consumes["images"].fields["data"].type == Type("binary")
assert fondant_component.consumes["embeddings"].fields["data"].type == Type.list(
Type("binary")
)
assert fondant_component.consumes["embeddings"].fields[
"data"
].type.value == pa.list_(pa.int8())


def test_kfp_component_creation(valid_fondant_schema, valid_kubeflow_schema):
Expand Down
22 changes: 13 additions & 9 deletions tests/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,16 @@ def test_subset_fields():
subset = Subset(specification=subset_spec, base_path="/tmp")

# add a field
subset.add_field(name="data2", type_=Type.binary)
subset.add_field(name="data2", type_=Type("binary"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit unfortunate that we lose the ability to reference types by attribute (and therefore autocomplete, static analysis, etc., ...). Unfortunately I don't see a straightforward way to keep them either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah indeed, you would need to know/type all possible combinations which is not very feasible. But it's not user-facing so it should not have that big of an impact.

assert "data2" in subset.fields

# add a duplicate field
with pytest.raises(ValueError):
subset.add_field(name="data2", type_=Type.binary)
subset.add_field(name="data2", type_=Type("binary"))

# add a duplicate field but overwrite
subset.add_field(name="data2", type_=Type.string, overwrite=True)
assert subset.fields["data2"].type == Type.string
subset.add_field(name="data2", type_=Type("string"), overwrite=True)
assert subset.fields["data2"].type.value == Type("string").value

# remove a field
subset.remove_field(name="data2")
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_attribute_access(valid_manifest):
assert manifest.metadata == valid_manifest["metadata"]
assert manifest.index.location == "gs://bucket/index"
assert manifest.subsets["images"].location == "gs://bucket/images"
assert manifest.subsets["images"].fields["data"].type == Type.binary
assert manifest.subsets["images"].fields["data"].type == Type("binary")


def test_manifest_creation():
Expand All @@ -123,8 +123,8 @@ def test_manifest_creation():
manifest = Manifest.create(
base_path=base_path, run_id=run_id, component_id=component_id
)
manifest.add_subset("images", [("width", Type.int32), ("height", Type.int32)])
manifest.subsets["images"].add_field("data", Type.binary)
manifest.add_subset("images", [("width", Type("int32")), ("height", Type("int32"))])
manifest.subsets["images"].add_field("data", Type("binary"))

assert manifest._specification == {
"metadata": {
Expand Down Expand Up @@ -166,12 +166,16 @@ def test_manifest_alteration(valid_manifest):
manifest = Manifest(valid_manifest)

# test adding a subset
manifest.add_subset("images2", [("width", Type.int32), ("height", Type.int32)])
manifest.add_subset(
"images2", [("width", Type("int32")), ("height", Type("int32"))]
)
assert "images2" in manifest.subsets

# test adding a duplicate subset
with pytest.raises(ValueError):
manifest.add_subset("images2", [("width", Type.int32), ("height", Type.int32)])
manifest.add_subset(
"images2", [("width", Type("int32")), ("height", Type("int32"))]
)

# test removing a subset
manifest.remove_subset("images2")
Expand Down
Loading