From 774aa2b3f2f59ead2b8964e3bb1bb14f09a11428 Mon Sep 17 00:00:00 2001 From: peteford Date: Wed, 24 Apr 2024 17:16:48 -0400 Subject: [PATCH] resolved two issues --- tests/test_config.py | 44 ++++++++++++++++++++++++ zamba/models/config.py | 78 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+) diff --git a/tests/test_config.py b/tests/test_config.py index 2a6c77d0..196dfd79 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,6 +7,8 @@ import pandas as pd from pydantic import ValidationError +from zamba.pytorch.dataloaders import __getitem__ + from zamba.models.config import ( EarlyStoppingConfig, ModelConfig, @@ -574,3 +576,45 @@ def test_validate_provided_species_and_use_default_model_labels(labels_absolute_ "Conflicting information between `use_default_model_labels=True` and species provided." in error ) + + +""" +test case where splits get automatically assigned and then check that that csv gets written +""" +def test_split_files(labels_absolute_path, tmp_path): + #adding a test case where splits get automatically assigned + + #arbitrary training + config = TrainConfig( + labels=labels_absolute_path, + model_name="time_distributed", + skip_load_validation=True, + save_dir=tmp_path / "my_model", + ) + #make sure the split was automatically generated + split = pd.read_csv(tmp_path / "splits.csv")["split"].values + assert(split == ["train", "val", "train", "val"]).all + + #checking if the csv gets written + assert os.path.exists(tmp_path/"splits.csv"), "File does not exist!" + + config.preprocess_labels(config,config.labels) + +""" +Test case for bad videos +To test this, have the video file path contain two videos, the first one being legit and the second being bad +A corrupted mp4 is included in the tests labeled monkeys.mp4 +""" +def test_bad_video(): + + #we have to call the __getitem__ function in dataloaders + # the first index will be of a good video, the second of a bad video + for index in range(2): + video =__getitem__(index) + all_zero = np.all(video == 0) + + #the good video + if index == 0: + assert all_zero is False + else: + assert all_zero is True \ No newline at end of file diff --git a/zamba/models/config.py b/zamba/models/config.py index bd826381..9adab2c4 100644 --- a/zamba/models/config.py +++ b/zamba/models/config.py @@ -911,3 +911,81 @@ def get_default_video_loader_config(cls, values): values["video_loader_config"] = VideoLoaderConfig(**config_dict["video_loader_config"]) return values + + +class CustomBase(BaseModel): + class Config: + hardware_dependent = [] + + def dict_hardware_independent(self): + full_dict = self.dict() + + # remove hardware dependent fields on children + for field_name, field_type in self.__fields__.items(): + field_value = getattr(self, field_name, None) + if isinstance(field_value, CustomBase): + full_dict[field_name] = field_value.dict_hardware_independent() + + # remove hardware dependent fields on this model + for f in self.Config.hardware_dependent: + full_dict.pop(f) + + return full_dict + + +class SubModelC(CustomBase): + x: str + + class Config: + hardware_dependent = ['x'] + +class SubModelA(CustomBase): + a: str + b: str + c: SubModelC + + class Config: + hardware_dependent = ['a'] + +#adding additional sub models +class SubModelB(CustomBase): + d: str + e: str + f: SubModelA + g: SubModelC + + class Config: + hardware_dependent = ['d', 'f'] + +class SubModelD(CustomBase): + h: SubModelB + i: str + j: SubModelA + + class Config: + hardware_dependent = ['h', 'i', 'j'] + +class SubModelE(CustomBase): + k: str + l: str + + class Config: + hardware_dependent = ['k', 'l'] + +# created instances for c, a, and b to construct instances of parent model +instance_c = SubModelC(x = 'x') +instance_a = SubModelA(a="a", b="b", c=SubModelC) +instance_b = SubModelB(d = 'd', e = 'e', f = instance_a, g = instance_c) + +# now make instance of parent model d +instance_d = SubModelD(h = instance_b, i = 'i', j = instance_a) + +# can decide whether you want this one or not since not used in parent models +instance_e = SubModelE(k = 'k', l = 'l') + +# now we just serialize the different models +instance_a.dict_hardware_independent() +instance_b.dict_hardware_independent() +instance_c.dict_hardware_independent() +instance_d.dict_hardware_independent() +instance_e.dict_hardware_independent() \ No newline at end of file