Skip to content

Commit

Permalink
feat: factory pattern for creating datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
Liam James committed Mar 6, 2024
1 parent b429d61 commit b574912
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 22 deletions.
20 changes: 20 additions & 0 deletions aequitas/datasets/binary_label_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from aequitas.datasets.structured_dataset import StructuredDataset
from aequitas.metrics.binary_label_dataset_metric import BinaryLabelDatasetScoresMetric


class BinaryLabelDataset(StructuredDataset):

def __init__(self, favorable_label=1., unfavorable_label=0., **kwargs):
self._favorable_label = float(favorable_label)
self._unfavorable_label = float(unfavorable_label)

super(BinaryLabelDataset, self).__init__(**kwargs)

@classmethod
def new_instance(cls):
return cls

@property
def metrics(self, **kwargs):
return BinaryLabelDatasetScoresMetric(**kwargs)

12 changes: 12 additions & 0 deletions aequitas/datasets/concrete_dataset_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from aequitas.datasets.dataset_factory import DatasetFactory
from aequitas.datasets.binary_label_dataset import BinaryLabelDataset


class BinaryLabelDatasetFactory(DatasetFactory):
def __init__(self):
pass

def create_dataset(self, **kwargs) -> BinaryLabelDataset:
print("DEBUG:calling BinaryLabelDataset constructor")
ds = BinaryLabelDataset(favorable_label=1., unfavorable_label=0., **kwargs)
return ds
12 changes: 12 additions & 0 deletions aequitas/datasets/create_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from aequitas.datasets.structured_dataset import StructuredDataset
from aequitas.datasets.concrete_dataset_factories import BinaryLabelDatasetFactory


class CreateDataset:
def __init__(self, dataset_type):
self.dataset_type = dataset_type

def create_dataset(self, **kwargs) -> StructuredDataset:
if self.dataset_type == "binary label":
bldf = BinaryLabelDatasetFactory()
return bldf.create_dataset(**kwargs)
10 changes: 10 additions & 0 deletions aequitas/datasets/dataset_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
from aequitas.datasets.structured_dataset import StructuredDataset


class DatasetFactory(ABC):
@abstractmethod
def create_dataset(self, **kwargs) -> StructuredDataset:
raise NotImplementedError


14 changes: 0 additions & 14 deletions aequitas/datasets/my_structured_dataset.py

This file was deleted.

8 changes: 0 additions & 8 deletions aequitas/datasets/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,7 @@ def __init__(self,
kwargs["df"] = self._df
super(StructuredDataset, self).__init__(**kwargs)

@property
def strategy(self):
return self.__strategy

@property
@abstractmethod
def metrics(self):
raise NotImplementedError

def __custom_preprocessing(self, df: pd.DataFrame) -> pd.DataFrame:
print("Applying custom preprocessing")
return self.strategy.custom_preprocessing(df)

0 comments on commit b574912

Please sign in to comment.