Skip to content

Commit

Permalink
Move Shape to merlin.dtypes and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
karlhigley committed Jan 18, 2023
1 parent 7d03b5f commit ee8c804
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 76 deletions.
36 changes: 34 additions & 2 deletions merlin/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
# limitations under the License.
#

from dataclasses import dataclass
from dataclasses import dataclass, replace
from enum import Enum
from typing import Optional
from typing import Optional, Tuple, Union

from merlin.dtypes.registry import _dtype_registry
from merlin.dtypes.shape import Shape


class ElementType(Enum):
Expand Down Expand Up @@ -69,6 +70,7 @@ class DType:
element_size: Optional[int] = None
element_unit: Optional[ElementUnit] = None
signed: Optional[bool] = None
shape: Optional[Shape] = None

def to(self, mapping_name: str):
"""
Expand Down Expand Up @@ -125,3 +127,33 @@ def is_integer(self):
@property
def is_float(self):
return self.element_type.value == "float"

def with_shape(self, shape: Union[Tuple, Shape]):
"""
Create a copy of this dtype with a new shape
Parameters
----------
shape : Union[Tuple, Shape]
Object to set as shape of dtype, must be either a tuple or Shape.
Returns
-------
DType
A copy of this dtype containing the provided shape value
Raises
------
TypeError
If value is not either a tuple or a Shape
"""
if isinstance(shape, tuple):
shape = Shape(shape)

if not isinstance(shape, Shape):
raise TypeError(
f"Provided value {shape} (of type {type(shape)}) for DType.shape property "
"is not of type Shape."
)

return replace(self, shape=shape)
119 changes: 119 additions & 0 deletions merlin/dtypes/shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass
from typing import Optional, Tuple


@dataclass(frozen=True)
class Dimension:
"""
The range of potential sizes for a single dimension of a field or column
"""

min: int = 0
max: Optional[int] = None

def __post_init__(self):
if self.min is None:
raise ValueError("The mwhack-a-moleinimum size of a dimension cannot be None. ")

if self.min < 0:
raise ValueError(
"The minimum size of a dimension must be non-negative. " f"Provided min: {self.min}"
)

if self.max and self.max < 0:
raise ValueError(
"The maximum size of a dimension must be at least one. " f"Provided max: {self.max}"
)

if self.max and self.max < self.min:
raise ValueError(
"The maximum size of a dimension must be at least as large as the minimum size. "
f"Provided min: {self.min} max: {self.max}"
)

@property
def is_bounded(self):
return self.max is not None

@property
def is_fixed(self):
return self.is_bounded and self.min == self.max

@property
def is_variable(self):
return not self.is_fixed


@dataclass(frozen=True)
class Shape:
"""
The range of potential sizes for all the dimensions of a field or column
"""

dims: Tuple

def __post_init__(self):
new_dims = []
for i, dim in enumerate(self.dims):
if isinstance(dim, Dimension):
new_dim = dim
elif isinstance(dim, tuple) and len(dim) == 2:
new_dim = Dimension(dim[0], dim[1])
elif isinstance(dim, int):
new_dim = Dimension(dim, dim)
else:
raise ValueError(
"Invalid shape tuple format: {self.dims}. Each dimension is expected to be "
"either a single integer or a length 2 tuple."
)
new_dims.append(new_dim)

object.__setattr__(self, "dims", tuple(new_dims))

@property
def min(self) -> Tuple:
return tuple(dim.min for dim in self.dims)

@property
def max(self) -> Tuple:
return tuple(dim.max for dim in self.dims)

@property
def fixed(self) -> Tuple:
return tuple(dim.min if dim.is_fixed else None for dim in self.dims)

@property
def is_bounded(self):
return all(dim.is_bounded for dim in self.dims)

@property
def is_fixed(self):
return all(dim.is_fixed for dim in self.dims)

@property
def is_variable(self):
return not self.is_fixed

@property
def is_list(self):
return len(self.dims) > 1

@property
def is_ragged(self):
return self.is_list and any(dim.min != dim.max for dim in self.dims[1:])
80 changes: 6 additions & 74 deletions merlin/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

from dataclasses import dataclass, field, replace
from enum import Enum
from typing import Dict, List, Optional, Text, Tuple, Union
from typing import Dict, List, Optional, Text, Union

import pandas as pd

import merlin.dtypes as md
from merlin.dtypes import DType
from merlin.dtypes.shape import Dimension
from merlin.schema.tags import Tags, TagSet


Expand Down Expand Up @@ -56,78 +57,6 @@ def is_bounded(self):
return self.max and self.min


@dataclass(frozen=True)
class Dimension:
"""
The range of potential sizes for a single dimension of a field or column
"""

min: int = 0
max: Optional[int] = None

def __post_init__(self):
if self.min < 0:
raise ValueError(
"The minimum size of a dimension must be non-negative. " f"Provided min: {self.min}"
)

if self.max and self.max < 0:
raise ValueError(
"The maximum size of a dimension must be at least one. " f"Provided max: {self.max}"
)

if self.max and self.max < self.min:
raise ValueError(
"The maximum size of a dimension must be at least as large as the minimum size. "
f"Provided min: {self.min} max: {self.max}"
)

@property
def is_bounded(self):
return self.max is not None

@property
def is_fixed(self):
return self.is_bounded and self.min == self.max

@property
def is_variable(self):
return not self.is_fixed


@dataclass(frozen=True)
class Shape:
"""
The range of potential sizes for all the dimensions of a field or column
"""

dims: Tuple[Dimension]

@property
def min(self) -> Tuple:
return tuple(dim.min for dim in self.dims)

@property
def max(self) -> Tuple:
return tuple(dim.max for dim in self.dims)

@property
def fixed(self) -> Tuple:
return tuple(dim.min if dim.is_fixed else None for dim in self.dims)

@property
def is_bounded(self):
return all(dim.is_bounded for dim in self.dims)

@property
def is_fixed(self):
return all(dim.is_fixed for dim in self.dims)

@property
def is_variable(self):
return not self.is_fixed


@dataclass(frozen=True)
class ColumnSchema:
"""A schema containing metadata of a dataframe column."""
Expand All @@ -138,7 +67,6 @@ class ColumnSchema:
dtype: Optional[DType] = None
is_list: Optional[bool] = None
is_ragged: Optional[bool] = None
shape: Optional[Tuple[Dimension]] = None

def __post_init__(self):
"""Standardize tags and dtypes on initialization
Expand Down Expand Up @@ -194,6 +122,10 @@ def __post_init__(self):
)
object.__setattr__(self, "shape", new_shape)

@property
def shape(self):
return self.dtype.shape

@property
def quantity(self):
"""
Expand Down
Loading

0 comments on commit ee8c804

Please sign in to comment.