-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathraw_exp_monolithic_models.py
32 lines (23 loc) · 1.28 KB
/
raw_exp_monolithic_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import dataclasses
import typing
import mashumaro
import experiments
from . import base, raw_exp_abstract
@dataclasses.dataclass
class ExportConfigExpBaseModelsRaw(mashumaro.DataClassDictMixin,
base.RawToParsed[experiments.ExportConfigBaseModels]):
exists_ok: bool = dataclasses.field(default=False)
def parse(self) -> experiments.ExportConfigBaseModels:
# also here, we'll add the base directory later.
return experiments.ExportConfigBaseModels(exists_ok=self.exists_ok)
@dataclasses.dataclass
class ExperimentMonolithicModelRaw(mashumaro.DataClassDictMixin,
base.RawToParsed[experiments.ExperimentMonolithicModels],
raw_exp_abstract.AbstractExperiment):
monolithic_models: typing.Dict[str, base.FuncPair] = dataclasses.field(default_factory=list)
export_config: ExportConfigExpBaseModelsRaw = dataclasses.field(default_factory=ExportConfigExpBaseModelsRaw)
def parse(self) -> experiments.ExperimentMonolithicModels:
dg = self.get_dg()
return experiments.ExperimentMonolithicModels.from_dataset_generator(
dg=dg, monolithic_models=[(k, v.parse()) for (k, v) in self.monolithic_models.items()],
repetitions=self.repetitions)