diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f5f33dc9..8d5f7f17 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -18,10 +18,10 @@ jobs: - name: Install Poetry run: | pipx install poetry - - name: Set up Python ${{ matrix.python-version }} + - name: Set up Python 3.11 uses: actions/setup-python@v3 with: - python-version: ${{ matrix.python-version }} + python-version: "3.11" cache: 'poetry' - name: Install lint tools run: | @@ -40,6 +40,7 @@ jobs: fail-fast: false matrix: python-version: + - "3.11" - "3.10" - 3.9 - 3.8 diff --git a/tests/integration/test_meta.py b/tests/integration/test_meta.py index 0d6b50cc..e94785fd 100644 --- a/tests/integration/test_meta.py +++ b/tests/integration/test_meta.py @@ -40,23 +40,22 @@ class NodeWithEnvParam(NodeWithEnv): OMP_NUM_THREADS = zntrack.meta.Environment("1", is_parameter=True) -def test_NodeWithEnvNone(proj_path): - with zntrack.Project() as proj: - _ = NodeWithEnvNone() # the actual test is inside the run method. - proj.run() - +@pytest.mark.parametrize("eager", [True, False]) +def test_NodeWithMeta(proj_path, eager): + with zntrack.Project() as project: + node_w_meta = NodeWithMeta() -def test_NodeWithMeta(proj_path): - NodeWithMeta().write_graph() + project.run(eager=eager) + if not eager: + node_w_meta.load() - node_w_meta = NodeWithMeta.from_rev() assert node_w_meta.author == "Fabian" - dvc_yaml = yaml.safe_load(pathlib.Path("dvc.yaml").read_text()) - assert dvc_yaml["stages"]["NodeWithMeta"]["meta"] == { - "author": "Fabian", - "title": "Test Node", - } + +def test_NodeWithEnvNone(proj_path): + with zntrack.Project() as proj: + _ = NodeWithEnvNone() # the actual test is inside the run method. + proj.run() class CombinedNodeWithMeta(zntrack.Node): diff --git a/zntrack/fields/meta/__init__.py b/zntrack/fields/meta/__init__.py index e7a02bf2..9e387650 100644 --- a/zntrack/fields/meta/__init__.py +++ b/zntrack/fields/meta/__init__.py @@ -1,8 +1,10 @@ """Additional fields that are neither dvc/zn i/o fields.""" +import json import pathlib import typing import yaml +import znjson from zntrack.fields.field import Field, FieldGroup from zntrack.utils import file_io @@ -16,6 +18,7 @@ class Text(Field): dvc_option: str = None group = FieldGroup.PARAMETER + use_dvc_yaml: bool = False def get_files(self, instance) -> list: """Get the params.yaml file.""" @@ -23,16 +26,30 @@ def get_files(self, instance) -> list: def save(self, instance): """Save the field to disk.""" - file_io.update_meta( - file=pathlib.Path("dvc.yaml"), - node_name=instance.name, - data={self.name: getattr(instance, self.name)}, - ) + value = getattr(instance, self.name) + if pathlib.Path("dvc.yaml").exists() and self.use_dvc_yaml: + file_io.update_meta( + file=pathlib.Path("dvc.yaml"), + node_name=instance.name, + data={self.name: value}, + ) + else: + # load from zntrack.json + self._write_value_to_config(value, instance, encoder=znjson.ZnEncoder) def get_data(self, instance: "Node") -> any: """Get the value of the field from the file.""" - dvc_dict = yaml.safe_load(instance.state.fs.read_text("dvc.yaml")) - return dvc_dict["stages"][instance.name]["meta"].get(self.name, None) + if pathlib.Path("dvc.yaml").exists() and self.use_dvc_yaml: + dvc_dict = yaml.safe_load(instance.state.fs.read_text("dvc.yaml")) + return dvc_dict["stages"][instance.name]["meta"].get(self.name, None) + else: + # load from zntrack.json + zntrack_dict = json.loads( + instance.state.fs.read_text("zntrack.json"), + ) + return json.loads( + json.dumps(zntrack_dict[instance.name][self.name]), cls=znjson.ZnDecoder + ) def get_stage_add_argument(self, instance) -> typing.List[tuple]: """Get the dvc command for this field."""