Skip to content

Commit

Permalink
Extract dtypes and mappings from merlin.dtype module's init
Browse files Browse the repository at this point in the history
  • Loading branch information
karlhigley committed Nov 22, 2022
1 parent d4c0bbe commit 5ad1de6
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 98 deletions.
9 changes: 7 additions & 2 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pandas as pd
from dask.core import flatten

from merlin import dtype
from merlin.core.dispatch import concat_columns, is_list_dtype, list_val_dtype
from merlin.core.utils import (
ensure_optimize_dataframe_graph,
Expand Down Expand Up @@ -271,11 +272,15 @@ def transform(
columns = list(flatten(wfn.output_columns.names for wfn in nodes))
columns += additional_columns if additional_columns else []

if isinstance(output_dtypes, dict):
for col_name, col_dtype in output_dtypes.items():
output_dtypes[col_name] = dtype.to("numpy", col_dtype)

if isinstance(output_dtypes, dict) and isinstance(ddf._meta, pd.DataFrame):
dtypes = output_dtypes
output_dtypes = type(ddf._meta)({k: [] for k in columns})
for column, dtype in dtypes.items():
output_dtypes[column] = output_dtypes[column].astype(dtype)
for col_name, col_dtype in dtypes.items():
output_dtypes[col_name] = output_dtypes[col_name].astype(col_dtype)

elif not output_dtypes:
# TODO: constructing meta like this loses dtype information on the ddf
Expand Down
102 changes: 19 additions & 83 deletions merlin/dtype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,84 +14,18 @@
# limitations under the License.
#

# flake8: noqa
import dataclasses
import sys
from copy import copy
from dataclasses import dataclass
from enum import Enum
from types import ModuleType
from typing import Any, Dict, Optional, Tuple
from typing import Any, Optional, Tuple

import numpy as np
from merlin.dtype.dtypes import *
from merlin.dtype.dtypes import DType
from merlin.dtype.mappings import _dtype_registry


class ElementType(Enum):
Bool = "bool"
Int = "int"
UInt = "uint"
Float = "float"
String = "string"
DateTime = "datetime"
Object = "object"


@dataclass
class DType:
name: str
# TODO: Rename elemsize to bits or bytes for clarity
elemtype: ElementType
elemsize: Optional[int] = None
signed: Optional[bool] = None
shape: Optional[Tuple] = None

@property
def is_list(self):
return self.shape is not None and len(self.shape) > 1

@property
def is_ragged(self):
return self.is_list and None in self.shape


int32 = DType("int32", ElementType.Int, 32, signed=True)
int64 = DType("int64", ElementType.Int, 64, signed=True)
uint32 = DType("uint32", ElementType.UInt, 32)
uint64 = DType("uint64", ElementType.UInt, 64)
float32 = DType("float32", ElementType.Float, 32, signed=True)
float64 = DType("float64", ElementType.Float, 64, signed=True)
datetime64us = DType("datetime64[us]", ElementType.DateTime, 64)
datetime64ns = DType("datetime64[ns]", ElementType.DateTime, 64)
string = DType("str", ElementType.String)
boolean = DType("bool", ElementType.Bool)
object_ = DType("object", ElementType.Object)

_mapping_registry = []


# Is there ever a case where we'd want to preempt the built-in mappings?
def register(mapping: Dict[str, DType]):
_mapping_registry.append(mapping)


# Make these mappings immutable?
python_dtypes = {int: int64, float: float64, str: string}
register(python_dtypes)

numpy_dtypes = {
np.int32: int32,
np.dtype("int32"): int32,
np.int64: int64,
np.dtype("int64"): int64,
np.float32: float32,
np.dtype("float32"): float32,
np.float64: float64,
np.dtype("float64"): float64,
np.datetime64: datetime64ns,
np.dtype("datetime64[ns]"): datetime64ns,
np.dtype("datetime64[us]"): datetime64us,
np.str: string,
np.dtype("O"): object_,
}
register(numpy_dtypes)
# Convenience alias for this method
register = _dtype_registry.register


# This class implements the "call" method for the *module*, which
Expand All @@ -103,20 +37,22 @@ class DTypeModule(ModuleType):
def __call__(self, value: Any, shape: Optional[Tuple] = None):
if isinstance(value, DType):
return value
for mapping in _mapping_registry:
try:
if value in mapping:
merlin_type = copy(mapping[value])
if shape is not None:
merlin_type.shape = shape
return merlin_type
except TypeError:
pass

for _, mapping in _dtype_registry.mappings.items():
if value in mapping.to_merlin:
return dataclasses.replace(mapping.to_merlin[value], **{"shape": shape})

raise TypeError(
f"Merlin doesn't have a mapping from {value} to a Merlin dtype. "
"If you'd like to provide one, you can use `merlin.dtype.register()`."
)

def to(self, mapping_name: str, dtype: DType):
mapping = _dtype_registry.mappings[mapping_name]
# Don't match using the shape
dtype = dataclasses.replace(dtype, **{"shape": None})
# Always translate to the first external dtype in the list
return mapping.from_merlin[dtype][0]


sys.modules[__name__].__class__ = DTypeModule
58 changes: 58 additions & 0 deletions merlin/dtype/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass, replace
from enum import Enum
from typing import Optional, Tuple

from merlin.dtype.registry import _dtype_registry


class ElementType(Enum):
Bool = "bool"
Int = "int"
UInt = "uint"
Float = "float"
String = "string"
DateTime = "datetime"
Object = "object"


@dataclass(eq=True, frozen=True)
class DType:
name: str
elemtype: ElementType
elemsize: Optional[int] = None
signed: Optional[bool] = None
shape: Optional[Tuple] = None

# These properties refer to what's in a single row of the DataFrame/DictArray
@property
def is_list(self):
return self.shape is not None and len(self.shape) > 1

@property
def is_ragged(self):
return self.is_list and None in self.shape[1:]

def to(self, mapping_name):
mapping = _dtype_registry.mappings[mapping_name]

# Ignore the shape when matching dtypes
dtype = replace(dtype, **{"shape": None})

# Always translate to the first external dtype in the list
return mapping.from_merlin[dtype][0]
43 changes: 43 additions & 0 deletions merlin/dtype/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from merlin.dtype.base import DType, ElementType

# Unsigned Integer
uint8 = DType("uint8", ElementType.UInt, 8)
uint16 = DType("uint16", ElementType.UInt, 16)
uint32 = DType("uint32", ElementType.UInt, 32)
uint64 = DType("uint64", ElementType.UInt, 64)

# Signed Integer
int8 = DType("int8", ElementType.Int, 8, signed=True)
int16 = DType("int16", ElementType.Int, 16, signed=True)
int32 = DType("int32", ElementType.Int, 32, signed=True)
int64 = DType("int64", ElementType.Int, 64, signed=True)

# Float
float16 = DType("float16", ElementType.Float, 16, signed=True)
float32 = DType("float32", ElementType.Float, 32, signed=True)
float64 = DType("float64", ElementType.Float, 64, signed=True)

# Date/Time
datetime64us = DType("datetime64[us]", ElementType.DateTime, 64)
datetime64ns = DType("datetime64[ns]", ElementType.DateTime, 64)

# Miscellaneous
string = DType("str", ElementType.String)
boolean = DType("bool", ElementType.Bool)
object_ = DType("object", ElementType.Object)
55 changes: 55 additions & 0 deletions merlin/dtype/mappings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np

from merlin.dtype import dtypes
from merlin.dtype.registry import _dtype_registry

python_dtypes = {
dtypes.boolean: bool,
dtypes.int64: int,
dtypes.float64: float,
dtypes.string: str,
}
_dtype_registry.register("python", python_dtypes)


numpy_dtypes = {
# Unsigned Integer
dtypes.uint8: [np.uint8, np.dtype("uint8")],
dtypes.uint16: [np.uint16, np.dtype("uint16")],
dtypes.uint32: [np.uint32, np.dtype("uint32")],
dtypes.uint64: [np.uint64, np.dtype("uint64")],
# Signed integer
dtypes.int8: [np.int8, np.dtype("int8")],
dtypes.int16: [np.int16, np.dtype("int16")],
dtypes.int32: [np.int32, np.dtype("int32")],
dtypes.int64: [np.int64, np.dtype("int64")],
# Floating Point
dtypes.float16: [np.float16, np.dtype("float16")],
dtypes.float32: [np.float32, np.dtype("float32")],
dtypes.float64: [np.float64, np.dtype("float64")],
# Date/Time
# TODO: Figure out which datetime64 unit is the default
dtypes.datetime64ns: [np.datetime64, np.dtype("datetime64[ns]")],
dtypes.datetime64us: [np.datetime64, np.dtype("datetime64[us]")],
# Miscellaneous
dtypes.string: [np.str],
dtypes.object_: [np.dtype("O")],
dtypes.boolean: [np.bool, np.dtype("bool")],
}
_dtype_registry.register("numpy", numpy_dtypes)
45 changes: 45 additions & 0 deletions merlin/dtype/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Union


class DTypeMapping:
def __init__(self, mapping):
self.from_merlin = mapping
self.to_merlin = {}

for key, values in mapping.items():
if not isinstance(values, list):
values = [values]
for value in values:
self.to_merlin[value] = key


class DTypeMappingRegistry:
def __init__(self):
self.mappings = {}

def __iter__(self):
return iter(self.mappings)

def register(self, name: str, mapping: Union[Dict, DTypeMapping]):
if not isinstance(mapping, DTypeMapping):
mapping = DTypeMapping(mapping)

self.mappings[name] = mapping


_dtype_registry = DTypeMappingRegistry()
Loading

0 comments on commit 5ad1de6

Please sign in to comment.