generated from aequitas-aod/template-python-project-poetry
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: factory pattern for creating datasets
- Loading branch information
Liam James
committed
Mar 6, 2024
1 parent
b429d61
commit b574912
Showing
6 changed files
with
54 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from aequitas.datasets.structured_dataset import StructuredDataset | ||
from aequitas.metrics.binary_label_dataset_metric import BinaryLabelDatasetScoresMetric | ||
|
||
|
||
class BinaryLabelDataset(StructuredDataset): | ||
|
||
def __init__(self, favorable_label=1., unfavorable_label=0., **kwargs): | ||
self._favorable_label = float(favorable_label) | ||
self._unfavorable_label = float(unfavorable_label) | ||
|
||
super(BinaryLabelDataset, self).__init__(**kwargs) | ||
|
||
@classmethod | ||
def new_instance(cls): | ||
return cls | ||
|
||
@property | ||
def metrics(self, **kwargs): | ||
return BinaryLabelDatasetScoresMetric(**kwargs) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from aequitas.datasets.dataset_factory import DatasetFactory | ||
from aequitas.datasets.binary_label_dataset import BinaryLabelDataset | ||
|
||
|
||
class BinaryLabelDatasetFactory(DatasetFactory): | ||
def __init__(self): | ||
pass | ||
|
||
def create_dataset(self, **kwargs) -> BinaryLabelDataset: | ||
print("DEBUG:calling BinaryLabelDataset constructor") | ||
ds = BinaryLabelDataset(favorable_label=1., unfavorable_label=0., **kwargs) | ||
return ds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from aequitas.datasets.structured_dataset import StructuredDataset | ||
from aequitas.datasets.concrete_dataset_factories import BinaryLabelDatasetFactory | ||
|
||
|
||
class CreateDataset: | ||
def __init__(self, dataset_type): | ||
self.dataset_type = dataset_type | ||
|
||
def create_dataset(self, **kwargs) -> StructuredDataset: | ||
if self.dataset_type == "binary label": | ||
bldf = BinaryLabelDatasetFactory() | ||
return bldf.create_dataset(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from abc import ABC, abstractmethod | ||
from aequitas.datasets.structured_dataset import StructuredDataset | ||
|
||
|
||
class DatasetFactory(ABC): | ||
@abstractmethod | ||
def create_dataset(self, **kwargs) -> StructuredDataset: | ||
raise NotImplementedError | ||
|
||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters