Skip to content

Commit

Permalink
Merge pull request #28 from neptune-ai/pc/vocab-error-fix
Browse files Browse the repository at this point in the history
Pc/vocab error fix
  • Loading branch information
Blaizzy authored Jul 15, 2022
2 parents a995d8c + e5edbce commit 4ca89fb
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions neptune_fastai/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _total_model_parameters(self) -> int:

@property
def _trainable_model_parameters(self) -> int:
return sum([p.numel() for p in trainable_params(self.learn)])
return sum(p.numel() for p in trainable_params(self.learn))

@property
def _optimizer_criterion(self) -> str:
Expand Down Expand Up @@ -242,14 +242,10 @@ def _target(self) -> str:
return 'training' if self.learn.training else 'validation'

def _log_model_configuration(self):
self.neptune_run[f'{self.base_namespace}/config'] = {
config = {
'device': self._device,
'batch_size': self._batch_size,
'model': {
'vocab': {
'details': self._vocab,
'total': len(self._vocab)
},
'params': {
'total': self._total_model_parameters,
'trainable_params': self._trainable_model_parameters,
Expand All @@ -263,12 +259,19 @@ def _log_model_configuration(self):
}
}

if hasattr(self.learn.dls, 'vocab'):
config['model']['vocab'] = {
'details': self._vocab,
'total': len(self._vocab)
}

self.neptune_run[f'{self.base_namespace}/config'] = config

def after_create(self):
if not hasattr(self, 'save_model'):
if self.upload_saved_models:
warnings.warn(
'NeptuneCallback: SaveModelCallback is necessary for uploading model checkpoints.'
)
if not hasattr(self, 'save_model') and self.upload_saved_models:
warnings.warn(
'NeptuneCallback: SaveModelCallback is necessary for uploading model checkpoints.'
)

def before_fit(self):
self._log_model_configuration()
Expand Down

0 comments on commit 4ca89fb

Please sign in to comment.