Skip to content

Commit

Permalink
added weights_only
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Al-Saffar committed Aug 7, 2024
1 parent 8af269f commit cc60ed9
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
3 changes: 2 additions & 1 deletion myresources/crocodile/deeplearning.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ def save_class(self, weights_only: bool = True, version: str = 'v0', strict: boo
try: generate_readme(get_hp_save_dir(hp), obj=self.__class__, desc=desc)
except Exception as ex: print(ex) # often fails because model is defined in main during experiments.
save_dir = get_hp_save_dir(hp).joinpath(f'{"weights" if weights_only else "model"}_save_{version}')
if weights_only: self.save_weights(save_dir.create())
if weights_only:
self.save_weights(save_dir.create())
else:
self.save_model(save_dir)

Expand Down
2 changes: 1 addition & 1 deletion myresources/crocodile/deeplearning_pytorch_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def main():
m.save_weights(save_dir=save_dir)
# save_onnx(save_dir=save_dir, model=m.model, dummy_ip=x)

m1 = BaseModel.load_model(save_dir=save_dir, map_location=None)
m1 = BaseModel.load_model(save_dir=save_dir, map_location=None, weights_only=True)
m_init = My2LayerNN(hp=hp)
# m_base = BaseModel(m_init, optimizer=optimizer, loss=loss, metrics=[])
m2 = BaseModel.load_weights(model=m_init, save_dir=save_dir, map_location=None)
Expand Down
4 changes: 2 additions & 2 deletions myresources/crocodile/deeplearning_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def summary(model: nn.Module, detailed: bool = False):

def save_model(self, save_dir: P): t.save(self.model, save_dir.joinpath("model.pth"))
@staticmethod
def load_model(save_dir: P, map_location: Union[str, Device, None]):
def load_model(save_dir: P, map_location: Union[str, Device, None], weights_only: bool):
if map_location is None and t.cuda.is_available():
map_location = "cpu"
model: nn.Module = t.load(save_dir.joinpath("model.pth"), map_location=map_location, weights_only=True) # type: ignore
model: nn.Module = t.load(save_dir.joinpath("model.pth"), map_location=map_location, weights_only=weights_only) # type: ignore
model.eval()
import traceback
try:
Expand Down

0 comments on commit cc60ed9

Please sign in to comment.