From d92b3f5ebfe45ee637abc9f509138b9626770c69 Mon Sep 17 00:00:00 2001 From: Amit Parekh <7276308+amitkparekh@users.noreply.github.com> Date: Wed, 9 Oct 2024 16:02:32 +0100 Subject: [PATCH] fix: ensure that seeds are instantiated in the same way --- src/cogelot/entrypoints/interactive_evaluate.py | 3 +++ src/vima_bench/env/base.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/src/cogelot/entrypoints/interactive_evaluate.py b/src/cogelot/entrypoints/interactive_evaluate.py index e18ff1b..07da436 100644 --- a/src/cogelot/entrypoints/interactive_evaluate.py +++ b/src/cogelot/entrypoints/interactive_evaluate.py @@ -6,6 +6,7 @@ import torch import typer from loguru import logger +from pytorch_lightning import seed_everything from rich.console import Console from rich.table import Table @@ -29,6 +30,8 @@ def create_evaluation_module(config_path: Path) -> EvaluationLightningModule: config_path.name, overrides=["model.model.wandb_run_id=8lkml12g", "environment@model.environment=display"], ) + seed = config.get("seed") + seed_everything(seed, workers=True) evaluation = hydra.utils.instantiate(config["model"]) assert isinstance(evaluation, EvaluationLightningModule) return evaluation diff --git a/src/vima_bench/env/base.py b/src/vima_bench/env/base.py index 2a6d6b9..999672b 100644 --- a/src/vima_bench/env/base.py +++ b/src/vima_bench/env/base.py @@ -3,6 +3,7 @@ import os import tempfile import time +from contextlib import suppress from typing import Literal import gym @@ -159,6 +160,12 @@ def set_task(self, task: BaseTask | str | None = None, task_kwargs: dict | None # setup task ALL_TASKS = _ALL_TASKS.copy() ALL_TASKS.update({k.split("/")[1]: v for k, v in ALL_TASKS.items()}) + + with suppress(AttributeError): + updated_task_kwargs = {"seed": self.global_seed[0]} + if task_kwargs: + task_kwargs = {**task_kwargs, **updated_task_kwargs} + if isinstance(task, str): assert task in ALL_TASKS, f"Invalid task name provided {task}" task = ALL_TASKS[task](debug=self._debug, **(task_kwargs or {}))