Skip to content

Commit

Permalink
Attempt to work around lebrice/SimpleParsing#322
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonGross committed Sep 17, 2024
1 parent 229d099 commit 7d19194
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions gbmi/exp_indhead/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_summary_slug(self, config: Config) -> str:
class IndHeadFineTune(ExperimentConfig):
train: Config[IndHead]
finetune: IndHeadOnlyFineTune
base_model_force: Optional[Literal["train", "load"]] = "load"
base_model_force: Literal["train", "load", None] = "load"
version_number: int = 1

def __post_init__(self):
Expand Down Expand Up @@ -196,7 +196,7 @@ def get_ground_truth(
def from_IndHead(
config: Config[IndHead],
config_finetune: IndHead,
base_model_force: Optional[Literal["train", "load"]] = "load",
base_model_force: Literal["train", "load", None] = "load",
) -> IndHeadFineTune:
return IndHeadFineTune(
train=config,
Expand All @@ -208,7 +208,7 @@ def from_IndHead(
def from_IndHeadConfig(
config: Config[IndHead],
config_finetune: Config[IndHead],
base_model_force: Optional[Literal["train", "load"]] = "load",
base_model_force: Literal["train", "load", None] = "load",
) -> Config[IndHeadFineTune]:
return cast(
Config[IndHeadFineTune],
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_dataloader(self):
def make_default_finetune(
config: Config[IndHead],
alpha_mix_uniform: float = 1,
base_model_force: Optional[Literal["train", "load"]] = "load",
base_model_force: Literal["train", "load", None] = "load",
) -> Config[IndHeadFineTune]:
return IndHeadFineTune.from_IndHeadConfig(
config,
Expand All @@ -350,7 +350,7 @@ def make_default_finetune(
def main(
argv: List[str] = sys.argv,
default: Config[IndHeadFineTune] = ABCAB8_1H_FINETUNE,
default_force: Optional[Literal["train", "load"]] = None,
default_force: Literal["train", "load", None] = None,
):
parser = simple_parsing.ArgumentParser(
description="Train a model with configurable attention rate."
Expand Down

0 comments on commit 7d19194

Please sign in to comment.