From 60a0f339fa64538c69dec6cb16af788c45736042 Mon Sep 17 00:00:00 2001 From: Karl Higley Date: Fri, 6 Jan 2023 12:33:20 -0500 Subject: [PATCH] Add a new `shape` field to `ColumnSchema` This creates a place to store shape information for all dimensions of the data across both array/tensor and dataframe formats. In contrast to the existing "value_count" property (which only records the value counts of the lists in list field, this attribute is intended to capture the size of _all_ dimensions of the data (the batch dimension, the list lengths, embedding sizes, etc.) --- merlin/schema/schema.py | 82 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/merlin/schema/schema.py b/merlin/schema/schema.py index ede772ef8..14d4aca16 100644 --- a/merlin/schema/schema.py +++ b/merlin/schema/schema.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, field, replace from enum import Enum -from typing import Dict, List, Optional, Text, Union +from typing import Dict, List, Optional, Text, Tuple, Union import numpy as np import pandas as pd @@ -55,6 +55,78 @@ def is_bounded(self): return self.max and self.min +@dataclass(frozen=True) +class Dimension: + """ + The range of potential sizes for a single dimension of a field or column + """ + + min: int = 0 + max: Optional[int] = None + + def __post_init__(self): + if self.min < 0: + raise ValueError( + "The minimum size of a dimension must be non-negative. " f"Provided min: {self.min}" + ) + + if self.max and self.max < 0: + raise ValueError( + "The maximum size of a dimension must be at least one. " f"Provided max: {self.max}" + ) + + if self.max and self.max < self.min: + raise ValueError( + "The maximum size of a dimension must be at least as large as the minimum size. " + f"Provided min: {self.min} max: {self.max}" + ) + + @property + def is_bounded(self): + return self.max is not None + + @property + def is_fixed(self): + return self.is_bounded and self.min == self.max + + @property + def is_variable(self): + return not self.is_fixed + + +@dataclass(frozen=True) +class Shape: + """ + The range of potential sizes for all the dimensions of a field or column + """ + + dims: Tuple[Dimension] + + @property + def min(self) -> Tuple: + return tuple(dim.min for dim in self.dims) + + @property + def max(self) -> Tuple: + return tuple(dim.max for dim in self.dims) + + @property + def fixed(self) -> Tuple: + return tuple(dim.min if dim.is_fixed else None for dim in self.dims) + + @property + def is_bounded(self): + return all(dim.is_bounded for dim in self.dims) + + @property + def is_fixed(self): + return all(dim.is_fixed for dim in self.dims) + + @property + def is_variable(self): + return not self.is_fixed + + @dataclass(frozen=True) class ColumnSchema: """A schema containing metadata of a dataframe column.""" @@ -65,6 +137,7 @@ class ColumnSchema: dtype: Optional[object] = None is_list: Optional[bool] = None is_ragged: Optional[bool] = None + shape: Optional[Tuple[Dimension]] = None def __post_init__(self): """Standardize tags and dtypes on initialization @@ -126,6 +199,13 @@ def __post_init__(self): "This is a fixed size list and `is_ragged` should be set to False. " ) + if self.shape: + # Convert raw (min,max) tuples to Dimension objects for convenience + new_shape = ( + Dimension(dim) if not isinstance(dim, Dimension) else dim for dim in self.shape + ) + object.__setattr__(self, "shape", new_shape) + @property def quantity(self): """