Skip to content

Commit

Permalink
Add a new shape field to ColumnSchema
Browse files Browse the repository at this point in the history
This creates a place to store shape information for all dimensions of the data across both array/tensor and dataframe formats. In contrast to the existing "value_count" property (which only records the value counts of the lists in list field, this attribute is intended to capture the size of _all_ dimensions of the data (the batch dimension, the list lengths, embedding sizes, etc.)
  • Loading branch information
karlhigley committed Jan 6, 2023
1 parent 39e2c9a commit 60a0f33
Showing 1 changed file with 81 additions and 1 deletion.
82 changes: 81 additions & 1 deletion merlin/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from dataclasses import dataclass, field, replace
from enum import Enum
from typing import Dict, List, Optional, Text, Union
from typing import Dict, List, Optional, Text, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -55,6 +55,78 @@ def is_bounded(self):
return self.max and self.min


@dataclass(frozen=True)
class Dimension:
"""
The range of potential sizes for a single dimension of a field or column
"""

min: int = 0
max: Optional[int] = None

def __post_init__(self):
if self.min < 0:
raise ValueError(
"The minimum size of a dimension must be non-negative. " f"Provided min: {self.min}"
)

if self.max and self.max < 0:
raise ValueError(
"The maximum size of a dimension must be at least one. " f"Provided max: {self.max}"
)

if self.max and self.max < self.min:
raise ValueError(
"The maximum size of a dimension must be at least as large as the minimum size. "
f"Provided min: {self.min} max: {self.max}"
)

@property
def is_bounded(self):
return self.max is not None

@property
def is_fixed(self):
return self.is_bounded and self.min == self.max

@property
def is_variable(self):
return not self.is_fixed


@dataclass(frozen=True)
class Shape:
"""
The range of potential sizes for all the dimensions of a field or column
"""

dims: Tuple[Dimension]

@property
def min(self) -> Tuple:
return tuple(dim.min for dim in self.dims)

@property
def max(self) -> Tuple:
return tuple(dim.max for dim in self.dims)

@property
def fixed(self) -> Tuple:
return tuple(dim.min if dim.is_fixed else None for dim in self.dims)

@property
def is_bounded(self):
return all(dim.is_bounded for dim in self.dims)

@property
def is_fixed(self):
return all(dim.is_fixed for dim in self.dims)

@property
def is_variable(self):
return not self.is_fixed


@dataclass(frozen=True)
class ColumnSchema:
"""A schema containing metadata of a dataframe column."""
Expand All @@ -65,6 +137,7 @@ class ColumnSchema:
dtype: Optional[object] = None
is_list: Optional[bool] = None
is_ragged: Optional[bool] = None
shape: Optional[Tuple[Dimension]] = None

def __post_init__(self):
"""Standardize tags and dtypes on initialization
Expand Down Expand Up @@ -126,6 +199,13 @@ def __post_init__(self):
"This is a fixed size list and `is_ragged` should be set to False. "
)

if self.shape:
# Convert raw (min,max) tuples to Dimension objects for convenience
new_shape = (
Dimension(dim) if not isinstance(dim, Dimension) else dim for dim in self.shape
)
object.__setattr__(self, "shape", new_shape)

@property
def quantity(self):
"""
Expand Down

0 comments on commit 60a0f33

Please sign in to comment.