Skip to content

Commit

Permalink
♻️ Style Transfer: rename main methods
Browse files Browse the repository at this point in the history
  • Loading branch information
simonmeoni committed Feb 28, 2024
1 parent 56c1bf1 commit 2b1f1da
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions lib/style-transfer/style_transfer/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@hydra.main(version_base="1.3", config_path="../configs", config_name="gen.yaml")
def gen(cfg):
def main(cfg):
api = wandb.Api()
model_artifact = api.artifact(cfg.checkpoint)
model_dir = model_artifact.download()
Expand Down Expand Up @@ -130,4 +130,4 @@ def add_prompt(data_point):


if __name__ == "__main__":
gen()
main()
4 changes: 2 additions & 2 deletions lib/style-transfer/style_transfer/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@hydra.main(version_base="1.3", config_path="../configs", config_name="score.yaml")
def score(cfg):
def main(cfg):
wandb.config = omegaconf.OmegaConf.to_container(
cfg,
)
Expand Down Expand Up @@ -166,4 +166,4 @@ def add_prompt(data_point):


if __name__ == "__main__":
score()
main()
5 changes: 3 additions & 2 deletions lib/style-transfer/style_transfer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

os.environ["WANDB_PROJECT"] = "sft-style-transfer"
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_START_METHOD"] = "thread"


@hydra.main(version_base="1.3", config_path="../configs", config_name="sft.yaml")
def sft(cfg):
def main(cfg):
set_seed(cfg.seed)
dataset = build_dataset(
dataset_name=cfg.dataset,
Expand Down Expand Up @@ -83,4 +84,4 @@ def setup(self, args, state, model, **kwargs):


if __name__ == "__main__":
sft()
main()

0 comments on commit 2b1f1da

Please sign in to comment.