-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathraw_poisoning.py
68 lines (51 loc) · 3.04 KB
/
raw_poisoning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import dataclasses
import typing
import mashumaro
from . import base
import poisoning
@dataclasses.dataclass
class SelectorRaw(mashumaro.DataClassDictMixin, base.RawToParsed[poisoning.AbstractSelector]):
name: str
init_kwargs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
def parse(self) -> poisoning.AbstractSelector:
# first we load the class and its kwargs.
obj: poisoning.AbstractSelector = base.load_func(self.name, self.init_kwargs)
return obj
@dataclasses.dataclass
class PerformerRaw(mashumaro.DataClassDictMixin, base.RawToParsed[poisoning.AbstractPerformer]):
# the class name
name: str
init_kwargs: typing.Dict[str, typing.Any] = dataclasses.field(default_factory=dict)
# it will be converted in the appropriate type when we convert this object first,
# because we can't know in advance to which type this needs to be converted (it depends on "name").
# performing_info: typing.Optional[typing.Union[poisoning.PoisoningInfoMonoDirectional, poisoning.PoisoningInfoBiDirectionalMirrored]] = dataclasses.field(default=None)
def parse(self) -> poisoning.AbstractPerformer:
# first we load the class and its kwargs.
return base.load_func(self.name, self.init_kwargs)
@dataclasses.dataclass
class PoisoningGenerationInfoRaw(mashumaro.DataClassDictMixin, base.RawToParsed[poisoning.PoisoningGenerationInput]):
selector: SelectorRaw
performer: PerformerRaw
perform_info_clazz: str
selection_info_clazz: str
perc_data_points: typing.Sequence[float]
perc_features: typing.Sequence[float] = dataclasses.field(default_factory=list)
columns: typing.Optional[typing.List[str]] = dataclasses.field(default=None)
# shuffle: bool = dataclasses.field(default=True)
# train_split: float = dataclasses.field(default=.75)
perform_info_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = dataclasses.field(default_factory=dict)
selection_info_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = dataclasses.field(default_factory=dict)
def parse(self) -> poisoning.PoisoningGenerationInput:
selector = self.selector.parse()
performer = self.performer.parse()
# load the class holding the poisoning information. Note that this is the CLASS only,
# i.e., the name of the class rather than an object of this type.
perform_info_clazz = base.load_func(self.perform_info_clazz, None)
perform_info_kwargs = base.fill_kwargs(self.perform_info_kwargs)
selection_info_clazz = base.load_func(self.selection_info_clazz, None)
selection_info_kwargs = base.fill_kwargs(self.selection_info_kwargs)
return poisoning.PoisoningGenerationInput(
columns=self.columns, selector=selector, performer=performer,
perform_info_clazz=perform_info_clazz, perform_info_kwargs=perform_info_kwargs,
selection_info_clazz=selection_info_clazz, selection_info_kwargs=selection_info_kwargs,
perc_features=self.perc_features, perc_data_points=self.perc_data_points)