Skip to content

Commit

Permalink
[ENH] refactor datatypes mtypes - example fixtures (#458)
Browse files Browse the repository at this point in the history
This PR refactors the data type specifications and converters to
classes.

Related: sktime/sktime#3512, related to
sktime/sktime#2957.

Contains:

* a base class for datatype examples, `BaseExample`, to replace the more
ad-hoc dictionary design
* a complete refactor of the `Table` and `Proba` mtype submodules to
this interface
* a full refactor of the public framework module with `get_example`
logic, in `datatypes`, to allow extensibility with this design

Partial mirror in `skpro` of sktime/sktime#6033
  • Loading branch information
fkiraly authored Sep 8, 2024
1 parent 31da07d commit 1f5e5f2
Show file tree
Hide file tree
Showing 7 changed files with 345 additions and 162 deletions.
4 changes: 2 additions & 2 deletions skpro/datatypes/_base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Base module for datatypes."""

from skpro.datatypes._base._base import BaseConverter, BaseDatatype
from skpro.datatypes._base._base import BaseConverter, BaseDatatype, BaseExample

__all__ = ["BaseConverter", "BaseDatatype"]
__all__ = ["BaseConverter", "BaseDatatype", "BaseExample"]
37 changes: 37 additions & 0 deletions skpro/datatypes/_base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,43 @@ def _get_key(self):
return (mtype_from, mtype_to, scitype)


class BaseExample(BaseObject):
"""Base class for Example fixtures used in tests and get_examples."""

_tags = {
"object_type": "datatype_example",
"scitype": None,
"mtype": None,
"python_version": None,
"python_dependencies": None,
"index": None, # integer index of the example to match with other mtypes
"lossy": False, # whether the example is lossy
}

def __init__(self):
super().__init__()

def _get_key(self):
"""Get unique dictionary key corresponding to self.
Private function, used in collecting a dictionary of examples.
"""
mtype = self.get_class_tag("mtype")
scitype = self.get_class_tag("scitype")
index = self.get_class_tag("index")
return (mtype, scitype, index)

def build(self):
"""Build example.
Returns
-------
obj : any
Example object.
"""
raise NotImplementedError


def _coerce_str_to_cls(cls_or_str):
"""Get class from string.
Expand Down
57 changes: 34 additions & 23 deletions skpro/datatypes/_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
e.g., metadata such as column names are missing
"""

from functools import lru_cache

from skpro.datatypes._registry import mtype_to_scitype

__author__ = ["fkiraly"]
Expand All @@ -21,29 +23,36 @@
"get_examples",
]

from skpro.datatypes._proba import (
example_dict_lossy_Proba,
example_dict_metadata_Proba,
example_dict_Proba,
)
from skpro.datatypes._table import (
example_dict_lossy_Table,
example_dict_metadata_Table,
example_dict_Table,
)

# pool example_dict-s
example_dict = dict()
example_dict.update(example_dict_Proba)
example_dict.update(example_dict_Table)
@lru_cache(maxsize=1)
def generate_example_dicts(soft_deps="present"):
"""Generate example dicts using lookup."""
from skbase.utils.dependencies import _check_estimator_deps

from skpro.datatypes._base import BaseExample
from skpro.utils.retrieval import _all_classes

classes = _all_classes("skpro.datatypes")
classes = [x[1] for x in classes]
classes = [x for x in classes if issubclass(x, BaseExample)]
classes = [x for x in classes if not x.__name__.startswith("Base")]

example_dict_lossy = dict()
example_dict_lossy.update(example_dict_lossy_Proba)
example_dict_lossy.update(example_dict_lossy_Table)
# subset only to data types with soft dependencies present
if soft_deps == "present":
classes = [x for x in classes if _check_estimator_deps(x, severity="none")]

example_dict_metadata = dict()
example_dict_metadata.update(example_dict_metadata_Proba)
example_dict_metadata.update(example_dict_metadata_Table)
example_dict = dict()
example_dict_lossy = dict()
example_dict_metadata = dict()
for cls in classes:
k = cls()
key = k._get_key()
key_meta = (key[1], key[2])
example_dict[key] = k
example_dict_lossy[key] = k.get_class_tags().get("lossy", False)
example_dict_metadata[key_meta] = k.get_class_tags().get("metadata", {})

return example_dict, example_dict_lossy, example_dict_metadata


def get_examples(
Expand Down Expand Up @@ -79,6 +88,8 @@ def get_examples(
if as_scitype is None:
as_scitype = mtype_to_scitype(mtype)

example_dict, example_dict_lossy, example_dict_metadata = generate_example_dicts()

# retrieve all keys that match the query
exkeys = example_dict.keys()
keys = [k for k in exkeys if k[0] == mtype and k[1] == as_scitype]
Expand All @@ -88,14 +99,14 @@ def get_examples(

for k in keys:
if return_lossy:
fixtures[k[2]] = (example_dict.get(k), example_dict_lossy.get(k))
fixtures[k[2]] = (example_dict.get(k).build(), example_dict_lossy.get(k))
elif return_metadata:
fixtures[k[2]] = (
example_dict.get(k),
example_dict.get(k).build(),
example_dict_lossy.get(k),
example_dict_metadata.get((k[1], k[2])),
)
else:
fixtures[k[2]] = example_dict.get(k)
fixtures[k[2]] = example_dict.get(k).build()

return fixtures
10 changes: 0 additions & 10 deletions skpro/datatypes/_proba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,11 @@

from skpro.datatypes._proba._check import check_dict as check_dict_Proba
from skpro.datatypes._proba._convert import convert_dict as convert_dict_Proba
from skpro.datatypes._proba._examples import example_dict as example_dict_Proba
from skpro.datatypes._proba._examples import (
example_dict_lossy as example_dict_lossy_Proba,
)
from skpro.datatypes._proba._examples import (
example_dict_metadata as example_dict_metadata_Proba,
)
from skpro.datatypes._proba._registry import MTYPE_LIST_PROBA, MTYPE_REGISTER_PROBA

__all__ = [
"check_dict_Proba",
"convert_dict_Proba",
"MTYPE_LIST_PROBA",
"MTYPE_REGISTER_PROBA",
"example_dict_Proba",
"example_dict_lossy_Proba",
"example_dict_metadata_Proba",
]
142 changes: 95 additions & 47 deletions skpro/datatypes/_proba/_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,64 +31,112 @@
import numpy as np
import pandas as pd

example_dict = dict()
example_dict_lossy = dict()
example_dict_metadata = dict()
from skpro.datatypes._base import BaseExample

###
# example 0: univariate

pred_q = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
pred_q.columns = pd.MultiIndex.from_product([["foo"], [0.2, 0.6]])

# we need to use this due to numerical inaccuracies from the binary based representation
pseudo_0_2 = 2 * np.abs(0.6 - 0.5)
class _ProbaUniv(BaseExample):
_tags = {
"scitype": "Proba",
"index": 0,
"metadata": {
"is_univariate": True,
"is_empty": False,
"has_nans": False,
},
}

example_dict[("pred_quantiles", "Proba", 0)] = pred_q
example_dict_lossy[("pred_quantiles", "Proba", 0)] = False

pred_int = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
pred_int.columns = pd.MultiIndex.from_tuples(
[("foo", 0.6, "lower"), ("foo", pseudo_0_2, "upper")]
)
class _ProbaUnivPredQ(_ProbaUniv):
_tags = {
"mtype": "pred_quantiles",
"python_dependencies": None,
"lossy": False,
}

example_dict[("pred_interval", "Proba", 0)] = pred_int
example_dict_lossy[("pred_interval", "Proba", 0)] = False
def build(self):
pred_q = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
pred_q.columns = pd.MultiIndex.from_product([["foo"], [0.2, 0.6]])

return pred_q


class _ProbaUnivPredInt(_ProbaUniv):
_tags = {
"mtype": "pred_interval",
"python_dependencies": None,
"lossy": False,
}

def build(self):
# we need to use this due to numerical inaccuracies
# from the binary based representation
pseudo_0_2 = 2 * np.abs(0.6 - 0.5)

pred_int = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4]})
pred_int.columns = pd.MultiIndex.from_tuples(
[("foo", 0.6, "lower"), ("foo", pseudo_0_2, "upper")]
)

return pred_int

example_dict_metadata[("Proba", 0)] = {
"is_univariate": True,
"is_empty": False,
"has_nans": False,
}

###
# example 1: multi

pred_q = pd.DataFrame({0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]})
pred_q.columns = pd.MultiIndex.from_product([["foo", "bar"], [0.2, 0.6]])

example_dict[("pred_quantiles", "Proba", 1)] = pred_q
example_dict_lossy[("pred_quantiles", "Proba", 1)] = False

pred_int = pd.DataFrame(
{0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]}
)
pred_int.columns = pd.MultiIndex.from_tuples(
[
("foo", 0.6, "lower"),
("foo", pseudo_0_2, "upper"),
("bar", 0.6, "lower"),
("bar", pseudo_0_2, "upper"),
]
)

example_dict[("pred_interval", "Proba", 1)] = pred_int
example_dict_lossy[("pred_interval", "Proba", 1)] = False


example_dict_metadata[("Proba", 1)] = {
"is_univariate": False,
"is_empty": False,
"has_nans": False,
}

class _ProbaMulti(BaseExample):
_tags = {
"scitype": "Proba",
"index": 1,
"metadata": {
"is_univariate": False,
"is_empty": False,
"has_nans": False,
},
}


class _ProbaMultiPredQ(_ProbaMulti):
_tags = {
"mtype": "pred_quantiles",
"python_dependencies": None,
"lossy": False,
}

def build(self):
pred_q = pd.DataFrame(
{0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]}
)
pred_q.columns = pd.MultiIndex.from_product([["foo", "bar"], [0.2, 0.6]])

return pred_q


class _ProbaMultiPredInt(_ProbaMulti):
_tags = {
"mtype": "pred_interval",
"python_dependencies": None,
"lossy": False,
}

def build(self):
# we need to use this due to numerical inaccuracies
# from the binary based representation
pseudo_0_2 = 2 * np.abs(0.6 - 0.5)

pred_int = pd.DataFrame(
{0.2: [1, 2, 3], 0.6: [2, 3, 4], 42: [5, 3, -1], 46: [5, 3, -1]}
)
pred_int.columns = pd.MultiIndex.from_tuples(
[
("foo", 0.6, "lower"),
("foo", pseudo_0_2, "upper"),
("bar", 0.6, "lower"),
("bar", pseudo_0_2, "upper"),
]
)

return pred_int
10 changes: 0 additions & 10 deletions skpro/datatypes/_table/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
"""Module exports: Series type checkers, converters and mtype inference."""

from skpro.datatypes._table._convert import convert_dict as convert_dict_Table
from skpro.datatypes._table._examples import example_dict as example_dict_Table
from skpro.datatypes._table._examples import (
example_dict_lossy as example_dict_lossy_Table,
)
from skpro.datatypes._table._examples import (
example_dict_metadata as example_dict_metadata_Table,
)
from skpro.datatypes._table._registry import MTYPE_LIST_TABLE, MTYPE_REGISTER_TABLE

__all__ = [
"convert_dict_Table",
"MTYPE_LIST_TABLE",
"MTYPE_REGISTER_TABLE",
"example_dict_Table",
"example_dict_lossy_Table",
"example_dict_metadata_Table",
]
Loading

0 comments on commit 1f5e5f2

Please sign in to comment.