Skip to content

Commit

Permalink
resolved two issues
Browse files Browse the repository at this point in the history
  • Loading branch information
peteford21 committed Apr 24, 2024
1 parent 70e79f2 commit 774aa2b
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pandas as pd
from pydantic import ValidationError

from zamba.pytorch.dataloaders import __getitem__

from zamba.models.config import (
EarlyStoppingConfig,
ModelConfig,
Expand Down Expand Up @@ -574,3 +576,45 @@ def test_validate_provided_species_and_use_default_model_labels(labels_absolute_
"Conflicting information between `use_default_model_labels=True` and species provided."
in error
)


"""
test case where splits get automatically assigned and then check that that csv gets written
"""
def test_split_files(labels_absolute_path, tmp_path):
#adding a test case where splits get automatically assigned

#arbitrary training
config = TrainConfig(
labels=labels_absolute_path,
model_name="time_distributed",
skip_load_validation=True,
save_dir=tmp_path / "my_model",
)
#make sure the split was automatically generated
split = pd.read_csv(tmp_path / "splits.csv")["split"].values
assert(split == ["train", "val", "train", "val"]).all

#checking if the csv gets written
assert os.path.exists(tmp_path/"splits.csv"), "File does not exist!"

config.preprocess_labels(config,config.labels)

"""
Test case for bad videos
To test this, have the video file path contain two videos, the first one being legit and the second being bad
A corrupted mp4 is included in the tests labeled monkeys.mp4
"""
def test_bad_video():

#we have to call the __getitem__ function in dataloaders
# the first index will be of a good video, the second of a bad video
for index in range(2):
video =__getitem__(index)
all_zero = np.all(video == 0)

#the good video
if index == 0:
assert all_zero is False
else:
assert all_zero is True
78 changes: 78 additions & 0 deletions zamba/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,3 +911,81 @@ def get_default_video_loader_config(cls, values):
values["video_loader_config"] = VideoLoaderConfig(**config_dict["video_loader_config"])

return values


class CustomBase(BaseModel):
class Config:
hardware_dependent = []

def dict_hardware_independent(self):
full_dict = self.dict()

# remove hardware dependent fields on children
for field_name, field_type in self.__fields__.items():
field_value = getattr(self, field_name, None)
if isinstance(field_value, CustomBase):
full_dict[field_name] = field_value.dict_hardware_independent()

# remove hardware dependent fields on this model
for f in self.Config.hardware_dependent:
full_dict.pop(f)

return full_dict


class SubModelC(CustomBase):
x: str

class Config:
hardware_dependent = ['x']

class SubModelA(CustomBase):
a: str
b: str
c: SubModelC

class Config:
hardware_dependent = ['a']

#adding additional sub models
class SubModelB(CustomBase):
d: str
e: str
f: SubModelA
g: SubModelC

class Config:
hardware_dependent = ['d', 'f']

class SubModelD(CustomBase):
h: SubModelB
i: str
j: SubModelA

class Config:
hardware_dependent = ['h', 'i', 'j']

class SubModelE(CustomBase):
k: str
l: str

class Config:
hardware_dependent = ['k', 'l']

# created instances for c, a, and b to construct instances of parent model
instance_c = SubModelC(x = 'x')
instance_a = SubModelA(a="a", b="b", c=SubModelC)
instance_b = SubModelB(d = 'd', e = 'e', f = instance_a, g = instance_c)

# now make instance of parent model d
instance_d = SubModelD(h = instance_b, i = 'i', j = instance_a)

# can decide whether you want this one or not since not used in parent models
instance_e = SubModelE(k = 'k', l = 'l')

# now we just serialize the different models
instance_a.dict_hardware_independent()
instance_b.dict_hardware_independent()
instance_c.dict_hardware_independent()
instance_d.dict_hardware_independent()
instance_e.dict_hardware_independent()

0 comments on commit 774aa2b

Please sign in to comment.