diff --git a/pyiron_base/project/archiving/import_archive.py b/pyiron_base/project/archiving/import_archive.py index 775cc5dd3..5b8ccab9e 100644 --- a/pyiron_base/project/archiving/import_archive.py +++ b/pyiron_base/project/archiving/import_archive.py @@ -1,3 +1,4 @@ +import io import os import posixpath import tarfile @@ -130,3 +131,29 @@ def get_dataframe(origin_path: str, csv_file_name: str = "export.csv") -> "DataF if csv_file_name in files: return pandas.read_csv(os.path.join(root, csv_file_name), index_col=0) raise FileNotFoundError(f"File: {csv_file_name} was not found.") + + +def inspect_csv(tar_path: str, csv_file: str = "export.csv"): + """ + Inspect the csv file inside a tar archive. + + Args: + tar_path (str): Path to the tar archive. + csv_file (str): Name of the csv file. + + Returns: + pandas.DataFrame: Job table. + """ + with tarfile.open(tar_path, mode="r:gz") as tar: + for member in tar.getmembers(): + # Check if the member is a file and ends with the desired csv file name + if member.isfile() and member.name.endswith(f"/{csv_file}"): + # Extract the file object + extracted_file = tar.extractfile(member) + + if extracted_file: + # Read the file content + return pandas.read_csv( + io.StringIO(extracted_file.read().decode("utf-8")), index_col=0 + ) + raise FileNotFoundError(f"File: {csv_file} in {tar_path} was not found.") diff --git a/pyiron_base/project/generic.py b/pyiron_base/project/generic.py index 237feaff1..2ac6b5905 100644 --- a/pyiron_base/project/generic.py +++ b/pyiron_base/project/generic.py @@ -1981,6 +1981,21 @@ def pack( df=export_archive.export_database(self.job_table()), ) + def unpack_csv(self, tar_path: str): + """ + Import job table from a csv file and copy the content of a project + directory from a given path. + + Args: + tar_path (str): the relative path of a directory from which the + project directory is copied. + csv_file (str): the name of the csv file. + + Returns: + pandas.DataFrame: job table + """ + return import_archive.inspect_csv(tar_path=tar_path, csv_file="export.csv") + def unpack(self, origin_path, **kwargs): """ by this function, job table is imported from a given csv file, diff --git a/tests/unit/archiving/test_import.py b/tests/unit/archiving/test_import.py index 4ccc93d0e..0f037ecf2 100644 --- a/tests/unit/archiving/test_import.py +++ b/tests/unit/archiving/test_import.py @@ -1,7 +1,7 @@ import os import unittest from pyiron_base import Project -from pandas._testing import assert_frame_equal +import pandas as pd from filecmp import dircmp from shutil import rmtree, copytree import tarfile @@ -47,6 +47,11 @@ def tearDown(self): super().tearDown() self.imp_pr.remove_jobs(recursive=True, silently=True) + def test_inspect(self): + df = self.pr.unpack_csv(self.arch_dir_comp + ".tar.gz") + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 1) + def test_import_csv(self): df_original = self.pr.job_table() df_import = self.imp_pr.job_table() @@ -58,7 +63,7 @@ def test_import_csv(self): df_original.drop("id", inplace=True, axis=1) df_import["hamversion"] = float(df_import["hamversion"]) df_original["hamversion"] = float(df_original["hamversion"]) - assert_frame_equal(df_original, df_import) + pd._testing.assert_frame_equal(df_original, df_import) def test_import_compressed(self): path_original = self.pr.path @@ -192,6 +197,7 @@ def test_import_old_tar(self): dirs_exist_ok=True, ) pr = Project("old_tar") + self.assertRaises(FileNotFoundError, pr.unpack_csv, "test_pack.tar.gz") pr.unpack(origin_path="test_pack.tar.gz") job = pr.load("toy") self.assertEqual(job.job_name, "toy")