Skip to content

Commit

Permalink
Merge pull request #1381 from pyiron/script_remove_hdf
Browse files Browse the repository at this point in the history
ScriptJob: from to_hdf() and from_hdf()
  • Loading branch information
jan-janssen authored Mar 26, 2024
2 parents 0c6686f + a688821 commit 7303661
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 50 deletions.
64 changes: 16 additions & 48 deletions pyiron_base/jobs/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,54 +274,22 @@ def set_input_to_read_only(self):
super().set_input_to_read_only()
self.input.read_only = True

def to_hdf(self, hdf=None, group_name=None):
"""
Store the ScriptJob in an HDF5 file
Args:
hdf (ProjectHDFio): HDF5 group object - optional
group_name (str): HDF5 subgroup name - optional
"""
super(ScriptJob, self).to_hdf(hdf=hdf, group_name=group_name)
with self.project_hdf5.open("input") as hdf5_input:
hdf5_input["path"] = self._script_path
hdf5_input["parallel"] = self._enable_mpi4py
self.input.to_hdf(hdf5_input)

def from_hdf(self, hdf=None, group_name=None):
"""
Restore the ScriptJob from an HDF5 file
Args:
hdf (ProjectHDFio): HDF5 group object - optional
group_name (str): HDF5 subgroup name - optional
"""
super(ScriptJob, self).from_hdf(hdf=hdf, group_name=group_name)
if "HDF_VERSION" in self.project_hdf5.list_nodes():
version = self.project_hdf5["HDF_VERSION"]
else:
version = "0.1.0"
if version == "0.1.0":
with self.project_hdf5.open("input") as hdf5_input:
try:
self.script_path = hdf5_input["path"]
gp = GenericParameters(table_name="custom_dict")
gp.from_hdf(hdf5_input)
for k in gp.keys():
self.input[k] = gp[k]
except TypeError:
pass
elif version == "0.2.0":
with self.project_hdf5.open("input") as hdf5_input:
if "parallel" in hdf5_input.list_nodes():
self._enable_mpi4py = hdf5_input["parallel"]
try:
self.script_path = hdf5_input["path"]
except TypeError:
pass
self.input.from_hdf(hdf5_input)
else:
raise ValueError("Cannot handle hdf version: {}".format(version))
def to_dict(self):
job_dict = super().to_dict()
job_dict["input/path"] = self._script_path
job_dict["input/parallel"] = self._enable_mpi4py
job_dict["input/custom_dict"] = self.input.to_builtin()
return job_dict

def from_dict(self, job_dict):
super().from_dict(job_dict=job_dict)
if "parallel" in job_dict["input"].keys():
self._enable_mpi4py = job_dict["input"]["parallel"]
path = job_dict["input"]["path"]
if path is not None:
self.script_path = path
if "custom_dict" in job_dict["input"].keys():
self.input.update(job_dict["input"]["custom_dict"])

def write_input(self):
"""
Expand Down
7 changes: 6 additions & 1 deletion pyiron_base/project/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ def load():
hdf_file = str(hdf_file).replace("\\", "/") + ".h5"
if Path(hdf_file).exists():
obj = DataContainer()
obj.from_hdf(hdf=FileHDFio(hdf_file), group_name=folder + "/input/custom_dict")
hdf_file_obj = FileHDFio(hdf_file)
hdf_input = hdf_file_obj[folder + "/input"]
if "custom_dict" in hdf_input.list_nodes():
obj.update(hdf_file_obj[folder + "/input/custom_dict"])
else: # Backwards compatibility
obj.from_hdf(hdf=hdf_file_obj, group_name=folder + "/input/custom_dict")
obj["project_dir"] = str(project_folder)
return obj
elif Path("input.json").exists():
Expand Down
2 changes: 1 addition & 1 deletion tests/job/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_notebook_input(self):
self.job.input['value'] = 300
self.job.save()
self.assertTrue(
"custom_dict" in self.job["input"].list_groups(),
"custom_dict" in self.job["input"].list_nodes(),
msg="Input not saved in the 'custom_dict' group in HDF"
)

Expand Down

0 comments on commit 7303661

Please sign in to comment.