Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pc/ Fix AttributeError: 'Loss' object has no attr 'func' #29

Merged
merged 11 commits into from
Jul 15, 2022
30 changes: 18 additions & 12 deletions neptune_fastai/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,14 @@ def _total_model_parameters(self) -> int:

@property
def _trainable_model_parameters(self) -> int:
return sum([p.numel() for p in trainable_params(self.learn)])
return sum(p.numel() for p in trainable_params(self.learn))

@property
def _optimizer_criterion(self) -> str:
return repr(self.loss_func.func)
if hasattr(self.learn.loss_func, 'func'):
return repr(self.loss_func.func)
else:
return repr(self.loss_func)
Raalsky marked this conversation as resolved.
Show resolved Hide resolved

@property
def _optimizer_hyperparams(self) -> Optional[dict]:
Expand Down Expand Up @@ -242,14 +245,10 @@ def _target(self) -> str:
return 'training' if self.learn.training else 'validation'

def _log_model_configuration(self):
self.neptune_run[f'{self.base_namespace}/config'] = {
config = {
'device': self._device,
'batch_size': self._batch_size,
'model': {
'vocab': {
'details': self._vocab,
'total': len(self._vocab)
},
'params': {
'total': self._total_model_parameters,
'trainable_params': self._trainable_model_parameters,
Expand All @@ -263,12 +262,19 @@ def _log_model_configuration(self):
}
}

if hasattr(self.learn.dls, 'vocab'):
config['model']['vocab'] = {
'details': self._vocab,
'total': len(self._vocab)
}

self.neptune_run[f'{self.base_namespace}/config'] = config

def after_create(self):
if not hasattr(self, 'save_model'):
if self.upload_saved_models:
warnings.warn(
'NeptuneCallback: SaveModelCallback is necessary for uploading model checkpoints.'
)
if not hasattr(self, 'save_model') and self.upload_saved_models:
warnings.warn(
'NeptuneCallback: SaveModelCallback is necessary for uploading model checkpoints.'
)

def before_fit(self):
self._log_model_configuration()
Expand Down