Skip to content

Commit

Permalink
fix: ensure that seeds are instantiated in the same way
Browse files Browse the repository at this point in the history
  • Loading branch information
amitkparekh committed Oct 9, 2024
1 parent 52eaf6f commit d92b3f5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/cogelot/entrypoints/interactive_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import typer
from loguru import logger
from pytorch_lightning import seed_everything
from rich.console import Console
from rich.table import Table

Expand All @@ -29,6 +30,8 @@ def create_evaluation_module(config_path: Path) -> EvaluationLightningModule:
config_path.name,
overrides=["model.model.wandb_run_id=8lkml12g", "environment@model.environment=display"],
)
seed = config.get("seed")
seed_everything(seed, workers=True)
evaluation = hydra.utils.instantiate(config["model"])
assert isinstance(evaluation, EvaluationLightningModule)
return evaluation
Expand Down
7 changes: 7 additions & 0 deletions src/vima_bench/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import tempfile
import time
from contextlib import suppress
from typing import Literal

import gym
Expand Down Expand Up @@ -159,6 +160,12 @@ def set_task(self, task: BaseTask | str | None = None, task_kwargs: dict | None
# setup task
ALL_TASKS = _ALL_TASKS.copy()
ALL_TASKS.update({k.split("/")[1]: v for k, v in ALL_TASKS.items()})

with suppress(AttributeError):
updated_task_kwargs = {"seed": self.global_seed[0]}
if task_kwargs:
task_kwargs = {**task_kwargs, **updated_task_kwargs}

if isinstance(task, str):
assert task in ALL_TASKS, f"Invalid task name provided {task}"
task = ALL_TASKS[task](debug=self._debug, **(task_kwargs or {}))
Expand Down

0 comments on commit d92b3f5

Please sign in to comment.