From 7d191941292b736db8552d9e7deaf13677d8d352 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 17 Sep 2024 13:44:14 -0700 Subject: [PATCH] Attempt to work around https://github.com/lebrice/SimpleParsing/issues/322 --- gbmi/exp_indhead/finetune.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gbmi/exp_indhead/finetune.py b/gbmi/exp_indhead/finetune.py index 46ea0c9d..b16c1cc1 100644 --- a/gbmi/exp_indhead/finetune.py +++ b/gbmi/exp_indhead/finetune.py @@ -144,7 +144,7 @@ def get_summary_slug(self, config: Config) -> str: class IndHeadFineTune(ExperimentConfig): train: Config[IndHead] finetune: IndHeadOnlyFineTune - base_model_force: Optional[Literal["train", "load"]] = "load" + base_model_force: Literal["train", "load", None] = "load" version_number: int = 1 def __post_init__(self): @@ -196,7 +196,7 @@ def get_ground_truth( def from_IndHead( config: Config[IndHead], config_finetune: IndHead, - base_model_force: Optional[Literal["train", "load"]] = "load", + base_model_force: Literal["train", "load", None] = "load", ) -> IndHeadFineTune: return IndHeadFineTune( train=config, @@ -208,7 +208,7 @@ def from_IndHead( def from_IndHeadConfig( config: Config[IndHead], config_finetune: Config[IndHead], - base_model_force: Optional[Literal["train", "load"]] = "load", + base_model_force: Literal["train", "load", None] = "load", ) -> Config[IndHeadFineTune]: return cast( Config[IndHeadFineTune], @@ -329,7 +329,7 @@ def test_dataloader(self): def make_default_finetune( config: Config[IndHead], alpha_mix_uniform: float = 1, - base_model_force: Optional[Literal["train", "load"]] = "load", + base_model_force: Literal["train", "load", None] = "load", ) -> Config[IndHeadFineTune]: return IndHeadFineTune.from_IndHeadConfig( config, @@ -350,7 +350,7 @@ def make_default_finetune( def main( argv: List[str] = sys.argv, default: Config[IndHeadFineTune] = ABCAB8_1H_FINETUNE, - default_force: Optional[Literal["train", "load"]] = None, + default_force: Literal["train", "load", None] = None, ): parser = simple_parsing.ArgumentParser( description="Train a model with configurable attention rate."