Skip to content

Commit

Permalink
format experiment server
Browse files Browse the repository at this point in the history
  • Loading branch information
sondreso committed Sep 15, 2024
1 parent a53531d commit 6eb7862
Showing 1 changed file with 57 additions and 24 deletions.
81 changes: 57 additions & 24 deletions src/ert/experiment_server/main.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,81 @@
import asyncio
import multiprocessing as mp
import os
from concurrent.futures import ProcessPoolExecutor
import queue
import uuid
from concurrent.futures import ProcessPoolExecutor
from multiprocessing.queues import Queue
import queue
from typing import Dict, Tuple, Union

from fastapi import BackgroundTasks, FastAPI, HTTPException, WebSocket
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, Field

from ert.config import ErtConfig, QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_models.model_factory import create_model
from ert.run_models.base_run_model import BaseRunModel, StatusEvents
from ert.gui.simulation.ensemble_experiment_panel import Arguments as EnsembleExperimentArguments
from ert.gui.simulation.ensemble_smoother_panel import Arguments as EnsembleSmootherArguments
from ert.gui.simulation.evaluate_ensemble_panel import Arguments as EvaluateEnsembleArguments
from ert.gui.simulation.iterated_ensemble_smoother_panel import Arguments as IteratedEnsembleSmootherArguments
from ert.ensemble_evaluator.event import EndEvent, _UpdateEvent
from ert.gui.simulation.ensemble_experiment_panel import (
Arguments as EnsembleExperimentArguments,
)
from ert.gui.simulation.ensemble_smoother_panel import (
Arguments as EnsembleSmootherArguments,
)
from ert.gui.simulation.evaluate_ensemble_panel import (
Arguments as EvaluateEnsembleArguments,
)
from ert.gui.simulation.iterated_ensemble_smoother_panel import (
Arguments as IteratedEnsembleSmootherArguments,
)
from ert.gui.simulation.manual_update_panel import Arguments as ManualUpdateArguments
from ert.gui.simulation.multiple_data_assimilation_panel import Arguments as MultipleDataAssimilationArguments
from ert.gui.simulation.multiple_data_assimilation_panel import (
Arguments as MultipleDataAssimilationArguments,
)
from ert.gui.simulation.single_test_run_panel import Arguments as SingleTestRunArguments
from ert.run_models.base_run_model import BaseRunModel, StatusEvents
from ert.run_models.model_factory import create_model
from ert.storage import open_storage
from ert.ensemble_evaluator.event import _UpdateEvent, EndEvent

from typing import Dict, Union, Tuple

from ert.config import ErtConfig, QueueSystem
from fastapi.encoders import jsonable_encoder

class Experiment(BaseModel):
args: Union[EnsembleExperimentArguments, EnsembleSmootherArguments, EvaluateEnsembleArguments, IteratedEnsembleSmootherArguments, ManualUpdateArguments, MultipleDataAssimilationArguments, SingleTestRunArguments] = Field(..., discriminator='mode')
args: Union[
EnsembleExperimentArguments,
EnsembleSmootherArguments,
EvaluateEnsembleArguments,
IteratedEnsembleSmootherArguments,
ManualUpdateArguments,
MultipleDataAssimilationArguments,
SingleTestRunArguments,
] = Field(..., discriminator="mode")
ert_config: ErtConfig

mp_ctx = mp.get_context('fork')
process_pool = ProcessPoolExecutor(max_workers=max((os.cpu_count() or 1) - 2, 1), mp_context=mp_ctx)

mp_ctx = mp.get_context("fork")
process_pool = ProcessPoolExecutor(
max_workers=max((os.cpu_count() or 1) - 2, 1), mp_context=mp_ctx
)
app = FastAPI()
experiments: Dict[str, Tuple[BaseRunModel, "Queue[StatusEvents]"]] = {}


@app.get("/")
async def root():
return {"message": "ping"}

experiments : Dict[str, Tuple[BaseRunModel, "Queue[StatusEvents]"]]= {}

async def run_experiment(experiment_id:str, evaluator_server_config: EvaluatorServerConfig):
async def run_experiment(
experiment_id: str, evaluator_server_config: EvaluatorServerConfig
):
loop = asyncio.get_running_loop()
print(f"Starting experiment {experiment_id}")
await loop.run_in_executor(None, lambda: experiments[experiment_id][0].start_simulations_thread(evaluator_server_config))
await loop.run_in_executor(
None,
lambda: experiments[experiment_id][0].start_simulations_thread(
evaluator_server_config
),
)
print(f"Experiment {experiment_id} done")


@app.post("/experiments/")
async def submit_experiment(experiment: Experiment, background_tasks: BackgroundTasks):
storage = open_storage(experiment.ert_config.ens_path, "w")
Expand All @@ -60,7 +88,10 @@ async def submit_experiment(experiment: Experiment, background_tasks: Background
status_queue,
)
except ValueError as e:
return HTTPException(status_code=420, detail=f"{experiment.args.mode} was not valid, failed with: {e}")
return HTTPException(
status_code=420,
detail=f"{experiment.args.mode} was not valid, failed with: {e}",
)

port_range = None
if model.queue_system == QueueSystem.LOCAL:
Expand All @@ -70,9 +101,12 @@ async def submit_experiment(experiment: Experiment, background_tasks: Background
experiment_id = str(uuid.uuid4())
experiments[experiment_id] = (model, status_queue)

background_tasks.add_task(run_experiment, experiment_id, evaluator_server_config=evaluator_server_config)
background_tasks.add_task(
run_experiment, experiment_id, evaluator_server_config=evaluator_server_config
)
return {"message": "Experiment Started", "experiment_id": experiment_id}


@app.put("/experiments/{experiment_id}/cancel")
async def cancel_experiment(experiment_id: str):
if experiment_id in experiments:
Expand All @@ -92,7 +126,7 @@ async def websocket_endpoint(websocket: WebSocket, experiment_id: str):
try:
item: StatusEvents = q.get(block=False)
except queue.Empty:
asyncio.sleep(0.01)
await asyncio.sleep(0.01)
continue

if isinstance(item, _UpdateEvent):
Expand All @@ -104,4 +138,3 @@ async def websocket_endpoint(websocket: WebSocket, experiment_id: str):
await asyncio.sleep(0.1)
if isinstance(item, EndEvent):
break

0 comments on commit 6eb7862

Please sign in to comment.