diff --git a/.github/workflows/docs-sched-rebuild.yaml b/.github/workflows/docs-sched-rebuild.yaml index accb943f4..a3e323464 100644 --- a/.github/workflows/docs-sched-rebuild.yaml +++ b/.github/workflows/docs-sched-rebuild.yaml @@ -1,9 +1,10 @@ name: docs-sched-rebuild on: - schedule: - # * is a special character in YAML so you have to quote this string - - cron: "0 0 * * *" + push: + branches: [main] + tags: + - v* workflow_dispatch: jobs: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e97668ada..e0ad731ff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: hooks: - id: absolufy-imports - repo: https://github.com/timothycrosley/isort - rev: 5.11.2 + rev: 5.12.0 hooks: - id: isort additional_dependencies: [toml] @@ -26,11 +26,11 @@ repos: exclude: ^docs/ # code style - repo: https://github.com/python/black - rev: 22.12.0 + rev: 23.1.0 hooks: - id: black - repo: https://github.com/pycqa/pylint - rev: v2.15.8 + rev: v2.16.0 hooks: - id: pylint - repo: https://github.com/pycqa/flake8 diff --git a/ci/pr.gpu.Jenkinsfile b/ci/pr.gpu.Jenkinsfile index 81aa9b196..2fc78a097 100644 --- a/ci/pr.gpu.Jenkinsfile +++ b/ci/pr.gpu.Jenkinsfile @@ -2,7 +2,7 @@ pipeline { agent { docker { image 'nvcr.io/nvstaging/merlin/merlin-ci-runner-wrapper' - label 'merlin_gpu' + label 'merlin_gpu_gcp || merlin_gpu' registryCredentialsId 'jawe-nvcr-io' registryUrl 'https://nvcr.io' args "--runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=all" diff --git a/merlin/core/compat.py b/merlin/core/compat.py index d372f8a73..cd3113681 100644 --- a/merlin/core/compat.py +++ b/merlin/core/compat.py @@ -13,19 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. # -try: - from numba import cuda +import os +try: + from numba import cuda # pylint: disable=unused-import except ImportError: cuda = None -HAS_GPU = False -try: - from dask.distributed.diagnostics import nvml +from dask.distributed.diagnostics import nvml - HAS_GPU = nvml.device_get_count() > 0 -except ImportError: - # We can use `cuda` to set `HAS_GPU` now that we - # know `distributed` is not installed (otherwise - # the `nvml` import would have succeeded) - HAS_GPU = cuda is not None + +def _get_gpu_count(): + """Get Number of GPU devices accounting for CUDA_VISIBLE_DEVICES environment variable""" + # Using the `dask.distributed.diagnostics.nvml.device_get_count` + # helper function from dask to check device counts with NVML + # since this handles some complexity of checking NVML state for us. + + # Note: We can't use `numba.cuda.gpus`, since this has some side effects + # that are incompatible with Dask-CUDA. If CUDA runtime functions are + # called before Dask-CUDA can spawn worker processes + # then Dask-CUDA it will not work correctly (raises an exception) + nvml_device_count = nvml.device_get_count() + if nvml_device_count == 0: + return 0 + try: + cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] + if cuda_visible_devices: + return len(cuda_visible_devices.split(",")) + else: + return 0 + except KeyError: + return nvml_device_count + + +HAS_GPU = _get_gpu_count() > 0 diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index 0118c0f75..771bfa9c6 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -34,7 +34,6 @@ if HAS_GPU: try: import cudf # type: ignore[no-redef] - import cupy as cp # type: ignore[no-redef] import dask_cudf import rmm # type: ignore[no-redef] from cudf.core.column import as_column, build_column @@ -47,10 +46,12 @@ # cudf < 21.08 from cudf.utils.dtypes import is_list_dtype as cudf_is_list_dtype from cudf.utils.dtypes import is_string_dtype as cudf_is_string_dtype - except ImportError: - HAS_GPU = False - + pass + try: + import cupy as cp # type: ignore[no-redef] + except ImportError: + pass try: # Dask >= 2021.5.1 @@ -76,7 +77,7 @@ def inner2(*args, **kwargs): return inner1 -if HAS_GPU: +if HAS_GPU and cudf: DataFrameType = Union[pd.DataFrame, cudf.DataFrame] # type: ignore SeriesType = Union[pd.Series, cudf.Series] # type: ignore else: @@ -415,7 +416,10 @@ def parquet_writer_dispatch(df: DataFrameLike, path=None, **kwargs): elif cudf is not None: _cls = cudf.io.parquet.ParquetWriter else: - ValueError("Unable to load cudf. Please check your environment GPU and cudf available.") + raise ValueError( + "Unable to load cudf. " + "Please check that your environment has GPU(s) and cudf available." + ) if not path: return _cls @@ -489,16 +493,21 @@ def concat(objs, **kwargs): def make_df(_like_df=None, device=None): """Return a DataFrame with the same dtype as `_like_df`""" - if not cudf or isinstance(_like_df, (pd.DataFrame, pd.Series)): - return pd.DataFrame(_like_df) - elif isinstance(_like_df, (cudf.DataFrame, cudf.Series)): + if not cudf or device == "cpu" or isinstance(_like_df, (pd.DataFrame, pd.Series)): + # move to pandas need it on CPU (host memory) + # can be a cudf, cupy or numpy Series + if cudf and isinstance(_like_df, (cudf.DataFrame, cudf.Series)): + # move to cpu + return _like_df.to_pandas() + if cp and isinstance(_like_df, cp.ndarray): + return pd.DataFrame(_like_df.get()) + else: + return pd.DataFrame(_like_df) + else: + if isinstance(_like_df, dict) and len(_like_df) > 0: + if all(isinstance(v, pd.Series) for v in _like_df.values()): + return pd.DataFrame(_like_df) return cudf.DataFrame(_like_df) - elif device is None and isinstance(_like_df, dict) and len(_like_df) > 0: - is_pandas = all(isinstance(v, pd.Series) for v in _like_df.values()) - return pd.DataFrame(_like_df) if is_pandas else cudf.DataFrame(_like_df) - if device == "cpu": - return pd.DataFrame(_like_df) - return cudf.DataFrame(_like_df) def make_series(_like_ser=None, device=None): diff --git a/merlin/dag/__init__.py b/merlin/dag/__init__.py index 1818f8601..eb9b6c96d 100644 --- a/merlin/dag/__init__.py +++ b/merlin/dag/__init__.py @@ -20,3 +20,4 @@ from merlin.dag.graph import Graph from merlin.dag.node import Node, iter_nodes, postorder_iter_nodes, preorder_iter_nodes from merlin.dag.selector import ColumnSelector +from merlin.dag.utils import group_values_offsets, ungroup_values_offsets diff --git a/merlin/dag/base_operator.py b/merlin/dag/base_operator.py index 85e96117a..0035a004e 100644 --- a/merlin/dag/base_operator.py +++ b/merlin/dag/base_operator.py @@ -146,10 +146,8 @@ def compute_output_schema( output_schema = Schema() for output_col_name, input_col_names in self.column_mapping(col_selector).items(): - col_schema = ColumnSchema(output_col_name) - col_schema = self._compute_dtype(col_schema, input_schema[input_col_names]) - col_schema = self._compute_tags(col_schema, input_schema[input_col_names]) - col_schema = self._compute_properties(col_schema, input_schema[input_col_names]) + input_schema_fragment = input_schema[input_col_names] + col_schema = self.compute_column_schema(output_col_name, input_schema_fragment) output_schema += Schema([col_schema]) if self.dynamic_dtypes and prev_output_schema: @@ -226,7 +224,12 @@ def column_mapping(self, col_selector): return column_mapping def compute_column_schema(self, col_name, input_schema): - methods = [self._compute_dtype, self._compute_tags, self._compute_properties] + methods = [ + self._compute_dtype, + self._compute_tags, + self._compute_properties, + self._compute_shape, + ] return self._compute_column_schema(col_name, input_schema, methods=methods) def _compute_column_schema(self, col_name, input_schema, methods=None): @@ -239,21 +242,24 @@ def _compute_column_schema(self, col_name, input_schema, methods=None): def _compute_dtype(self, col_schema, input_schema): dtype = col_schema.dtype - is_list = col_schema.is_list - is_ragged = col_schema.is_ragged if input_schema.column_schemas: source_col_name = input_schema.column_names[0] dtype = input_schema[source_col_name].dtype - is_list = input_schema[source_col_name].is_list - is_ragged = input_schema[source_col_name].is_ragged if self.output_dtype is not None: dtype = self.output_dtype - is_list = any(cs.is_list for _, cs in input_schema.column_schemas.items()) - is_ragged = any(cs.is_ragged for _, cs in input_schema.column_schemas.items()) - return col_schema.with_dtype(dtype, is_list=is_list, is_ragged=is_ragged) + return col_schema.with_dtype(dtype) + + def _compute_shape(self, col_schema, input_schema): + shape = col_schema.shape + + if input_schema.column_schemas: + source_col_name = input_schema.column_names[0] + shape = input_schema[source_col_name].shape + + return col_schema.with_shape(shape) @property def dynamic_dtypes(self): diff --git a/merlin/dag/executors.py b/merlin/dag/executors.py index 9680a6ed1..12f43736e 100644 --- a/merlin/dag/executors.py +++ b/merlin/dag/executors.py @@ -27,6 +27,7 @@ set_client_deprecated, ) from merlin.dag import ColumnSelector, Graph, Node +from merlin.dtypes.shape import DefaultShapes from merlin.io.worker import clean_worker_cache LOG = logging.getLogger("merlin") @@ -44,6 +45,7 @@ def transform( output_dtypes=None, additional_columns=None, capture_dtypes=False, + validate_dtypes=True, ): """ Transforms a single dataframe (possibly a partition of a Dask Dataframe) @@ -67,11 +69,13 @@ def transform( output_data = None for node in nodes: - input_data = self._build_input_data(node, transformable, capture_dtypes=capture_dtypes) + input_data = self._build_input_data( + node, transformable, capture_dtypes=capture_dtypes, validate_dtypes=validate_dtypes + ) if node.op: transformed_data = self._transform_data( - node, input_data, capture_dtypes=capture_dtypes + node, input_data, capture_dtypes=capture_dtypes, validate_dtypes=validate_dtypes ) else: transformed_data = input_data @@ -85,7 +89,7 @@ def transform( return output_data - def _build_input_data(self, node, transformable, capture_dtypes=False): + def _build_input_data(self, node, transformable, capture_dtypes=False, validate_dtypes=True): """ Recurse through the graph executing parent and dependency operators to form the input dataframe for each output node @@ -114,7 +118,12 @@ def _build_input_data(self, node, transformable, capture_dtypes=False): for parent in node.parents_with_dependencies: parent_output_cols = _get_unique(parent.output_schema.column_names) - parent_data = self.transform(transformable, [parent], capture_dtypes=capture_dtypes) + parent_data = self.transform( + transformable, + [parent], + capture_dtypes=capture_dtypes, + validate_dtypes=validate_dtypes, + ) if input_data is None or not len(input_data): input_data = parent_data[parent_output_cols] seen_columns = set(parent_output_cols) @@ -141,7 +150,7 @@ def _build_input_data(self, node, transformable, capture_dtypes=False): return input_data - def _transform_data(self, node, input_data, capture_dtypes=False): + def _transform_data(self, node, input_data, capture_dtypes=False, validate_dtypes=True): """ Run the transform represented by the final node in the graph and check output dtypes against the output schema @@ -153,6 +162,8 @@ def _transform_data(self, node, input_data, capture_dtypes=False): Dataframe to run the graph ending with node on capture_dtypes : bool, optional Overrides the schema dtypes with the actual dtypes when True, by default False + validate_dtypes : bool, optional + Checks the dtype of returned data against the schema, by default True Returns ------- Transformable @@ -161,7 +172,7 @@ def _transform_data(self, node, input_data, capture_dtypes=False): ------ TypeError If the transformed output columns don't have the same dtypes - as the output schema columns + as the output schema columns when validate_dtypes is True RuntimeError If no DataFrame or DictArray is returned from the operator """ @@ -171,40 +182,51 @@ def _transform_data(self, node, input_data, capture_dtypes=False): output_data = node.op.transform(selection, input_data) # Update or validate output_data dtypes - for col_name, output_col_schema in node.output_schema.column_schemas.items(): - col_series = output_data[col_name] - col_dtype = col_series.dtype - is_list = is_list_dtype(col_series) - - if is_list: - col_dtype = list_val_dtype(col_series) - - # TODO: Add a utility that condenses the known methods of fetching dtypes - # from series/arrays into a single function, so that Tensorflow specific - # code doesn't leak into the executors - if not hasattr(col_dtype, "as_numpy_dtype") and hasattr(col_series, "numpy"): - col_dtype = col_series[0].cpu().numpy().dtype - - output_data_schema = output_col_schema.with_dtype(col_dtype, is_list=is_list) - - if capture_dtypes: - node.output_schema.column_schemas[col_name] = output_data_schema - elif len(output_data): - # Validate that the dtypes match but only if they both exist - # (since schemas may not have all dtypes specified, especially - # in the tests) - if ( - output_col_schema.dtype - and output_data_schema.dtype - and output_col_schema.dtype != md.string - and output_col_schema.dtype != output_data_schema.dtype + if capture_dtypes or validate_dtypes: + for col_name, output_col_schema in node.output_schema.column_schemas.items(): + col_series = output_data[col_name] + output_data_dtype = col_series.dtype + col_shape = output_col_schema.shape + is_list = is_list_dtype(col_series) + + if is_list: + output_data_dtype = list_val_dtype(col_series) + + if not col_shape.is_list or col_shape.is_unknown: + col_shape = DefaultShapes.LIST + + # TODO: Add a utility that condenses the known methods of fetching dtypes + # from series/arrays into a single function, so that Tensorflow specific + # code doesn't leak into the executors + if not hasattr(output_data_dtype, "as_numpy_dtype") and hasattr( + col_series, "numpy" ): - raise TypeError( - f"Dtype discrepancy detected for column {col_name}: " - f"operator {node.op.label} reported dtype " - f"`{output_col_schema.dtype}` but returned dtype " - f"`{output_data_schema.dtype}`." - ) + output_data_dtype = col_series[0].cpu().numpy().dtype + + output_data_schema = output_col_schema.with_dtype(output_data_dtype).with_shape( + col_shape + ) + + if capture_dtypes: + node.output_schema.column_schemas[col_name] = output_data_schema + elif validate_dtypes and len(output_data): + # Validate that the dtypes match but only if they both exist + # (since schemas may not have all dtypes specified, especially + # in the tests) + output_schema_dtype = output_col_schema.dtype.without_shape + output_data_dtype = md.dtype(output_data_dtype).without_shape + if ( + output_schema_dtype + and output_data_dtype + and output_schema_dtype != md.string + and output_schema_dtype != output_data_dtype + ): + raise TypeError( + f"Dtype discrepancy detected for column {col_name}: " + f"operator {node.op.label} reported dtype " + f"`{output_schema_dtype}` but returned dtype " + f"`{output_data_dtype}`." + ) except Exception: LOG.exception("Failed to transform operator %s", node.op) raise diff --git a/merlin/dag/graph.py b/merlin/dag/graph.py index b30835723..0a7682044 100644 --- a/merlin/dag/graph.py +++ b/merlin/dag/graph.py @@ -31,6 +31,11 @@ class Graph: + """ + Represents an DAG composed of Nodes, each of which contains an operator that + transforms dataframes or dataframe-like data + """ + def __init__(self, output_node: Node, subgraphs: Optional[Dict[str, Node]] = None): self.output_node = output_node self.subgraphs = subgraphs or {} @@ -84,6 +89,21 @@ def column_mapping(self): return column_mapping def construct_schema(self, root_schema: Schema, preserve_dtypes=False) -> "Graph": + """ + Given the schema of a dataset to transform, determine the output schema of the graph + + Parameters + ---------- + root_schema : Schema + The schema of a dataset to be transformed with this DAG + preserve_dtypes : bool, optional + Whether to keep any dtypes that may already be present in the schemas, by default False + + Returns + ------- + Graph + This DAG after the schemas have been filled in + """ nodes = list(postorder_iter_nodes(self.output_node)) self._compute_node_schemas(root_schema, nodes, preserve_dtypes) diff --git a/merlin/dag/utils.py b/merlin/dag/utils.py new file mode 100644 index 000000000..b036687b6 --- /dev/null +++ b/merlin/dag/utils.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +def ungroup_values_offsets(grouped_cols: dict) -> dict: + """ + Flatten columns with values/offsets tuples in a dictionary to separate keys + + Parameters + ---------- + grouped_cols : dict + A dictionary of column arrays including values/offsets tuples + + Returns + ------- + dict + A dictionary of column arrays with separate keys for values and offsets + """ + flat_cols = {} + + for key, value in grouped_cols.items(): + if isinstance(value, tuple): + flat_cols[f"{key}__values"] = value[0] + flat_cols[f"{key}__offsets"] = value[1] + else: + flat_cols[key] = value + + return flat_cols + + +def group_values_offsets(flat_cols: dict) -> dict: + """ + Convert separate values/offsets keys for columns into tuples w/ a single key + + Parameters + ---------- + flat_cols : dict + A dictionary of column arrays with separate keys for values and offsets + + Returns + ------- + dict + A dictionary of column arrays including values/offsets tuples + """ + grouped_cols = {} + + for key, value in flat_cols.items(): + if key.endswith("__values"): + col_name = key.replace("__values", "") + grouped_cols[col_name] = (flat_cols[key], flat_cols[f"{col_name}__offsets"]) + elif key.endswith("__offsets"): + pass + else: + grouped_cols[key] = value + + return grouped_cols diff --git a/merlin/dtypes/__init__.py b/merlin/dtypes/__init__.py index 1b78db48c..fea6038a7 100644 --- a/merlin/dtypes/__init__.py +++ b/merlin/dtypes/__init__.py @@ -49,7 +49,6 @@ def dtype(external_dtype): try: return _dtype_registry.to_merlin(external_dtype) except TypeError as base_exc: - try: return _dtype_registry.to_merlin_via_numpy(external_dtype) except TypeError as exc: diff --git a/merlin/dtypes/base.py b/merlin/dtypes/base.py index 75f5c77a8..d486089bf 100644 --- a/merlin/dtypes/base.py +++ b/merlin/dtypes/base.py @@ -14,11 +14,12 @@ # limitations under the License. # -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum -from typing import Optional +from typing import Optional, Tuple, Union from merlin.dtypes.registry import _dtype_registry +from merlin.dtypes.shape import Shape class ElementType(Enum): @@ -69,6 +70,11 @@ class DType: element_size: Optional[int] = None element_unit: Optional[ElementUnit] = None signed: Optional[bool] = None + shape: Optional[Shape] = None + + def __post_init__(self): + if not self.shape: + object.__setattr__(self, "shape", Shape()) def to(self, mapping_name: str): """ @@ -103,7 +109,7 @@ def to(self, mapping_name: str): ) from exc try: - return mapping.from_merlin(self) + return mapping.from_merlin(self.without_shape) except KeyError as exc: raise ValueError( f"The registered dtype mapping for {mapping_name} doesn't contain type {self.name}." @@ -125,3 +131,48 @@ def is_integer(self): @property def is_float(self): return self.element_type.value == "float" + + def with_shape(self, shape: Union[Tuple, Shape]): + """ + Create a copy of this dtype with a new shape + + Parameters + ---------- + shape : Union[Tuple, Shape] + Object to set as shape of dtype, must be either a tuple or Shape. + + Returns + ------- + DType + A copy of this dtype containing the provided shape value + + Raises + ------ + TypeError + If value is not either a tuple or a Shape + """ + if isinstance(shape, tuple): + shape = Shape(shape) + + if not isinstance(shape, Shape): + raise TypeError( + f"Provided value {shape} (of type {type(shape)}) for DType.shape property " + "is not of type Shape." + ) + + return replace(self, shape=shape) + + @property + def without_shape(self): + """ + Create a copy of this object without the shape + + Returns + ------- + DType + A copy of this object with the shape removed + """ + if self.shape.dims is None: + return self + + return replace(self, shape=Shape()) diff --git a/merlin/dtypes/mapping.py b/merlin/dtypes/mapping.py index 047a38fb8..68e68a6a0 100644 --- a/merlin/dtypes/mapping.py +++ b/merlin/dtypes/mapping.py @@ -19,6 +19,10 @@ class NumpyPreprocessor: + """ + Allows converting framework dtypes to numpy dtypes before mapping to Merlin types + """ + def __init__( self, framework, @@ -32,6 +36,19 @@ def __init__( self.classes = classes or [] def matches(self, raw_dtype) -> bool: + """ + Check if this preprocessor has a translation available for a dtype + + Parameters + ---------- + raw_dtype : external framework dtype + The dtype to be translated to numpy + + Returns + ------- + bool + True if this preprocessor can convert the external dtype to Numpy + """ for attr in self.attrs: if hasattr(raw_dtype, attr): return True @@ -41,6 +58,19 @@ def matches(self, raw_dtype) -> bool: return False def to_numpy(self, raw_dtype) -> np.dtype: + """ + Translate an external framework dtype to Numpy + + Parameters + ---------- + raw_dtype : external framework dtype + The dtype to be translated to numpy + + Returns + ------- + np.dtype + The result of translating raw_dtype to Numpy + """ return self.translation_fn(raw_dtype) @@ -137,10 +167,7 @@ def _matches(self, dtype, mapping, base_class=None): # Some external dtype objects are not hashable, so they # can't be used as dictionary keys. In that case, match # against the dtype class instead. - hashable_dtype = dtype try: - hash(dtype) + return dtype in mapping.keys() except TypeError: - hashable_dtype = type(dtype) - - return hashable_dtype in mapping.keys() + return type(dtype) in mapping.keys() diff --git a/merlin/dtypes/mappings/cudf.py b/merlin/dtypes/mappings/cudf.py index 3f49405a3..4452727cd 100644 --- a/merlin/dtypes/mappings/cudf.py +++ b/merlin/dtypes/mappings/cudf.py @@ -21,6 +21,19 @@ def cudf_translator(raw_dtype) -> np.dtype: + """ + Translate cudf dtypes to Numpy dtypes + + Parameters + ---------- + raw_dtype : cudf dtype + The dtype to be translated + + Returns + ------- + np.dtype + The result of translating raw_dtype to Numpy + """ category_type = raw_dtype._categories.dtype if is_string_dtype(category_type): return np.dtype("str") diff --git a/merlin/dtypes/shape.py b/merlin/dtypes/shape.py new file mode 100644 index 000000000..1222aef7d --- /dev/null +++ b/merlin/dtypes/shape.py @@ -0,0 +1,177 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass, replace +from enum import Enum +from typing import Optional, Tuple, Union + + +class DefaultShapes(Enum): + LIST = (None, None) + SCALAR = (None,) + + +@dataclass(frozen=True) +class Dimension: + """ + The range of potential sizes for a single dimension of a field or column + """ + + min: int = 0 + max: Optional[int] = None + + def __post_init__(self): + if self.min is None: + raise ValueError("The minimum size of a dimension cannot be None. ") + + if self.min < 0: + raise ValueError( + "The minimum size of a dimension must be non-negative. " f"Provided min: {self.min}" + ) + + if self.max and self.max < 0: + raise ValueError( + "The maximum size of a dimension must be at least one. " f"Provided max: {self.max}" + ) + + if self.max and self.max < self.min: + raise ValueError( + "The maximum size of a dimension must be at least as large as the minimum size. " + f"Provided min: {self.min} max: {self.max}" + ) + + @property + def is_bounded(self): + return self.max is not None + + @property + def is_fixed(self): + return self.is_bounded and self.min == self.max + + @property + def is_variable(self): + return not self.is_fixed + + @property + def is_unknown(self): + return self.min == 0 and self.max is None + + def with_min(self, value): + return replace(self, min=value) + + def with_max(self, value): + return replace(self, max=value) + + +@dataclass(frozen=True) +class Shape: + """ + The range of potential sizes for all the dimensions of a field or column + """ + + dims: Optional[Union[Tuple, "Shape"]] = None + + def __post_init__(self): + if isinstance(self.dims, DefaultShapes): + object.__setattr__(self, "dims", self.dims.value) + + if isinstance(self.dims, Shape): + object.__setattr__(self, "dims", self.dims.dims) + + if self.dims is not None: + new_dims = [] + for i, dim in enumerate(self.dims): + if isinstance(dim, Dimension): + new_dim = dim + elif isinstance(dim, tuple) and len(dim) == 2: + new_dim = Dimension(dim[0], dim[1]) + elif isinstance(dim, int): + new_dim = Dimension(dim, dim) + elif dim is None: + new_dim = Dimension() + else: + raise ValueError( + f"Invalid shape tuple format: {self.dims}. Each dimension is expected " + " to be None, a single integer, or a tuple with length 2." + ) + new_dims.append(new_dim) + + object.__setattr__(self, "dims", tuple(new_dims)) + + def __eq__(self, other): + """ + Make `dims is None` a wildcard when determining equality + + This definition of equality allows an unknown shape with `dims is None` to be + considered equal or compatible with a known shape with `dims is not None`. + """ + if not isinstance(other, Shape): + return False + + if self.dims is None or other.dims is None: + return True + + return self.dims == other.dims + + def __iter__(self): + return self.dims + + def with_dim(self, index, value): + new_dims = list(self.dims) + new_dims[index] = value + return replace(self, dims=tuple(new_dims)) + + def with_dim_min(self, index, value): + return self.with_dim(index, self.dims[index].with_min(value)) + + def with_dim_max(self, index, value): + return self.with_dim(index, self.dims[index].with_max(value)) + + @property + def min(self) -> Tuple: + return tuple(dim.min for dim in self.dims) + + @property + def max(self) -> Tuple: + return tuple(dim.max for dim in self.dims) + + @property + def is_bounded(self): + return all(dim.is_bounded for dim in self.dims) + + @property + def is_fixed(self): + return all(dim.is_fixed for dim in self.dims) + + @property + def is_variable(self): + return not self.is_fixed + + @property + def is_list(self): + return self.dims is not None and len(self.dims) > 1 + + @property + def is_ragged(self): + return self.is_list and any(dim.min != dim.max for dim in self.dims[1:]) + + @property + def as_tuple(self): + return tuple(((dim.min, dim.max) for dim in self.dims)) if self.dims else None + + @property + def is_unknown(self): + return self.dims is None diff --git a/merlin/io/avro.py b/merlin/io/avro.py index ce4bcefca..1e2cfde11 100644 --- a/merlin/io/avro.py +++ b/merlin/io/avro.py @@ -46,7 +46,6 @@ def __init__(self, paths, part_size, storage_options=None, cpu=False, **kwargs): raise ValueError("cpu=True not supported for AvroDatasetEngine.") def to_ddf(self, columns=None, cpu=None): - # Check if we are using cpu cpu = self.cpu if cpu is None else cpu if cpu: @@ -80,7 +79,6 @@ def to_gpu(self): self.cpu = False def process_metadata(self, columns=None): - with open(self.paths[0], "rb") as fo: header = ua.core.read_header(fo) @@ -149,10 +147,8 @@ def process_metadata(self, columns=None): @classmethod def read_partition(cls, fs, piece, columns): - path = piece["path"] if "rows" in piece: - # See: (https://github.com/rapidsai/cudf/issues/6529) # Using `uavro` library for now. This means we must convert # data to pandas, and then to cudf (which is much slower diff --git a/merlin/io/csv.py b/merlin/io/csv.py index 73b455c2d..8958930ed 100644 --- a/merlin/io/csv.py +++ b/merlin/io/csv.py @@ -49,7 +49,6 @@ def __init__(self, paths, part_size, storage_options=None, cpu=False, **kwargs): self.paths = self.fs.glob(self.fs.sep.join([self.paths[0], "*"])) def to_ddf(self, columns=None, cpu=None): - # Check if we are using cpu cpu = self.cpu if cpu is None else cpu if cpu: diff --git a/merlin/io/dask.py b/merlin/io/dask.py index 9d1ece6b6..673db778a 100644 --- a/merlin/io/dask.py +++ b/merlin/io/dask.py @@ -177,7 +177,6 @@ def _write_subgraph( cpu, suffix, ): - fns = fns if isinstance(fns, (tuple, list)) else (fns,) writer = writer_factory( output_format, @@ -204,7 +203,6 @@ def _write_subgraph( def _write_metadata_files(md_list, output_path, output_format, cpu, schema): - # Separate and merge metadata general_md = [] special_md = [] @@ -223,7 +221,6 @@ def _write_metadata_files(md_list, output_path, output_format, cpu, schema): def _simple_shuffle(ddf, plan): - # Construct graph for a simple shuffle token = tokenize(ddf, plan) name = "shuffled-" + token @@ -256,7 +253,6 @@ def _ddf_to_dataset( partition_on=None, schema=None, ): - # Construct graph for Dask-based dataset write token = tokenize( ddf, shuffle, out_files_per_proc, cat_names, cont_names, label_names, suffix, partition_on @@ -377,7 +373,7 @@ def _finish_dataset(client, ddf, output_path, fs, output_format, cpu, schema): general_md = [] special_md = [] - for (gen, spec) in out.values(): + for gen, spec in out.values(): general_md.append(gen) if spec: special_md.append(spec) diff --git a/merlin/io/dataset.py b/merlin/io/dataset.py index fc1f49d1b..e3f0c98f1 100644 --- a/merlin/io/dataset.py +++ b/merlin/io/dataset.py @@ -38,6 +38,7 @@ list_val_dtype, ) from merlin.core.utils import device_mem_size, global_dask_client, set_client_deprecated +from merlin.dtypes.shape import DefaultShapes from merlin.io.csv import CSVDatasetEngine from merlin.io.dask import _ddf_to_dataset, _simple_shuffle from merlin.io.dataframe_engine import DataFrameDatasetEngine @@ -473,7 +474,6 @@ def shuffle_by_keys(self, keys, hive_data=None, npartitions=None): hive_mapping[_key].append(_val) if set(hive_mapping.keys()) == set(keys): - # Generate hive-mapping DataFrame summary hive_mapping = type(ddf._meta)(hive_mapping) cols = list(hive_mapping.columns) @@ -752,7 +752,6 @@ def to_parquet( """ if partition_on: - # Check that the user is not expecting a specific output-file # count/structure that is not supported if output_files: @@ -763,7 +762,6 @@ def to_parquet( raise ValueError("`preserve_files` not supported when `partition_on` is used.") else: - # Check that method (algorithm) is valid if method not in ("subgraph", "worker"): raise ValueError(f"{method} not a recognized method for `Dataset.to_parquet`") @@ -800,7 +798,6 @@ def to_parquet( # Deal with `method=="subgraph"`. # Convert `output_files` argument to a dict mapping if output_files: - # NOTES on `output_files`: # # - If a list of file names is specified, a contiguous range of @@ -896,14 +893,18 @@ def to_parquet( output_files[fn + suffix] = rgs suffix = "" # Don't add a suffix later - Names already include it + schema = Schema({**self.schema.column_schemas}) + if dtypes: _meta = _set_dtypes(ddf._meta, dtypes) ddf = ddf.map_partitions(_set_dtypes, dtypes, meta=_meta) + for col_name, col_dtype in dtypes.items(): + schema[col_name] = schema[col_name].with_dtype(col_dtype) fs = get_fs_token_paths(output_path)[0] fs.mkdirs(output_path, exist_ok=True) - tf_metadata = TensorflowMetadata.from_merlin_schema(self.schema) + tf_metadata = TensorflowMetadata.from_merlin_schema(schema) tf_metadata.to_proto_text_file(output_path) # Output dask_cudf DataFrame to dataset @@ -922,7 +923,7 @@ def to_parquet( self.cpu, suffix=suffix, partition_on=partition_on, - schema=self.schema if write_hugectr_keyset else None, + schema=schema if write_hugectr_keyset else None, ) def to_hugectr( @@ -1134,8 +1135,10 @@ def infer_schema(self, n=1): column_schemas = [] for column, dtype_info in dtypes.items(): dtype_val = dtype_info["dtype"] - is_list = dtype_info["is_list"] - col_schema = ColumnSchema(column, dtype=dtype_val, is_list=is_list, is_ragged=is_list) + + dims = DefaultShapes.LIST if dtype_info["is_list"] else DefaultShapes.SCALAR + col_schema = ColumnSchema(column, dtype=dtype_val, dims=dims) + column_schemas.append(col_schema) self.schema = Schema(column_schemas) diff --git a/merlin/io/fsspec_utils.py b/merlin/io/fsspec_utils.py index 7b1127af2..490d80ef9 100644 --- a/merlin/io/fsspec_utils.py +++ b/merlin/io/fsspec_utils.py @@ -113,7 +113,6 @@ def _optimized_read_partition_remote( def _optimized_read_remote(path, row_groups, columns, fs, **kwargs): - if row_groups is not None and not isinstance(row_groups, list): row_groups = [row_groups] @@ -254,7 +253,6 @@ def _fsspec_data_transfer( mode="rb", **kwargs, ): - # Calculate total file size file_size = file_size or fs.size(path_or_fob) @@ -265,7 +263,6 @@ def _fsspec_data_transfer( # Threaded read into "dummy" buffer buf = np.zeros(file_size, dtype="b") if byte_ranges: - # Optimize/merge the ranges byte_ranges = _merge_ranges( byte_ranges, @@ -320,7 +317,7 @@ def _merge_ranges(byte_ranges, max_block=256_000_000, max_gap=64_000): return new_ranges offset, size = byte_ranges[0] - for (new_offset, new_size) in byte_ranges[1:]: + for new_offset, new_size in byte_ranges[1:]: gap = new_offset - (offset + size) if gap > max_gap or (size + new_size + gap) > max_block: # Gap is too large or total read is too large @@ -349,9 +346,8 @@ def _read_byte_ranges( fs, **kwargs, ): - workers = [] - for (offset, nbytes) in ranges: + for offset, nbytes in ranges: if len(ranges) > 1: workers.append( Thread(target=_assign_block, args=(fs, path_or_fob, local_buffer, offset, nbytes)) diff --git a/merlin/io/parquet.py b/merlin/io/parquet.py index a46182caa..cfc3ed303 100644 --- a/merlin/io/parquet.py +++ b/merlin/io/parquet.py @@ -104,7 +104,6 @@ def read_metadata(*args, **kwargs): if (cudf_version.major == 21 and cudf_version.minor == 10) or ( cudf_version.major == 0 and cudf_version.minor == 0 ): - # We only need this work-around for cudf-21.10 return _override_read_metadata(_cudf_read_metadata, *args, **kwargs) return _override_read_metadata(CudfEngine.read_metadata, *args, **kwargs) @@ -287,7 +286,6 @@ def _override_read_metadata( # Apply file aggregation if aggregate_row_groups is not None: - # Convert `aggregate_files` to an integer `aggregation_depth` aggregation_depth = False if len(parts) and aggregate_files: @@ -448,7 +446,6 @@ def _process_parquet_metadata(self): self._pp_map = _pp_map def to_ddf(self, columns=None, cpu=None): - # Check if we are using cpu or gpu backend cpu = self.cpu if cpu is None else cpu backend_engine = CPUParquetEngine if cpu else GPUParquetEngine @@ -825,7 +822,6 @@ def regenerate_dataset( out_parts = 0 remaining_out_part_rows = rows_per_part for i, in_part_size in enumerate(size_list): - # The `split` dictionary will be passed to this input # partition to dictate how that partition will be split # into different output partitions/files. The "key" of @@ -834,7 +830,6 @@ def regenerate_dataset( split = {} last = 0 while in_part_size >= remaining_out_part_rows: - gets[out_parts].append(i) split[out_parts] = (last, last + remaining_out_part_rows) last += remaining_out_part_rows @@ -911,7 +906,6 @@ def regenerate_dataset( def _write_metadata_file(md_list, fs, output_path, gmd_base): - # Prepare both "general" and parquet metadata gmd = gmd_base.copy() pmd = {} @@ -939,7 +933,6 @@ def _write_metadata_file(md_list, fs, output_path, gmd_base): def _write_data(data_list, output_path, fs, fn): - # Initialize chunked writer path = fs.sep.join([output_path, fn]) writer = pwriter_cudf(path, compression=None) @@ -1125,7 +1118,6 @@ def _to_parquet(self, df, sink): return md def _append_writer(self, path, schema=None): - # Define "metadata collector" for pyarrow _md_collector = [] _args = [schema] diff --git a/merlin/io/writer.py b/merlin/io/writer.py index baeeb749d..c99154820 100644 --- a/merlin/io/writer.py +++ b/merlin/io/writer.py @@ -130,7 +130,6 @@ def _write_thread(self): @annotate("add_data", color="orange", domain="merlin_python") def add_data(self, df): - # Early return if isinstance(df, list) and not df: return diff --git a/merlin/schema/io/tensorflow_metadata.py b/merlin/schema/io/tensorflow_metadata.py index b126ff0d6..012615c8b 100644 --- a/merlin/schema/io/tensorflow_metadata.py +++ b/merlin/schema/io/tensorflow_metadata.py @@ -252,6 +252,7 @@ def _pb_extra_metadata(column_schema): properties = { k: v for k, v in column_schema.properties.items() if k not in ("domain", "value_count") } + properties["_dims"] = list(list(dim) for dim in column_schema.shape.as_tuple or []) properties["is_list"] = column_schema.is_list properties["is_ragged"] = column_schema.is_ragged if column_schema.dtype.element_size: @@ -270,8 +271,8 @@ def _pb_feature(column_schema): value_count = column_schema.properties.get("value_count", {}) if value_count: - min_length = value_count.get("min", 0) - max_length = value_count.get("max", 0) + min_length = value_count.get("min", 0) or 0 + max_length = value_count.get("max", 0) or 0 feature.value_count = ValueCount(min=min_length, max=max_length) feature.annotation.tag = _pb_tag(column_schema) @@ -323,9 +324,9 @@ def _merlin_value_count(feature): if proto_utils.has_field(feature, "value_count"): value_count = feature.value_count value_count_dict = {} - if value_count.min > 0: + if value_count.min and value_count.min > 0: value_count_dict["min"] = value_count.min - if value_count.max > 0: + if value_count.max and value_count.max > 0: value_count_dict["max"] = value_count.max return value_count_dict @@ -377,7 +378,7 @@ def _merlin_properties(feature): def _merlin_dtype(feature, properties): - dtype = None + dtype = md.unknown item_size = int(properties.get("dtype_item_size", 0)) or None if feature.type == FeatureType.INT: if item_size and item_size in int_dtypes_map: @@ -391,6 +392,18 @@ def _merlin_dtype(feature, properties): dtype = md.float64 elif feature.type == FeatureType.BYTES: dtype = md.string + + dims_list = properties.pop("_dims", None) + + if dims_list: + dims_tuple = tuple(tuple(dim) for dim in dims_list) + dtype = dtype.with_shape(dims_tuple) + + # If we found dims, avoid overwriting that shape with one inferred from counts or flags + properties.pop("value_count", None) + properties.pop("is_list", None) + properties.pop("is_ragged", None) + return dtype @@ -409,7 +422,12 @@ def _merlin_column(feature): if Tags.CATEGORICAL not in tags: tags.append(Tags.CATEGORICAL) - return ColumnSchema(name, tags, properties, dtype, is_list, is_ragged=is_ragged) + dims = dtype.shape.as_tuple + + if dims: + return ColumnSchema(name, tags, properties, dtype, dims=dims) + else: + return ColumnSchema(name, tags, properties, dtype, is_list=is_list, is_ragged=is_ragged) def _read_file(path: os.PathLike): diff --git a/merlin/schema/schema.py b/merlin/schema/schema.py index 41e4311bb..413f20111 100644 --- a/merlin/schema/schema.py +++ b/merlin/schema/schema.py @@ -14,25 +14,17 @@ # limitations under the License. # -from dataclasses import dataclass, field, replace -from enum import Enum -from typing import Dict, List, Optional, Text, Union +from dataclasses import InitVar, dataclass, field, replace +from typing import Dict, List, Optional, Text, Tuple, Union import pandas as pd import merlin.dtypes as md from merlin.dtypes import DType +from merlin.dtypes.shape import Shape from merlin.schema.tags import Tags, TagSet -class ColumnQuantity(Enum): - """Describes the number of elements in each row of a column""" - - SCALAR = "scalar" - FIXED_LIST = "fixed_list" - RAGGED_LIST = "ragged_list" - - @dataclass(frozen=True) class Domain: """Describes an integer or float domain. @@ -66,8 +58,9 @@ class ColumnSchema: dtype: Optional[DType] = None is_list: Optional[bool] = None is_ragged: Optional[bool] = None + dims: InitVar[Union[Tuple, Shape]] = None - def __post_init__(self): + def __post_init__(self, dims): """Standardize tags and dtypes on initialization This method works around the inability to set attributes on frozen dataclass @@ -78,60 +71,56 @@ def __post_init__(self): Raises: TypeError: If the provided dtype cannot be cast to a numpy dtype + ValueError: If the provided shape, value counts, and/or flags are inconsistent """ + # Provide defaults and minor conversions for convenience object.__setattr__(self, "tags", TagSet(self.tags)) - object.__setattr__(self, "dtype", md.dtype(self.dtype or md.unknown)) - # Validate the allowed range of value count - value_count = Domain(**self.properties.get("value_count", {})) - if value_count.min == 0 or value_count.max == 0: - raise ValueError( - "`value_count` min and max must be greater than zero. " - f"Provided min: {value_count.min} max: {value_count.max}" - ) + dtype = md.dtype(self.dtype or md.unknown).without_shape + object.__setattr__(self, "dtype", dtype) + + # Validate that everything provided is consistent + value_counts = self.properties.get("value_count", {}) + if self.is_list and self.is_ragged is False: + if "max" in value_counts and "min" not in value_counts: + value_counts["min"] = value_counts["max"] + if "max" not in value_counts and "min" in value_counts: + value_counts["max"] = value_counts["min"] + + self._validate_shape_info(self.shape, value_counts, self.is_list, self.is_ragged) + + # Pick which source to pull shape info from + if dims: + new_shape = Shape(dims) + elif dtype.shape.dims: + new_shape = dtype.shape + elif value_counts: + new_shape = self._shape_from_counts(Domain(**value_counts)) + elif self.is_list: + new_shape = self._shape_from_flags(self.is_list) + else: + new_shape = Shape() - if self.is_list is None: - object.__setattr__(self, "is_list", bool(value_count.max and value_count.max > 0)) + # Update the shape and propagate out to flags and value counts + dtype = dtype.with_shape(new_shape) + object.__setattr__(self, "dtype", dtype) + object.__setattr__(self, "is_list", dtype.shape.is_list) + object.__setattr__(self, "is_ragged", dtype.shape.is_ragged) - if self.is_ragged is None: - if value_count.is_bounded and value_count.max > value_count.min: - object.__setattr__(self, "is_ragged", True) - elif value_count.is_bounded and value_count.max == value_count.min: - object.__setattr__(self, "is_ragged", False) - else: - object.__setattr__(self, "is_ragged", bool(self.is_list)) + if new_shape.dims is not None and len(new_shape.dims) > 1: + value_counts = {"min": new_shape.dims[1].min, "max": new_shape.dims[1].max} + properties = {**self.properties, **{"value_count": value_counts}} + object.__setattr__(self, "properties", properties) - if self.is_ragged and not self.is_list: - raise ValueError( - "`is_ragged` is set to `True` but `is_list` is not. " - "Only list columns can set the `is_ragged` flag." - ) + def _shape_from_flags(self, is_list): + return Shape(((0, None), (0, None))) if is_list else None - if self.is_ragged and value_count.is_bounded and value_count.min == value_count.max: - raise ValueError( - "`is_ragged` is set to `True` but `value_count.min` == `value_count.max`. " - "If value_count min/max are equal. " - "This is a fixed size list and `is_ragged` should be set to False. " - ) + def _shape_from_counts(self, value_count): + return Shape(((0, None), (value_count.min or 0, value_count.max))) @property - def quantity(self): - """ - Describes the number of elements in each row of this column - - Returns - ------- - ColumnQuantity - SCALAR when one element per row - FIXED_LIST when the same number of elements per row - RAGGED_LIST when different numbers of elements per row - """ - if self.is_list and self.is_ragged: - return ColumnQuantity.RAGGED_LIST - elif self.is_list: - return ColumnQuantity.FIXED_LIST - else: - return ColumnQuantity.SCALAR + def shape(self): + return self.dtype.shape def with_name(self, name: str) -> "ColumnSchema": """Create a copy of this ColumnSchema object with a different column name @@ -185,17 +174,26 @@ def with_properties(self, properties: dict) -> "ColumnSchema": """ if not isinstance(properties, dict): - raise TypeError("properties must be in dict format, key: value") + raise TypeError("ColumnSchema properties must be a dictionary") # Using new dictionary to avoid passing old ref to new schema new_properties = {**self.properties, **properties} - is_ragged = self.is_ragged - value_count = Domain(**new_properties.get("value_count", {})) - if value_count.is_bounded and value_count.max == value_count.min: - is_ragged = False + value_counts = properties.get("value_count", {}) - return replace(self, properties=new_properties, is_ragged=is_ragged) + if value_counts: + return replace( + self, + properties=new_properties, + dtype=self.dtype.without_shape, + is_list=None, + is_ragged=None, + ) + else: + return replace( + self, + properties=new_properties, + ) def with_dtype(self, dtype, is_list: bool = None, is_ragged: bool = None) -> "ColumnSchema": """Create a copy of this ColumnSchema object with different column dtype @@ -217,14 +215,46 @@ def with_dtype(self, dtype, is_list: bool = None, is_ragged: bool = None) -> "Co Copied object with new column dtype """ - is_list = is_list if is_list is not None else self.is_list + new_dtype = md.dtype(dtype).with_shape(self.shape) - if is_list: - is_ragged = is_ragged if is_ragged is not None else self.is_ragged - else: - is_ragged = False + properties = self.properties.copy() + if is_list is not None or is_ragged is not None: + properties.pop("value_count", None) + new_dtype = new_dtype.without_shape - return replace(self, dtype=dtype, is_list=is_list, is_ragged=is_ragged) + return replace( + self, dtype=new_dtype, properties=properties, is_list=is_list, is_ragged=is_ragged + ) + + def with_shape(self, shape: Union[Tuple, Shape]) -> "ColumnSchema": + """ + Create a copy of this object with a new shape + + Parameters + ---------- + shape : Union[Tuple, Shape] + Object to set as shape, must be either a tuple or Shape. + + Returns + ------- + ColumnSchema + A copy of this object containing the provided shape value + + Raises + ------ + TypeError + If value is not either a tuple or a Shape + """ + dims = Shape(shape).as_tuple + properties = self.properties.copy() + properties.pop("value_count", None) + return replace( + self, + dims=dims, + properties=properties, + is_list=None, + is_ragged=None, + ) @property def int_domain(self) -> Optional[Domain]: @@ -240,12 +270,13 @@ def value_count(self) -> Optional[Domain]: return Domain(**value_count) if value_count else None def __merge__(self, other): - col_schema = self.with_tags(other.tags) - col_schema = col_schema.with_properties(other.properties) - col_schema = col_schema.with_dtype( - other.dtype, is_list=other.is_list, is_ragged=other.is_ragged + col_schema = ( + self.with_name(other.name) + .with_dtype(other.dtype) + .with_tags(other.tags) + .with_properties(other.properties) + .with_shape(other.shape) ) - col_schema = col_schema.with_name(other.name) return col_schema def __str__(self) -> str: @@ -256,6 +287,59 @@ def _domain(self) -> Optional[Domain]: domain = self.properties.get("domain") return Domain(**domain) if domain else None + def _validate_shape_info(self, shape, value_counts, is_list, is_ragged): + value_counts = value_counts or {} + + min_count = value_counts.get("min", None) + max_count = value_counts.get("max", None) + ragged_counts = min_count != max_count + + if shape and shape.dims is not None: + if is_ragged is not None and shape.is_ragged != is_ragged: + raise ValueError( + f"Provided value of `is_ragged={is_ragged}` " + f"is inconsistent with shape `{shape}`." + ) + elif is_list is not None and shape.is_list != is_list: + raise ValueError( + f"Provided value of `is_list={is_list}` " + f"is inconsistent with shape `{shape}`." + ) + + if value_counts and shape and shape.dims is not None: + if (min_count and min_count != shape.dims[1].min) or ( + max_count and max_count != shape.dims[1].max + ): + raise ValueError( + f"Provided value counts `{value_counts}` " + f"are inconsistent with shape `{shape}`." + ) + + if is_list is False and is_ragged is True: + raise ValueError( + "Columns with `is_list=False` can't set `is_ragged=True`, " + "since non-list columns can't be ragged." + ) + + if value_counts and is_ragged is not None and is_ragged != ragged_counts: + raise ValueError( + f"Provided value of `is_ragged={is_ragged}` " + f"is inconsistent with value counts `{value_counts}`." + ) + + # TODO: Enable this validation once we've removed these cases + # from downstream Merlin libraries + # if ( + # not value_counts + # and not (shape and shape.dims) + # and is_list is True + # and is_ragged is False + # ): + # raise ValueError( + # "Can't determine a shape for this column from " + # "`is_list=True` and `is_ragged=False` without value counts. " + # ) + class Schema: """A collection of column schemas for a dataset.""" diff --git a/requirements.txt b/requirements.txt index 2dc58350c..5c4d769fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ dask==2022.7.1 distributed==2022.7.1 -pandas>=1.2.0,<1.4.0dev0 +fsspec==2022.7.1 +pandas>=1.2.0,<1.6.0dev0 numba>=0.54 pyarrow==8.0.0 protobuf>=3.0.0 @@ -8,4 +9,3 @@ tqdm>=4.0 tensorflow-metadata>=1.2.0 betterproto<2.0.0 packaging -fsspec==2022.7.1 diff --git a/tests/conftest.py b/tests/conftest.py index 565f8c626..2af0d595a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -152,8 +152,20 @@ def datasets(tmpdir_factory): half = int(len(df) // 2) # Write Parquet Dataset - df.iloc[:half].to_parquet(str(datadir["parquet"].join("dataset-0.parquet")), chunk_size=1000) - df.iloc[half:].to_parquet(str(datadir["parquet"].join("dataset-1.parquet")), chunk_size=1000) + if cudf: + df.iloc[:half].to_parquet( + str(datadir["parquet"].join("dataset-0.parquet")), row_group_size_rows=5000 + ) + df.iloc[half:].to_parquet( + str(datadir["parquet"].join("dataset-1.parquet")), row_group_size_rows=5000 + ) + else: + df.iloc[:half].to_parquet( + str(datadir["parquet"].join("dataset-0.parquet")), chunk_size=1000 + ) + df.iloc[half:].to_parquet( + str(datadir["parquet"].join("dataset-1.parquet")), chunk_size=1000 + ) # Write CSV Dataset (Leave out categorical column) df.iloc[:half].drop(columns=["name-cat"]).to_csv( diff --git a/tests/unit/core/test_dispatch.py b/tests/unit/core/test_dispatch.py index 7d55f27b3..62893af1a 100644 --- a/tests/unit/core/test_dispatch.py +++ b/tests/unit/core/test_dispatch.py @@ -14,10 +14,17 @@ # limitations under the License. # import numpy as np +import pandas as pd import pytest from merlin.core.dispatch import HAS_GPU, concat_columns, is_list_dtype, list_val_dtype, make_df +try: + import cupy as cp +except ImportError: + cp = None + + if HAS_GPU: _DEVICES = ["cpu", "gpu"] else: @@ -44,3 +51,15 @@ def test_concat_columns(device): data_frames = [df1, df2] res = concat_columns(data_frames) assert res.columns.to_list() == ["a", "b", "c"] + + +@pytest.mark.skipif(not cp, reason="Cupy not available") +def test_pandas_cupy_combo(): + rand_cp_nd_arr = cp.random.uniform(0.0, 1.0, size=100) + with pytest.raises(TypeError) as exc_info: + pd.DataFrame(rand_cp_nd_arr) + + assert "Implicit conversion to a NumPy array is not allowed" in str(exc_info) + pd_df = pd.DataFrame(rand_cp_nd_arr.get())[0] + mk_df = make_df(rand_cp_nd_arr)[0] + assert all(pd_df.to_numpy() == mk_df.to_numpy()) diff --git a/tests/unit/dag/test_dag_utils.py b/tests/unit/dag/test_dag_utils.py new file mode 100644 index 000000000..14d2662bf --- /dev/null +++ b/tests/unit/dag/test_dag_utils.py @@ -0,0 +1,31 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np + +from merlin.dag import group_values_offsets, ungroup_values_offsets + + +def test_flat_dict_to_tuple_dict(): + col1 = np.array([1, 2, 3, 4, 5]) + col2_values = np.array([6, 7, 8, 9, 10]) + col2_offsets = np.array([0, 2, 5]) + + flat_dict = {"col1": col1, "col2__values": col2_values, "col2__offsets": col2_offsets} + + tuple_dict = {"col1": col1, "col2": (col2_values, col2_offsets)} + + assert ungroup_values_offsets(tuple_dict) == flat_dict + assert group_values_offsets(flat_dict) == tuple_dict diff --git a/tests/unit/dtypes/test_shape.py b/tests/unit/dtypes/test_shape.py new file mode 100644 index 000000000..5dd9a164a --- /dev/null +++ b/tests/unit/dtypes/test_shape.py @@ -0,0 +1,222 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +import merlin.dtypes as md +from merlin.dtypes.shape import Dimension, Shape + +# Dimension + + +def test_empty_dimension(): + dim = Dimension() + assert dim.min == 0 + assert dim.max is None + + +def test_min_max_val_dimension(): + dim = Dimension(2, 3) + assert dim.min == 2 + assert dim.max == 3 + + +def test_fixed_min_with_unbounded_max(): + dim = Dimension(2) + assert dim.min == 2 + assert dim.max is None + + dim = Dimension(2, None) + assert dim.min == 2 + assert dim.max is None + + +def test_min_is_none_raises_error(): + with pytest.raises(ValueError): + Dimension(None) + + with pytest.raises(ValueError): + Dimension(None, 1) + + +def test_bounds_must_be_non_negative(): + with pytest.raises(ValueError): + Dimension(-1, 2) + + with pytest.raises(ValueError): + Dimension(2, -1) + + +def test_max_less_than_min(): + with pytest.raises(ValueError): + Dimension(2, 1) + + +def test_is_bounded(): + dim = Dimension() + assert dim.is_bounded is False + + dim = Dimension(2) + assert dim.is_bounded is False + + dim = Dimension(2, 2) + assert dim.is_bounded is True + + dim = Dimension(2, 4) + assert dim.is_bounded is True + + dim = Dimension(2, None) + assert dim.is_bounded is False + + +def test_is_fixed(): + dim = Dimension() + assert dim.is_fixed is False + + dim = Dimension(2) + assert dim.is_fixed is False + + dim = Dimension(2, 2) + assert dim.is_fixed is True + + dim = Dimension(2, 4) + assert dim.is_fixed is False + + dim = Dimension(2, None) + assert dim.is_fixed is False + + +def test_is_variable(): + dim = Dimension() + assert dim.is_variable is True + + dim = Dimension(2) + assert dim.is_variable is True + + dim = Dimension(2, 2) + assert dim.is_variable is False + + dim = Dimension(2, 4) + assert dim.is_variable is True + + dim = Dimension(2, None) + assert dim.is_variable is True + + +# Shape + + +def test_shape_without_args_represents_unknown(): + shape = Shape() + assert shape.dims is None + assert shape.is_list is False + assert shape.is_ragged is False + + +def test_shape_with_empty_tuple_represents_scalar(): + shape = Shape(()) + assert shape.dims == () + assert shape.is_list is False + assert shape.is_ragged is False + + +def test_flat_tuple_creates_fixed_shape(): + shape = Shape((1, 2, 3)) + assert shape.is_fixed is True + + +def test_none_is_shorthand_for_unknown_unbounded(): + shape = Shape((None, (4, 16))) + assert shape == Shape(((0, None), (4, 16))) + + +def test_nested_tuple_creates_variable_shape(): + shape = Shape(((5, 5), (2, 2), (3, 3))) + assert shape.is_variable is False + + shape = Shape(((1, 3), (2, 2), (3, 4))) + assert shape.is_variable is True + + shape = Shape(((1, 3), (2, 4), (4, 4))) + assert shape.is_variable is True + + shape = Shape(((1, 3), (2, 4), (4, 7))) + assert shape.is_variable is True + + +def test_mixed_tuple_creates_variable_shape(): + shape = Shape((5, (2, 3), 4)) + assert shape.is_variable is True + + +def test_nested_tuple_error(): + with pytest.raises(ValueError): + Shape((5, (2, None), (4, 5, 6))) + + with pytest.raises(ValueError): + Shape((5.3, (2, None), (4, 6))) + + with pytest.raises(ValueError): + Shape(("asdf", (2, None), (4, 6))) + + +def test_shape_properties(): + shape = Shape((5,)) + assert shape.is_fixed is True + assert shape.is_variable is False + assert shape.is_bounded is True + assert shape.is_ragged is False + assert shape.is_list is False + + shape = Shape((5, 1)) + assert shape.is_fixed is True + assert shape.is_variable is False + assert shape.is_bounded is True + assert shape.is_ragged is False + assert shape.is_list is True + + shape = Shape(((5, None), 2, 4)) + assert shape.is_fixed is False + assert shape.is_variable is True + assert shape.is_bounded is False + assert shape.is_ragged is False + assert shape.is_list is True + + shape = Shape((5, (2, None), 4)) + assert shape.is_fixed is False + assert shape.is_variable is True + assert shape.is_bounded is False + assert shape.is_ragged is True + assert shape.is_list is True + + shape = Shape((5, 2, (4, None))) + assert shape.is_fixed is False + assert shape.is_variable is True + assert shape.is_bounded is False + assert shape.is_ragged is True + assert shape.is_list is True + + +# DType + + +def test_dtype_has_a_shape(): + assert md.int32.shape == Shape() + + +def test_dtype_with_shape(): + dtype = md.int32.with_shape((3, 4, 5)) + assert dtype.shape != (3, 4, 5) + assert dtype.shape == Shape((3, 4, 5)) diff --git a/tests/unit/io/test_io.py b/tests/unit/io/test_io.py index cc0e6ef66..fabf8932d 100644 --- a/tests/unit/io/test_io.py +++ b/tests/unit/io/test_io.py @@ -27,6 +27,7 @@ from dask.dataframe import assert_eq from packaging.version import Version +import merlin.dtypes as md import merlin.io from merlin.core import dispatch from merlin.io.parquet import GPUParquetWriter @@ -51,7 +52,7 @@ def test_validate_dataset_bad_schema(tmpdir): pytest.skip("Test requires newer version of Dask.") path = str(tmpdir) - for (fn, df) in [ + for fn, df in [ ("part.0.parquet", pd.DataFrame({"a": range(10), "b": range(10)})), ("part.1.parquet", pd.DataFrame({"a": [None] * 10, "b": range(10)})), ]: @@ -250,7 +251,6 @@ def test_dask_dataset(datasets, engine, num_files, cpu): @pytest.mark.parametrize("origin", ["cudf", "dask_cudf", "pd", "dd"]) @pytest.mark.parametrize("cpu", [None, True]) def test_dask_dataset_from_dataframe(tmpdir, origin, cpu): - # Generate a DataFrame-based input if origin in ("pd", "dd"): df = pd.DataFrame({"a": range(100)}) @@ -456,7 +456,6 @@ def test_validate_dataset(datasets, engine): def test_validate_and_regenerate_dataset(tmpdir): - # Initial timeseries dataset (in cpu memory) ddf = dask.datasets.timeseries( start="2000-01-01", @@ -502,7 +501,6 @@ def test_validate_and_regenerate_dataset(tmpdir): @pytest.mark.parametrize("preserve_files", [True, False]) @pytest.mark.parametrize("cpu", [True, False]) def test_dataset_conversion(tmpdir, cpu, preserve_files): - # Generate toy dataset. # Include "hex" strings to mimic Criteo. size = 100 @@ -561,7 +559,6 @@ def test_dataset_conversion(tmpdir, cpu, preserve_files): @pytest.mark.parametrize("use_file_metadata", [True, None]) @pytest.mark.parametrize("shuffle", [True, False]) def test_parquet_iterator_len(tmpdir, shuffle, use_file_metadata): - ddf1 = dask.datasets.timeseries( start="2000-01-01", end="2000-01-6", @@ -594,7 +591,6 @@ def test_parquet_iterator_len(tmpdir, shuffle, use_file_metadata): @pytest.mark.parametrize("cpu", [True, False]) def test_hive_partitioned_data(tmpdir, cpu): - # Initial timeseries dataset (in cpu memory). # Round the full "timestamp" to the hour for partitioning. ddf = dask.datasets.timeseries( @@ -658,7 +654,6 @@ def test_hive_partitioned_data(tmpdir, cpu): @pytest.mark.parametrize("keys", [["name"], ["id"], ["name", "id"]]) @pytest.mark.parametrize("npartitions", [None, 2]) def test_dataset_shuffle_on_keys(tmpdir, cpu, partition_on, keys, npartitions): - # Initial timeseries dataset size = 60 df1 = pd.DataFrame( @@ -706,7 +701,6 @@ def test_dataset_shuffle_on_keys(tmpdir, cpu, partition_on, keys, npartitions): @pytest.mark.parametrize("cpu", [True, False]) def test_parquet_filtered_flat(tmpdir, cpu): - # Initial timeseries dataset (in cpu memory). # Round the full "timestamp" to the hour for partitioning. path = str(tmpdir) @@ -726,7 +720,6 @@ def test_parquet_filtered_flat(tmpdir, cpu): @pytest.mark.parametrize("cpu", [True, False]) def test_parquet_filtered_hive(tmpdir, cpu): - # Initial timeseries dataset (in cpu memory). # Round the full "timestamp" to the hour for partitioning. path = str(tmpdir) @@ -753,7 +746,6 @@ def test_parquet_filtered_hive(tmpdir, cpu): ) @pytest.mark.parametrize("cpu", [True, False]) def test_parquet_aggregate_files(tmpdir, cpu): - # Initial timeseries dataset (in cpu memory). # Round the full "timestamp" to the hour for partitioning. path = str(tmpdir) @@ -792,3 +784,18 @@ def test_parquet_aggregate_files(tmpdir, cpu): assert ds.to_ddf().npartitions == 1 assert len(ds.to_ddf().timestamp.unique()) == 1 _check_partition_lens(ds) + + +def test_to_parquet_dtypes_schema(tmpdir): + df = dispatch.make_df({"a": np.array([1, 2, 3], dtype=np.int32)}) + dataset = merlin.io.Dataset(df) + + # save to parquet with different dtypes and reload + dataset.to_parquet(output_path=str(tmpdir), dtypes={"a": np.float32}) + + # check that dtypes are unchanged + assert dataset.schema["a"].dtype == md.dtype("int32") + + reloaded_dataset = merlin.io.Dataset(str(tmpdir), engine="parquet") + # check that data was saved with the requested dtype + assert reloaded_dataset.schema["a"].dtype == md.dtype("float32") diff --git a/tests/unit/schema/test_column_schemas.py b/tests/unit/schema/test_column_schemas.py index 4dececf0f..daac73fdd 100644 --- a/tests/unit/schema/test_column_schemas.py +++ b/tests/unit/schema/test_column_schemas.py @@ -17,8 +17,8 @@ import pytest import merlin.dtypes as md +from merlin.dtypes.shape import Shape from merlin.schema import ColumnSchema -from merlin.schema.schema import ColumnQuantity from merlin.schema.tags import Tags, TagSet @@ -176,58 +176,32 @@ def test_list_column_attributes(): assert not col0_schema.is_list assert not col0_schema.is_ragged - assert col0_schema.quantity == ColumnQuantity.SCALAR col1_schema = ColumnSchema("col1", is_list=False, is_ragged=False) assert not col1_schema.is_list assert not col1_schema.is_ragged - assert col1_schema.quantity == ColumnQuantity.SCALAR col2_schema = ColumnSchema("col2", is_list=True) assert col2_schema.is_list assert col2_schema.is_ragged - assert col2_schema.quantity == ColumnQuantity.RAGGED_LIST col3_schema = ColumnSchema("col3", is_list=True, is_ragged=True) assert col3_schema.is_list assert col3_schema.is_ragged - assert col3_schema.quantity == ColumnQuantity.RAGGED_LIST - col4_schema = ColumnSchema("col4", is_list=True, is_ragged=False) + # TODO: Re-enable this test case once we've addressed cases + # like this in downstream libraries - assert col4_schema.is_list - assert not col4_schema.is_ragged - assert col4_schema.quantity == ColumnQuantity.FIXED_LIST + # with pytest.raises(ValueError): + # ColumnSchema("col4", is_list=True, is_ragged=False) with pytest.raises(ValueError): ColumnSchema("col5", is_list=False, is_ragged=True) -def test_value_count_invalid_min_max(): - with pytest.raises(ValueError) as exc_info: - ColumnSchema("col", is_ragged=True, properties={"value_count": {"min": 2, "max": 2}}) - assert "`is_ragged` is set to `True` but `value_count.min` == `value_count.max`" in str( - exc_info.value - ) - - -@pytest.mark.parametrize( - "properties", - [ - {"value_count": {"max": 0}}, - {"value_count": {"min": 0}}, - {"value_count": {"min": 0, "max": 2}}, - ], -) -def test_value_count_zero_min_max(properties): - with pytest.raises(ValueError) as exc_info: - ColumnSchema("col", is_ragged=True, properties=properties) - assert "`value_count` min and max must be greater than zero. " in str(exc_info.value) - - @pytest.mark.parametrize( ["value_count_min", "value_count_max"], [ @@ -246,11 +220,74 @@ def test_value_count(value_count_min, value_count_max): col_schema = ColumnSchema("col", properties={"value_count": value_count}) assert col_schema.value_count.max == value_count_max - assert col_schema.value_count.min == value_count_min + assert col_schema.value_count.min == (value_count_min or 0) + + +def test_value_count_inconsistency_with_flags(): + with pytest.raises(ValueError) as exc_info: + ColumnSchema( + "col", properties={"value_count": {"min": 5, "max": 5}}, is_list=True, is_ragged=True + ) + assert "Provided value of `is_ragged=True` is inconsistent with value counts" in str( + exc_info.value + ) + + +def test_column_schema_with_shape(): + col_schema = ColumnSchema("col") + assert col_schema.shape == Shape() + + col_schema = ColumnSchema("col", dtype=md.int32.with_shape((3, 4, 5))) + assert col_schema.shape != (3, 4, 5) + assert col_schema.shape == Shape((3, 4, 5)) + + col_schema = ColumnSchema("col", dims=(3, 4, 5)) + assert col_schema.shape != (3, 4, 5) + assert col_schema.shape == Shape((3, 4, 5)) + + col_schema = ColumnSchema("col").with_shape((3, 4, 5)) + assert col_schema.shape != (3, 4, 5) + assert col_schema.shape == Shape((3, 4, 5)) + + +@pytest.mark.parametrize("value_count", [{"max": 10}, {"min": 10}]) +def test_setting_partial_value_count(value_count): + col_schema = ColumnSchema( + "col", is_list=True, is_ragged=False, properties={"value_count": value_count} + ) + assert col_schema.is_list + assert not col_schema.is_ragged + assert col_schema.shape == Shape((None, 10)) + assert col_schema.properties["value_count"] == {"min": 10, "max": 10} + + +def test_setting_value_counts_updates_shape_and_flags(): + col_schema = ColumnSchema("col", dims=(None,)) + + counts = {"min": 4, "max": 5} + updated_schema = col_schema.with_properties({"value_count": counts}) + + assert updated_schema.properties["value_count"] == counts + assert updated_schema.shape == Shape((None, (4, 5))) + assert updated_schema.is_list + assert updated_schema.is_ragged + + +def test_setting_shape_updates_value_counts_and_flags(): + col_schema = ColumnSchema("col") + updated_schema = col_schema.with_shape((64, (4, 16))) + + assert updated_schema.shape == Shape((64, (4, 16))) + assert updated_schema.properties["value_count"] == {"min": 4, "max": 16} + assert updated_schema.is_list + assert updated_schema.is_ragged + +def test_setting_flags_updates_shape_and_value_counts(): + col_schema = ColumnSchema("col") + updated_schema = col_schema.with_dtype(md.int64, is_list=True, is_ragged=True) -def test_value_count_assign_properties(): - col_schema = ColumnSchema("col", is_list=True, is_ragged=True) - new_col_schema = col_schema.with_properties({"value_count": {"min": 5, "max": 5}}) - assert new_col_schema.is_ragged is False - assert new_col_schema.value_count.min == new_col_schema.value_count.max == 5 + assert updated_schema.shape == Shape((None, None)) + assert updated_schema.properties["value_count"] == {"min": 0, "max": None} + assert updated_schema.is_list + assert updated_schema.is_ragged diff --git a/tests/unit/schema/test_schema.py b/tests/unit/schema/test_schema.py index 1a648c257..239f89441 100644 --- a/tests/unit/schema/test_schema.py +++ b/tests/unit/schema/test_schema.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import dataclasses + import pytest from merlin.dag import ColumnSelector @@ -157,8 +159,11 @@ def test_schema_to_pandas(): schema_set = Schema(["a", "b", "c"]) df = schema_set.to_pandas() + expected_columns = [field.name for field in dataclasses.fields(ColumnSchema)] + expected_columns.remove("properties") + assert isinstance(df, pd.DataFrame) - assert list(df.columns) == ["name", "tags", "dtype", "is_list", "is_ragged"] + assert list(df.columns) == expected_columns def test_construct_schema_with_column_names(): diff --git a/tests/unit/schema/test_schema_io.py b/tests/unit/schema/test_schema_io.py index 81282469b..ced834175 100644 --- a/tests/unit/schema/test_schema_io.py +++ b/tests/unit/schema/test_schema_io.py @@ -19,6 +19,7 @@ import pytest import merlin.dtypes as md +from merlin.dtypes.shape import Shape from merlin.schema import ColumnSchema, Schema, Tags from merlin.schema.io.tensorflow_metadata import TensorflowMetadata @@ -231,6 +232,22 @@ def test_tensorflow_metadata_from_json(): # make sure the JSON formatted extra_metadata properties are human readable json_schema = json.loads(TensorflowMetadata.from_merlin_schema(schema).to_json()) - assert json_schema["feature"][0]["annotation"]["extraMetadata"] == [ - {"is_list": True, "is_ragged": True, "dtype_item_size": 64.0} - ] + extra_metadata = json_schema["feature"][0]["annotation"]["extraMetadata"][0] + extra_metadata = { + key: value for key, value in extra_metadata.items() if not key.startswith("_") + } + assert extra_metadata == {"is_list": True, "is_ragged": True, "dtype_item_size": 64.0} + + +@pytest.mark.parametrize("dim1", [1, None, (1, 3), (3, 3), (0, None), (4, None)]) +@pytest.mark.parametrize("dim2", [1, None, (1, 3), (3, 3), (0, None), (4, None)]) +@pytest.mark.parametrize("dim3", [1, None, (1, 3), (3, 3), (0, None), (4, None)]) +def test_shapes_survive_round_trip(dim1, dim2, dim3): + dims = (dim1, dim2, dim3) + + col_schema1 = ColumnSchema("col1", dtype=numpy.int, dims=dims) + + schema = Schema([col_schema1]) + loaded_schema = TensorflowMetadata.from_merlin_schema(schema).to_merlin_schema() + + assert loaded_schema["col1"].shape == Shape(dims) diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 08024ca0c..007a0fdce 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -30,7 +30,6 @@ @pytest.mark.parametrize("cpu", _CPU) def test_serial_context(client, cpu): - # Set distributed client set_dask_client(client=client) assert global_dask_client() == client @@ -48,7 +47,6 @@ def test_serial_context(client, cpu): @pytest.mark.parametrize("cpu", [True, False]) @pytest.mark.parametrize("nested_serial", _CPU) def test_nvt_distributed(cpu, nested_serial): - if cpu: distributed = pytest.importorskip("distributed") cluster_type = "cpu" @@ -84,7 +82,6 @@ def test_nvt_distributed(cpu, nested_serial): @pytest.mark.parametrize("cpu", _CPU) def test_nvt_distributed_force(client, cpu): - if cpu: distributed = pytest.importorskip("distributed") cluster_type = "cpu" diff --git a/tox.ini b/tox.ini index a759c16cc..9984b666b 100644 --- a/tox.ini +++ b/tox.ini @@ -84,6 +84,8 @@ commands = ; NOTE!!!! We must clean this up afterward with `rm -rf "systems-$GIT_COMMIT"` git clone --depth 1 --branch {posargs:main} https://github.com/NVIDIA-Merlin/systems.git systems-{env:GIT_COMMIT} python -m pip install --upgrade ./systems-{env:GIT_COMMIT}[test-cpu] + python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/NVTabular.git@{posargs:main} + python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/dataloader.git@{posargs:main} python -m pip install . ; this runs the tests then removes the systems repo directory whether the tests work or fail