Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues #122 and #154 #320

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pandas as pd
from pydantic import ValidationError

from zamba.pytorch.dataloaders import __getitem__

from zamba.models.config import (
EarlyStoppingConfig,
ModelConfig,
Expand Down Expand Up @@ -574,3 +576,45 @@ def test_validate_provided_species_and_use_default_model_labels(labels_absolute_
"Conflicting information between `use_default_model_labels=True` and species provided."
in error
)


"""
test case where splits get automatically assigned and then check that that csv gets written
"""
def test_split_files(labels_absolute_path, tmp_path):
#adding a test case where splits get automatically assigned

#arbitrary training
config = TrainConfig(
labels=labels_absolute_path,
model_name="time_distributed",
skip_load_validation=True,
save_dir=tmp_path / "my_model",
)
#make sure the split was automatically generated
split = pd.read_csv(tmp_path / "splits.csv")["split"].values
assert(split == ["train", "val", "train", "val"]).all

#checking if the csv gets written
assert os.path.exists(tmp_path/"splits.csv"), "File does not exist!"

config.preprocess_labels(config,config.labels)

"""
Test case for bad videos
To test this, have the video file path contain two videos, the first one being legit and the second being bad
A corrupted mp4 is included in the tests labeled monkeys.mp4
"""
def test_bad_video():

#we have to call the __getitem__ function in dataloaders
# the first index will be of a good video, the second of a bad video
for index in range(2):
video =__getitem__(index)
all_zero = np.all(video == 0)

#the good video
if index == 0:
assert all_zero is False
else:
assert all_zero is True
78 changes: 78 additions & 0 deletions zamba/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,3 +911,81 @@ def get_default_video_loader_config(cls, values):
values["video_loader_config"] = VideoLoaderConfig(**config_dict["video_loader_config"])

return values


class CustomBase(BaseModel):
class Config:
hardware_dependent = []

def dict_hardware_independent(self):
full_dict = self.dict()

# remove hardware dependent fields on children
for field_name, field_type in self.__fields__.items():
field_value = getattr(self, field_name, None)
if isinstance(field_value, CustomBase):
full_dict[field_name] = field_value.dict_hardware_independent()

# remove hardware dependent fields on this model
for f in self.Config.hardware_dependent:
full_dict.pop(f)

return full_dict


class SubModelC(CustomBase):
x: str

class Config:
hardware_dependent = ['x']

class SubModelA(CustomBase):
a: str
b: str
c: SubModelC

class Config:
hardware_dependent = ['a']

#adding additional sub models
class SubModelB(CustomBase):
d: str
e: str
f: SubModelA
g: SubModelC

class Config:
hardware_dependent = ['d', 'f']

class SubModelD(CustomBase):
h: SubModelB
i: str
j: SubModelA

class Config:
hardware_dependent = ['h', 'i', 'j']

class SubModelE(CustomBase):
k: str
l: str

class Config:
hardware_dependent = ['k', 'l']

# created instances for c, a, and b to construct instances of parent model
instance_c = SubModelC(x = 'x')
instance_a = SubModelA(a="a", b="b", c=SubModelC)
instance_b = SubModelB(d = 'd', e = 'e', f = instance_a, g = instance_c)

# now make instance of parent model d
instance_d = SubModelD(h = instance_b, i = 'i', j = instance_a)

# can decide whether you want this one or not since not used in parent models
instance_e = SubModelE(k = 'k', l = 'l')

# now we just serialize the different models
instance_a.dict_hardware_independent()
instance_b.dict_hardware_independent()
instance_c.dict_hardware_independent()
instance_d.dict_hardware_independent()
instance_e.dict_hardware_independent()