Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary wrapper forward_model_data_to_json #9551

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,33 +221,6 @@ def handle_default(fm_step: ForwardModelStep, arg: str) -> str:
}


def forward_model_data_to_json(
substitutions: Substitutions,
forward_model_steps: list[ForwardModelStep],
env_vars: dict[str, str],
env_pr_fm_step: dict[str, dict[str, Any]] | None = None,
user_config_file: str | None = "",
run_id: str | None = None,
iens: int = 0,
itr: int = 0,
context_env: dict[str, str] | None = None,
):
if context_env is None:
context_env = {}
if env_pr_fm_step is None:
env_pr_fm_step = {}
return create_forward_model_json(
context=substitutions,
forward_model_steps=forward_model_steps,
user_config_file=user_config_file,
env_vars={**env_vars, **context_env},
env_pr_fm_step=env_pr_fm_step,
run_id=run_id,
iens=iens,
itr=itr,
)


@dataclass
class ErtConfig:
DEFAULT_ENSPATH: ClassVar[str] = "storage"
Expand Down
9 changes: 4 additions & 5 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import xarray as xr
from numpy.random import SeedSequence

from ert.config.ert_config import forward_model_data_to_json
from ert.config.ert_config import create_forward_model_json
from ert.config.forward_model_step import ForwardModelStep
from ert.config.model_config import ModelConfig
from ert.substitutions import Substitutions, substitute_runpath_name
Expand Down Expand Up @@ -272,16 +272,15 @@ def create_run_path(
path = run_path / "jobs.json"
_backup_if_existing(path)

forward_model_output = forward_model_data_to_json(
substitutions=substitutions,
forward_model_output: dict[str, Any] = create_forward_model_json(
context=substitutions,
forward_model_steps=forward_model_steps,
user_config_file=user_config_file,
env_vars=env_vars,
env_vars={**env_vars, **context_env},
env_pr_fm_step=env_pr_fm_step,
run_id=run_arg.run_id,
iens=run_arg.iens,
itr=ensemble.iteration,
context_env=context_env,
)
with open(run_path / "jobs.json", mode="wb") as fptr:
fptr.write(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ert.config import ErtConfig, ForwardModelStep
from ert.config.ert_config import (
_forward_model_step_from_config_file,
forward_model_data_to_json,
create_forward_model_json,
)
from ert.substitutions import Substitutions

Expand Down Expand Up @@ -295,8 +295,8 @@ def test_config_path_and_file(context):
substitutions=context,
user_config_file="path_to_config_file/config.ert",
)
steps_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
steps_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand All @@ -318,8 +318,8 @@ def test_no_steps(context):
user_config_file="path_to_config_file/config.ert",
)

data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand All @@ -339,8 +339,8 @@ def test_one_step(fm_step_list, context):
substitutions=context,
)

data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand All @@ -357,8 +357,8 @@ def run_all(fm_steplist, context):
substitutions=context,
)

data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -400,8 +400,8 @@ def test_that_values_with_brackets_are_ommitted(caplog, fm_step_list, context):
forward_model_steps=forward_model_list, substitutions=context
)

data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -558,8 +558,8 @@ def test_forward_model_job(job, forward_model, expected_args):

forward_model = ert_config.forward_model_steps

data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -589,8 +589,8 @@ def test_that_config_path_is_the_directory_of_the_main_ert_config():
fout.write("FORWARD_MODEL job_name")

ert_config = ErtConfig.from_file("config_file.ert")
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -661,8 +661,8 @@ def test_simulation_job(job, forward_model, expected_args):
fout.write(forward_model)

ert_config = ErtConfig.from_file("config_file.ert")
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -696,8 +696,8 @@ def test_that_private_over_global_args_gives_logging_message(caplog):
fout.write("FORWARD_MODEL job_name(<ARG>=B)")

ert_config = ErtConfig.from_file("config_file.ert")
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -735,8 +735,8 @@ def test_that_private_over_global_args_does_not_give_logging_message_for_argpass
fout.write("FORWARD_MODEL job_name(<ARG>=<ARG>)")

ert_config = ErtConfig.from_file("config_file.ert")
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -786,8 +786,8 @@ def test_that_environment_variables_are_set_in_forward_model(
fout.write(forward_model)

ert_config = ErtConfig.from_file("config_file.ert")
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down Expand Up @@ -817,8 +817,8 @@ def test_that_executables_in_path_are_not_made_realpath(tmp_path):
)

ert_config = ErtConfig.from_file(str(config_file))
data = forward_model_data_to_json(
substitutions=ert_config.substitutions,
data = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
Expand Down
50 changes: 30 additions & 20 deletions tests/ert/unit_tests/config/test_ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from ert.config import AnalysisConfig, ConfigValidationError, ErtConfig, HookRuntime
from ert.config.ert_config import (
forward_model_data_to_json,
create_forward_model_json,
site_config_location,
)
from ert.config.parsing import ConfigKeys, ConfigWarning
Expand Down Expand Up @@ -857,11 +857,12 @@ def test_fm_step_config_via_plugin_ends_up_json_data(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert step_json["jobList"][0]["environment"]["FOO"] == "bar"

Expand All @@ -885,12 +886,14 @@ def test_fm_step_config_via_plugin_does_not_leak_to_other_step(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)

assert "FOO" not in step_json["jobList"][0]["environment"]


Expand All @@ -913,12 +916,14 @@ def test_fm_step_config_via_plugin_has_key_names_uppercased(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)

assert step_json["jobList"][0]["environment"]["FOO"] == "bar"


Expand All @@ -941,11 +946,12 @@ def test_fm_step_config_via_plugin_stringifies_python_objects(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert step_json["jobList"][0]["environment"]["FOO"] == "{'a_dict_as_value': 1}"

Expand All @@ -972,11 +978,12 @@ def test_fm_step_config_via_plugin_ignores_conflict_with_setenv(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert step_json["global_environment"]["FOO"] == "bar_from_setenv"
assert step_json["jobList"][0]["environment"]["FOO"] == "bar_from_plugin"
Expand All @@ -1002,11 +1009,12 @@ def test_fm_step_config_via_plugin_does_not_override_default_env(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert (
step_json["jobList"][0]["environment"]["_ERT_RUNPATH"]
Expand Down Expand Up @@ -1034,11 +1042,12 @@ def test_fm_step_config_via_plugin_is_substituted_for_defines(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert step_json["jobList"][0]["environment"]["FOO"] == "define_works"

Expand All @@ -1062,11 +1071,12 @@ def test_fm_step_config_via_plugin_is_dropped_if_not_define_exists(monkeypatch):
encoding="utf-8",
)
ert_config = ErtConfig.with_plugins().from_file("config.ert")
step_json = forward_model_data_to_json(
substitutions=ert_config.substitutions,
step_json = create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
env_pr_fm_step=ert_config.env_pr_fm_step,
run_id=None,
)
assert "FOO" not in step_json["jobList"][0]["environment"]

Expand Down Expand Up @@ -1533,13 +1543,13 @@ def test_validate_no_logs_when_overwriting_with_same_value(caplog):

with caplog.at_level(logging.INFO):
ert_config = ErtConfig.from_file("config_file.ert")
forward_model_data_to_json(
substitutions=ert_config.substitutions,
create_forward_model_json(
context=ert_config.substitutions,
forward_model_steps=ert_config.forward_model_steps,
env_vars=ert_config.env_vars,
user_config_file=ert_config.user_config_file,
run_id="0",
iens="0",
iens=0,
itr=0,
)

Expand Down
Loading
Loading