Skip to content

Commit 83789ee

Browse files
committed
Handle tuples, test
1 parent b046a12 commit 83789ee

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

src/ccompass/core.py

+14
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,9 @@ def series_representer(dumper, data):
770770
def float64_representer(dumper, data):
771771
return dumper.represent_float(float(data))
772772

773+
def tuple_representer(dumper, data):
774+
return dumper.represent_sequence("!tuple", data)
775+
773776
yaml.add_representer(
774777
np.float64, float64_representer, Dumper=yaml.SafeDumper
775778
)
@@ -782,6 +785,9 @@ def float64_representer(dumper, data):
782785
yaml.add_representer(
783786
pd.Series, series_representer, Dumper=yaml.SafeDumper
784787
)
788+
yaml.add_representer(
789+
tuple, tuple_representer, Dumper=yaml.SafeDumper
790+
)
785791

786792
with open(temp_dir / "session.yaml", "w") as f:
787793
yaml.safe_dump(self.model_dump(), f)
@@ -820,6 +826,7 @@ def ndarray_constructor(loader, node):
820826
def series_constructor(loader, node):
821827
"""Custom YAML constructor for pandas Series."""
822828
file_path = temp_dir / loader.construct_scalar(node)
829+
print(file_path, type(file_path))
823830
df = pd.read_csv(
824831
file_path,
825832
sep="\t",
@@ -828,8 +835,12 @@ def series_constructor(loader, node):
828835
float_precision="round_trip",
829836
)
830837
assert df.shape[1] == 1
838+
print(df.iloc[:, 0], type(df.iloc[:, 0]))
831839
return df.iloc[:, 0]
832840

841+
def tuple_constructor(loader, node):
842+
return tuple(loader.construct_sequence(node))
843+
833844
yaml.add_constructor(
834845
"!pandas.DataFrame",
835846
dataframe_constructor,
@@ -841,6 +852,9 @@ def series_constructor(loader, node):
841852
yaml.add_constructor(
842853
"!pandas.Series", series_constructor, Loader=yaml.SafeLoader
843854
)
855+
yaml.add_constructor(
856+
"!tuple", tuple_constructor, Loader=yaml.SafeLoader
857+
)
844858

845859
with open(temp_dir / "session.yaml") as f:
846860
data = yaml.safe_load(f)

tests/test_full_analysis.py

+11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33

44
import numpy as np
5+
from test_session_model import assert_session_equal
56

67
from ccompass._testing.synthetic_data import (
78
SyntheticDataConfig,
@@ -201,3 +202,13 @@ def test_full():
201202

202203
...
203204
sess.to_numpy(Path(__file__).parent / "session_test_full.npy")
205+
sess.to_zip(Path(__file__).parent / "session_test_full.ccompass")
206+
207+
sess2 = SessionModel.from_numpy(
208+
Path(__file__).parent / "session_test_full.npy"
209+
)
210+
assert_session_equal(sess, sess2)
211+
sess2 = SessionModel.from_zip(
212+
Path(__file__).parent / "session_test_full.ccompass"
213+
)
214+
assert_session_equal(sess, sess2)

tests/test_session_model.py

+9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from pathlib import Path
44
from tempfile import TemporaryDirectory
55

6+
import numpy as np
67
import pandas as pd
8+
import pydantic
79

810
from ccompass.main_gui import SessionModel
911

@@ -22,6 +24,11 @@ def test_serialization():
2224

2325
def assert_equal(obj1, obj2):
2426
"""Check if two objects are equal."""
27+
if isinstance(obj1, pydantic.BaseModel):
28+
assert isinstance(obj2, pydantic.BaseModel)
29+
assert_equal(obj1.model_dump(), obj2.model_dump())
30+
return
31+
2532
if isinstance(obj1, dict):
2633
for key in obj1:
2734
assert key in obj2
@@ -33,6 +40,8 @@ def assert_equal(obj1, obj2):
3340
pd.testing.assert_frame_equal(obj1, obj2, check_dtype=False)
3441
elif isinstance(obj1, pd.Series):
3542
pd.testing.assert_series_equal(obj1, obj2)
43+
elif isinstance(obj1, np.ndarray):
44+
np.testing.assert_almost_equal(obj1, obj2)
3645
elif isinstance(obj1, float) and pd.isna(obj1):
3746
assert pd.isna(obj2)
3847
else:

0 commit comments

Comments
 (0)