From c41f23d82faec1b340bc67e96a698f5910da62d4 Mon Sep 17 00:00:00 2001 From: Karl Higley Date: Tue, 7 Feb 2023 08:43:27 -0500 Subject: [PATCH] Add a new `shape` field to `ColumnSchema` (#195) * Improve robustness of test for converting schemas to dataframes This test was brittle to the addition of new fields in the `ColumnSchema` dataclass, but minor rework avoids that issue. * Add a new `shape` field to `ColumnSchema` This creates a place to store shape information for all dimensions of the data across both array/tensor and dataframe formats. In contrast to the existing "value_count" property (which only records the value counts of the lists in list field, this attribute is intended to capture the size of _all_ dimensions of the data (the batch dimension, the list lengths, embedding sizes, etc.) * Move `Shape` to `merlin.dtypes` and add tests * Compute `is_list` and `is_ragged` from `ColumnSchema.shape` * Remove shape from dtype when translating across frameworks For now, since all existing dtype translations rely on exact matching, we can drop the shape. In the future, when we add translations that need to know whether to use a list dtype or not, we'll have the information available here in the translation code. * Make the default `Shape()` represent unknown shapes * Ignore shapes when validating operator output dtypes * Fall back to the existing shape if there is one * Remove `Shape.fixed` property * Insert missing f-string * Use `DType.without_shape` * Make `None` shorthand for a dimension with unknown or unbounded min/max * Use whatever shape info is provided to fill in the rest This changes the way validation is done so that only the new shape info that's provided gets validated for consistency, and the rest gets inferred and filled in based on what was provided (assuming it's valid.) * Remove the value count min/max test This is now handled by the shape validation * Fix stray linter error * Minor test fix * Disable validation that shape info is provided when `is_ragged=False` * Add few convenience methods to `Shape` * Update `ColumnSchema.with_*` methods to clear existing shape info * Drop shapes from dtypes in `ColumnSchema` constructor * Fix `with_dtype` so dtypes don't overwrite the shape --- merlin/core/dispatch.py | 5 +- merlin/dag/executors.py | 5 +- merlin/dtypes/base.py | 54 +++++- merlin/dtypes/shape.py | 143 +++++++++++++++ merlin/schema/io/tensorflow_metadata.py | 8 +- merlin/schema/schema.py | 202 ++++++++++++++++----- tests/unit/dtypes/test_shape.py | 222 +++++++++++++++++++++++ tests/unit/schema/test_column_schemas.py | 95 ++++++---- tests/unit/schema/test_schema.py | 7 +- 9 files changed, 650 insertions(+), 91 deletions(-) create mode 100644 merlin/dtypes/shape.py create mode 100644 tests/unit/dtypes/test_shape.py diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index 0118c0f75..6b9c901e7 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -415,7 +415,10 @@ def parquet_writer_dispatch(df: DataFrameLike, path=None, **kwargs): elif cudf is not None: _cls = cudf.io.parquet.ParquetWriter else: - ValueError("Unable to load cudf. Please check your environment GPU and cudf available.") + raise ValueError( + "Unable to load cudf. " + "Please check that your environment has GPU(s) and cudf available." + ) if not path: return _cls diff --git a/merlin/dag/executors.py b/merlin/dag/executors.py index 9680a6ed1..2ff6cec92 100644 --- a/merlin/dag/executors.py +++ b/merlin/dag/executors.py @@ -196,8 +196,9 @@ def _transform_data(self, node, input_data, capture_dtypes=False): if ( output_col_schema.dtype and output_data_schema.dtype - and output_col_schema.dtype != md.string - and output_col_schema.dtype != output_data_schema.dtype + and output_col_schema.dtype.without_shape != md.string + and output_col_schema.dtype.without_shape + != output_data_schema.dtype.without_shape ): raise TypeError( f"Dtype discrepancy detected for column {col_name}: " diff --git a/merlin/dtypes/base.py b/merlin/dtypes/base.py index 75f5c77a8..05ea5ed26 100644 --- a/merlin/dtypes/base.py +++ b/merlin/dtypes/base.py @@ -14,11 +14,12 @@ # limitations under the License. # -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum -from typing import Optional +from typing import Optional, Tuple, Union from merlin.dtypes.registry import _dtype_registry +from merlin.dtypes.shape import Shape class ElementType(Enum): @@ -69,6 +70,11 @@ class DType: element_size: Optional[int] = None element_unit: Optional[ElementUnit] = None signed: Optional[bool] = None + shape: Optional[Shape] = None + + def __post_init__(self): + if not self.shape: + object.__setattr__(self, "shape", Shape()) def to(self, mapping_name: str): """ @@ -103,7 +109,7 @@ def to(self, mapping_name: str): ) from exc try: - return mapping.from_merlin(self) + return mapping.from_merlin(self.without_shape) except KeyError as exc: raise ValueError( f"The registered dtype mapping for {mapping_name} doesn't contain type {self.name}." @@ -125,3 +131,45 @@ def is_integer(self): @property def is_float(self): return self.element_type.value == "float" + + def with_shape(self, shape: Union[Tuple, Shape]): + """ + Create a copy of this dtype with a new shape + + Parameters + ---------- + shape : Union[Tuple, Shape] + Object to set as shape of dtype, must be either a tuple or Shape. + + Returns + ------- + DType + A copy of this dtype containing the provided shape value + + Raises + ------ + TypeError + If value is not either a tuple or a Shape + """ + if isinstance(shape, tuple): + shape = Shape(shape) + + if not isinstance(shape, Shape): + raise TypeError( + f"Provided value {shape} (of type {type(shape)}) for DType.shape property " + "is not of type Shape." + ) + + return replace(self, shape=shape) + + @property + def without_shape(self): + """ + Create a copy of this object without the shape + + Returns + ------- + DType + A copy of this object with the shape removed + """ + return self.with_shape(Shape()) diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py new file mode 100644 index 000000000..5f2666d95 --- /dev/null +++ b/merlin/dtypes/shape.py @@ -0,0 +1,143 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + + +@dataclass(frozen=True) +class Dimension: + """ + The range of potential sizes for a single dimension of a field or column + """ + + min: int = 0 + max: Optional[int] = None + + def __post_init__(self): + if self.min is None: + raise ValueError("The minimum size of a dimension cannot be None. ") + + if self.min < 0: + raise ValueError( + "The minimum size of a dimension must be non-negative. " f"Provided min: {self.min}" + ) + + if self.max and self.max < 0: + raise ValueError( + "The maximum size of a dimension must be at least one. " f"Provided max: {self.max}" + ) + + if self.max and self.max < self.min: + raise ValueError( + "The maximum size of a dimension must be at least as large as the minimum size. " + f"Provided min: {self.min} max: {self.max}" + ) + + @property + def is_bounded(self): + return self.max is not None + + @property + def is_fixed(self): + return self.is_bounded and self.min == self.max + + @property + def is_variable(self): + return not self.is_fixed + + +@dataclass(frozen=True) +class Shape: + """ + The range of potential sizes for all the dimensions of a field or column + """ + + dims: Optional[Union[Tuple, "Shape"]] = None + + def __post_init__(self): + if isinstance(self.dims, Shape): + object.__setattr__(self, "dims", self.dims.dims) + + if self.dims is not None: + new_dims = [] + for i, dim in enumerate(self.dims): + if isinstance(dim, Dimension): + new_dim = dim + elif isinstance(dim, tuple) and len(dim) == 2: + new_dim = Dimension(dim[0], dim[1]) + elif isinstance(dim, int): + new_dim = Dimension(dim, dim) + elif dim is None: + new_dim = Dimension() + else: + raise ValueError( + f"Invalid shape tuple format: {self.dims}. Each dimension is expected " + " to be None, a single integer, or a tuple with length 2." + ) + new_dims.append(new_dim) + + object.__setattr__(self, "dims", tuple(new_dims)) + + def __eq__(self, other): + """ + Make `dims is None` a wildcard when determining equality + + This definition of equality allows an unknown shape with `dims is None` to be + considered equal or compatible with a known shape with `dims is not None`. + """ + if not isinstance(other, Shape): + return False + + if self.dims is None or other.dims is None: + return True + + return self.dims == other.dims + + def __iter__(self): + return self.dims + + @property + def min(self) -> Tuple: + return tuple(dim.min for dim in self.dims) + + @property + def max(self) -> Tuple: + return tuple(dim.max for dim in self.dims) + + @property + def is_bounded(self): + return all(dim.is_bounded for dim in self.dims) + + @property + def is_fixed(self): + return all(dim.is_fixed for dim in self.dims) + + @property + def is_variable(self): + return not self.is_fixed + + @property + def is_list(self): + return self.dims is not None and len(self.dims) > 1 + + @property + def is_ragged(self): + return self.is_list and any(dim.min != dim.max for dim in self.dims[1:]) + + @property + def as_tuple(self): + return ((dim.min, dim.max) for dim in self.dims) if self.dims else None diff --git a/merlin/schema/io/tensorflow_metadata.py b/merlin/schema/io/tensorflow_metadata.py index b126ff0d6..108e645bf 100644 --- a/merlin/schema/io/tensorflow_metadata.py +++ b/merlin/schema/io/tensorflow_metadata.py @@ -270,8 +270,8 @@ def _pb_feature(column_schema): value_count = column_schema.properties.get("value_count", {}) if value_count: - min_length = value_count.get("min", 0) - max_length = value_count.get("max", 0) + min_length = value_count.get("min", 0) or 0 + max_length = value_count.get("max", 0) or 0 feature.value_count = ValueCount(min=min_length, max=max_length) feature.annotation.tag = _pb_tag(column_schema) @@ -323,9 +323,9 @@ def _merlin_value_count(feature): if proto_utils.has_field(feature, "value_count"): value_count = feature.value_count value_count_dict = {} - if value_count.min > 0: + if value_count.min and value_count.min > 0: value_count_dict["min"] = value_count.min - if value_count.max > 0: + if value_count.max and value_count.max > 0: value_count_dict["max"] = value_count.max return value_count_dict diff --git a/merlin/schema/schema.py b/merlin/schema/schema.py index 41e4311bb..7e38a460d 100644 --- a/merlin/schema/schema.py +++ b/merlin/schema/schema.py @@ -14,14 +14,15 @@ # limitations under the License. # -from dataclasses import dataclass, field, replace +from dataclasses import InitVar, dataclass, field, replace from enum import Enum -from typing import Dict, List, Optional, Text, Union +from typing import Dict, List, Optional, Text, Tuple, Union import pandas as pd import merlin.dtypes as md from merlin.dtypes import DType +from merlin.dtypes.shape import Shape from merlin.schema.tags import Tags, TagSet @@ -66,8 +67,9 @@ class ColumnSchema: dtype: Optional[DType] = None is_list: Optional[bool] = None is_ragged: Optional[bool] = None + dims: InitVar[Union[Tuple, Shape]] = None - def __post_init__(self): + def __post_init__(self, dims): """Standardize tags and dtypes on initialization This method works around the inability to set attributes on frozen dataclass @@ -78,41 +80,50 @@ def __post_init__(self): Raises: TypeError: If the provided dtype cannot be cast to a numpy dtype + ValueError: If the provided shape, value counts, and/or flags are inconsistent """ + # Provide defaults and minor conversions for convenience object.__setattr__(self, "tags", TagSet(self.tags)) - object.__setattr__(self, "dtype", md.dtype(self.dtype or md.unknown)) - # Validate the allowed range of value count - value_count = Domain(**self.properties.get("value_count", {})) - if value_count.min == 0 or value_count.max == 0: - raise ValueError( - "`value_count` min and max must be greater than zero. " - f"Provided min: {value_count.min} max: {value_count.max}" - ) + dtype = md.dtype(self.dtype or md.unknown).without_shape + object.__setattr__(self, "dtype", dtype) - if self.is_list is None: - object.__setattr__(self, "is_list", bool(value_count.max and value_count.max > 0)) + # Validate that everything provided is consistent + value_counts = self.properties.get("value_count", {}) + self._validate_shape_info(self.shape, value_counts, self.is_list, self.is_ragged) - if self.is_ragged is None: - if value_count.is_bounded and value_count.max > value_count.min: - object.__setattr__(self, "is_ragged", True) - elif value_count.is_bounded and value_count.max == value_count.min: - object.__setattr__(self, "is_ragged", False) - else: - object.__setattr__(self, "is_ragged", bool(self.is_list)) + # Pick which source to pull shape info from + if dims: + new_shape = Shape(dims) + elif dtype.shape.dims: + new_shape = dtype.shape + elif value_counts: + new_shape = self._shape_from_counts(Domain(**value_counts)) + elif self.is_list: + new_shape = self._shape_from_flags(self.is_list) + else: + new_shape = Shape() - if self.is_ragged and not self.is_list: - raise ValueError( - "`is_ragged` is set to `True` but `is_list` is not. " - "Only list columns can set the `is_ragged` flag." - ) + # Update the shape and propagate out to flags and value counts + dtype = dtype.with_shape(new_shape) + object.__setattr__(self, "dtype", dtype) + object.__setattr__(self, "is_list", dtype.shape.is_list) + object.__setattr__(self, "is_ragged", dtype.shape.is_ragged) - if self.is_ragged and value_count.is_bounded and value_count.min == value_count.max: - raise ValueError( - "`is_ragged` is set to `True` but `value_count.min` == `value_count.max`. " - "If value_count min/max are equal. " - "This is a fixed size list and `is_ragged` should be set to False. " - ) + if new_shape.dims is not None and len(new_shape.dims) > 1: + value_counts = {"min": new_shape.dims[1].min, "max": new_shape.dims[1].max} + properties = {**self.properties, **{"value_count": value_counts}} + object.__setattr__(self, "properties", properties) + + def _shape_from_flags(self, is_list): + return Shape(((0, None), (0, None))) if is_list else None + + def _shape_from_counts(self, value_count): + return Shape(((0, None), (value_count.min or 0, value_count.max))) + + @property + def shape(self): + return self.dtype.shape @property def quantity(self): @@ -185,17 +196,26 @@ def with_properties(self, properties: dict) -> "ColumnSchema": """ if not isinstance(properties, dict): - raise TypeError("properties must be in dict format, key: value") + raise TypeError("ColumnSchema properties must be a dictionary") # Using new dictionary to avoid passing old ref to new schema new_properties = {**self.properties, **properties} - is_ragged = self.is_ragged - value_count = Domain(**new_properties.get("value_count", {})) - if value_count.is_bounded and value_count.max == value_count.min: - is_ragged = False + value_counts = properties.get("value_count", {}) - return replace(self, properties=new_properties, is_ragged=is_ragged) + if value_counts: + return replace( + self, + properties=new_properties, + dtype=self.dtype.without_shape, + is_list=None, + is_ragged=None, + ) + else: + return replace( + self, + properties=new_properties, + ) def with_dtype(self, dtype, is_list: bool = None, is_ragged: bool = None) -> "ColumnSchema": """Create a copy of this ColumnSchema object with different column dtype @@ -217,14 +237,46 @@ def with_dtype(self, dtype, is_list: bool = None, is_ragged: bool = None) -> "Co Copied object with new column dtype """ - is_list = is_list if is_list is not None else self.is_list + new_dtype = md.dtype(dtype).with_shape(self.shape) - if is_list: - is_ragged = is_ragged if is_ragged is not None else self.is_ragged - else: - is_ragged = False + properties = self.properties.copy() + if is_list is not None or is_ragged is not None: + properties.pop("value_count", None) + new_dtype = new_dtype.without_shape + + return replace( + self, dtype=new_dtype, properties=properties, is_list=is_list, is_ragged=is_ragged + ) + + def with_shape(self, shape: Union[Tuple, Shape]) -> "ColumnSchema": + """ + Create a copy of this object with a new shape + + Parameters + ---------- + shape : Union[Tuple, Shape] + Object to set as shape, must be either a tuple or Shape. + + Returns + ------- + ColumnSchema + A copy of this object containing the provided shape value - return replace(self, dtype=dtype, is_list=is_list, is_ragged=is_ragged) + Raises + ------ + TypeError + If value is not either a tuple or a Shape + """ + dims = Shape(shape).as_tuple + properties = self.properties.copy() + properties.pop("value_count", None) + return replace( + self, + dims=dims, + properties=properties, + is_list=None, + is_ragged=None, + ) @property def int_domain(self) -> Optional[Domain]: @@ -240,12 +292,13 @@ def value_count(self) -> Optional[Domain]: return Domain(**value_count) if value_count else None def __merge__(self, other): - col_schema = self.with_tags(other.tags) - col_schema = col_schema.with_properties(other.properties) - col_schema = col_schema.with_dtype( - other.dtype, is_list=other.is_list, is_ragged=other.is_ragged + col_schema = ( + self.with_name(other.name) + .with_dtype(other.dtype) + .with_tags(other.tags) + .with_properties(other.properties) + .with_shape(other.shape) ) - col_schema = col_schema.with_name(other.name) return col_schema def __str__(self) -> str: @@ -256,6 +309,59 @@ def _domain(self) -> Optional[Domain]: domain = self.properties.get("domain") return Domain(**domain) if domain else None + def _validate_shape_info(self, shape, value_counts, is_list, is_ragged): + value_counts = value_counts or {} + + min_count = value_counts.get("min", None) + max_count = value_counts.get("max", None) + ragged_counts = min_count != max_count + + if shape and shape.dims is not None: + if is_ragged is not None and shape.is_ragged != is_ragged: + raise ValueError( + f"Provided value of `is_ragged={is_ragged}` " + f"is inconsistent with shape `{shape}`." + ) + elif is_list is not None and shape.is_list != is_list: + raise ValueError( + f"Provided value of `is_list={is_list}` " + f"is inconsistent with shape `{shape}`." + ) + + if value_counts and shape and shape.dims is not None: + if (min_count and min_count != shape.dims[1].min) or ( + max_count and max_count != shape.dims[1].max + ): + raise ValueError( + f"Provided value counts `{value_counts}` " + f"are inconsistent with shape `{shape}`." + ) + + if is_list is False and is_ragged is True: + raise ValueError( + "Columns with `is_list=False` can't set `is_ragged=True`, " + "since non-list columns can't be ragged." + ) + + if value_counts and is_ragged is not None and is_ragged != ragged_counts: + raise ValueError( + f"Provided value of `is_ragged={is_ragged}` " + f"is inconsistent with value counts `{value_counts}`." + ) + + # TODO: Enable this validation once we've removed these cases + # from downstream Merlin libraries + # if ( + # not value_counts + # and not (shape and shape.dims) + # and is_list is True + # and is_ragged is False + # ): + # raise ValueError( + # "Can't determine a shape for this column from " + # "`is_list=True` and `is_ragged=False` without value counts. " + # ) + class Schema: """A collection of column schemas for a dataset.""" diff --git a/tests/unit/dtypes/test_shape.py b/tests/unit/dtypes/test_shape.py new file mode 100644 index 000000000..5dd9a164a --- /dev/null +++ b/tests/unit/dtypes/test_shape.py @@ -0,0 +1,222 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +import merlin.dtypes as md +from merlin.dtypes.shape import Dimension, Shape + +# Dimension + + +def test_empty_dimension(): + dim = Dimension() + assert dim.min == 0 + assert dim.max is None + + +def test_min_max_val_dimension(): + dim = Dimension(2, 3) + assert dim.min == 2 + assert dim.max == 3 + + +def test_fixed_min_with_unbounded_max(): + dim = Dimension(2) + assert dim.min == 2 + assert dim.max is None + + dim = Dimension(2, None) + assert dim.min == 2 + assert dim.max is None + + +def test_min_is_none_raises_error(): + with pytest.raises(ValueError): + Dimension(None) + + with pytest.raises(ValueError): + Dimension(None, 1) + + +def test_bounds_must_be_non_negative(): + with pytest.raises(ValueError): + Dimension(-1, 2) + + with pytest.raises(ValueError): + Dimension(2, -1) + + +def test_max_less_than_min(): + with pytest.raises(ValueError): + Dimension(2, 1) + + +def test_is_bounded(): + dim = Dimension() + assert dim.is_bounded is False + + dim = Dimension(2) + assert dim.is_bounded is False + + dim = Dimension(2, 2) + assert dim.is_bounded is True + + dim = Dimension(2, 4) + assert dim.is_bounded is True + + dim = Dimension(2, None) + assert dim.is_bounded is False + + +def test_is_fixed(): + dim = Dimension() + assert dim.is_fixed is False + + dim = Dimension(2) + assert dim.is_fixed is False + + dim = Dimension(2, 2) + assert dim.is_fixed is True + + dim = Dimension(2, 4) + assert dim.is_fixed is False + + dim = Dimension(2, None) + assert dim.is_fixed is False + + +def test_is_variable(): + dim = Dimension() + assert dim.is_variable is True + + dim = Dimension(2) + assert dim.is_variable is True + + dim = Dimension(2, 2) + assert dim.is_variable is False + + dim = Dimension(2, 4) + assert dim.is_variable is True + + dim = Dimension(2, None) + assert dim.is_variable is True + + +# Shape + + +def test_shape_without_args_represents_unknown(): + shape = Shape() + assert shape.dims is None + assert shape.is_list is False + assert shape.is_ragged is False + + +def test_shape_with_empty_tuple_represents_scalar(): + shape = Shape(()) + assert shape.dims == () + assert shape.is_list is False + assert shape.is_ragged is False + + +def test_flat_tuple_creates_fixed_shape(): + shape = Shape((1, 2, 3)) + assert shape.is_fixed is True + + +def test_none_is_shorthand_for_unknown_unbounded(): + shape = Shape((None, (4, 16))) + assert shape == Shape(((0, None), (4, 16))) + + +def test_nested_tuple_creates_variable_shape(): + shape = Shape(((5, 5), (2, 2), (3, 3))) + assert shape.is_variable is False + + shape = Shape(((1, 3), (2, 2), (3, 4))) + assert shape.is_variable is True + + shape = Shape(((1, 3), (2, 4), (4, 4))) + assert shape.is_variable is True + + shape = Shape(((1, 3), (2, 4), (4, 7))) + assert shape.is_variable is True + + +def test_mixed_tuple_creates_variable_shape(): + shape = Shape((5, (2, 3), 4)) + assert shape.is_variable is True + + +def test_nested_tuple_error(): + with pytest.raises(ValueError): + Shape((5, (2, None), (4, 5, 6))) + + with pytest.raises(ValueError): + Shape((5.3, (2, None), (4, 6))) + + with pytest.raises(ValueError): + Shape(("asdf", (2, None), (4, 6))) + + +def test_shape_properties(): + shape = Shape((5,)) + assert shape.is_fixed is True + assert shape.is_variable is False + assert shape.is_bounded is True + assert shape.is_ragged is False + assert shape.is_list is False + + shape = Shape((5, 1)) + assert shape.is_fixed is True + assert shape.is_variable is False + assert shape.is_bounded is True + assert shape.is_ragged is False + assert shape.is_list is True + + shape = Shape(((5, None), 2, 4)) + assert shape.is_fixed is False + assert shape.is_variable is True + assert shape.is_bounded is False + assert shape.is_ragged is False + assert shape.is_list is True + + shape = Shape((5, (2, None), 4)) + assert shape.is_fixed is False + assert shape.is_variable is True + assert shape.is_bounded is False + assert shape.is_ragged is True + assert shape.is_list is True + + shape = Shape((5, 2, (4, None))) + assert shape.is_fixed is False + assert shape.is_variable is True + assert shape.is_bounded is False + assert shape.is_ragged is True + assert shape.is_list is True + + +# DType + + +def test_dtype_has_a_shape(): + assert md.int32.shape == Shape() + + +def test_dtype_with_shape(): + dtype = md.int32.with_shape((3, 4, 5)) + assert dtype.shape != (3, 4, 5) + assert dtype.shape == Shape((3, 4, 5)) diff --git a/tests/unit/schema/test_column_schemas.py b/tests/unit/schema/test_column_schemas.py index 4dececf0f..dadff7e9d 100644 --- a/tests/unit/schema/test_column_schemas.py +++ b/tests/unit/schema/test_column_schemas.py @@ -17,6 +17,7 @@ import pytest import merlin.dtypes as md +from merlin.dtypes.shape import Shape from merlin.schema import ColumnSchema from merlin.schema.schema import ColumnQuantity from merlin.schema.tags import Tags, TagSet @@ -196,38 +197,16 @@ def test_list_column_attributes(): assert col3_schema.is_ragged assert col3_schema.quantity == ColumnQuantity.RAGGED_LIST - col4_schema = ColumnSchema("col4", is_list=True, is_ragged=False) + # TODO: Re-enable this test case once we've addressed cases + # like this in downstream libraries - assert col4_schema.is_list - assert not col4_schema.is_ragged - assert col4_schema.quantity == ColumnQuantity.FIXED_LIST + # with pytest.raises(ValueError): + # ColumnSchema("col4", is_list=True, is_ragged=False) with pytest.raises(ValueError): ColumnSchema("col5", is_list=False, is_ragged=True) -def test_value_count_invalid_min_max(): - with pytest.raises(ValueError) as exc_info: - ColumnSchema("col", is_ragged=True, properties={"value_count": {"min": 2, "max": 2}}) - assert "`is_ragged` is set to `True` but `value_count.min` == `value_count.max`" in str( - exc_info.value - ) - - -@pytest.mark.parametrize( - "properties", - [ - {"value_count": {"max": 0}}, - {"value_count": {"min": 0}}, - {"value_count": {"min": 0, "max": 2}}, - ], -) -def test_value_count_zero_min_max(properties): - with pytest.raises(ValueError) as exc_info: - ColumnSchema("col", is_ragged=True, properties=properties) - assert "`value_count` min and max must be greater than zero. " in str(exc_info.value) - - @pytest.mark.parametrize( ["value_count_min", "value_count_max"], [ @@ -246,11 +225,63 @@ def test_value_count(value_count_min, value_count_max): col_schema = ColumnSchema("col", properties={"value_count": value_count}) assert col_schema.value_count.max == value_count_max - assert col_schema.value_count.min == value_count_min + assert col_schema.value_count.min == (value_count_min or 0) + + +def test_value_count_inconsistency_with_flags(): + with pytest.raises(ValueError) as exc_info: + ColumnSchema( + "col", properties={"value_count": {"min": 5, "max": 5}}, is_list=True, is_ragged=True + ) + assert "Provided value of `is_ragged=True` is inconsistent with value counts" in str( + exc_info.value + ) + + +def test_column_schema_with_shape(): + col_schema = ColumnSchema("col") + assert col_schema.shape == Shape() + + col_schema = ColumnSchema("col", dtype=md.int32.with_shape((3, 4, 5))) + assert col_schema.shape != (3, 4, 5) + assert col_schema.shape == Shape((3, 4, 5)) + + col_schema = ColumnSchema("col", dims=(3, 4, 5)) + assert col_schema.shape != (3, 4, 5) + assert col_schema.shape == Shape((3, 4, 5)) + + col_schema = ColumnSchema("col").with_shape((3, 4, 5)) + assert col_schema.shape != (3, 4, 5) + assert col_schema.shape == Shape((3, 4, 5)) + + +def test_setting_value_counts_updates_shape_and_flags(): + col_schema = ColumnSchema("col", dims=(None,)) + + counts = {"min": 4, "max": 5} + updated_schema = col_schema.with_properties({"value_count": counts}) + + assert updated_schema.properties["value_count"] == counts + assert updated_schema.shape == Shape((None, (4, 5))) + assert updated_schema.is_list + assert updated_schema.is_ragged + + +def test_setting_shape_updates_value_counts_and_flags(): + col_schema = ColumnSchema("col") + updated_schema = col_schema.with_shape((64, (4, 16))) + + assert updated_schema.shape == Shape((64, (4, 16))) + assert updated_schema.properties["value_count"] == {"min": 4, "max": 16} + assert updated_schema.is_list + assert updated_schema.is_ragged + +def test_setting_flags_updates_shape_and_value_counts(): + col_schema = ColumnSchema("col") + updated_schema = col_schema.with_dtype(md.int64, is_list=True, is_ragged=True) -def test_value_count_assign_properties(): - col_schema = ColumnSchema("col", is_list=True, is_ragged=True) - new_col_schema = col_schema.with_properties({"value_count": {"min": 5, "max": 5}}) - assert new_col_schema.is_ragged is False - assert new_col_schema.value_count.min == new_col_schema.value_count.max == 5 + assert updated_schema.shape == Shape((None, None)) + assert updated_schema.properties["value_count"] == {"min": 0, "max": None} + assert updated_schema.is_list + assert updated_schema.is_ragged diff --git a/tests/unit/schema/test_schema.py b/tests/unit/schema/test_schema.py index 1a648c257..239f89441 100644 --- a/tests/unit/schema/test_schema.py +++ b/tests/unit/schema/test_schema.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import dataclasses + import pytest from merlin.dag import ColumnSelector @@ -157,8 +159,11 @@ def test_schema_to_pandas(): schema_set = Schema(["a", "b", "c"]) df = schema_set.to_pandas() + expected_columns = [field.name for field in dataclasses.fields(ColumnSchema)] + expected_columns.remove("properties") + assert isinstance(df, pd.DataFrame) - assert list(df.columns) == ["name", "tags", "dtype", "is_list", "is_ragged"] + assert list(df.columns) == expected_columns def test_construct_schema_with_column_names():