Skip to content

Commit

Permalink
minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Al-Saffar committed May 11, 2024
1 parent f07cf5b commit b32c953
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions myresources/crocodile/deeplearning.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def compile(self, loss: Optional[Any] = None, optimizer: Optional[Any] = None, m
def fit(self, viz: bool = True, weight_name: Optional[str] = None,
val_sample_weight: Optional['npt.NDArray[np.float64]'] = None, sample_weight: Optional['npt.NDArray[np.float64]'] = None,
verbose: Union[int, str] = "auto", callbacks: Optional[list[Any]] = None,
validation_freq: int = 1, use_multiprocessing: bool = False,
validation_freq: int = 1,
**kwargs: Any):
assert self.data.split is not None, "Split your data before you start fitting."
x_train = [self.data.split[item] for item in self.data.specs.get_split_names(self.data.specs.ip_names, which_split="train")]
Expand Down Expand Up @@ -391,7 +391,7 @@ def fit(self, viz: bool = True, weight_name: Optional[str] = None,
batch_size=self.hp.batch_size, epochs=self.hp.epochs, shuffle=self.hp.shuffle,
)
default_settings.update(kwargs)
hist = self.model.fit(**default_settings, callbacks=callbacks, sample_weight=sample_weight, verbose=verbose, validation_freq=validation_freq, use_multiprocessing=use_multiprocessing)
hist = self.model.fit(**default_settings, callbacks=callbacks, sample_weight=sample_weight, verbose=verbose, validation_freq=validation_freq)
self.history.append(copy.deepcopy(hist.history)) # it is paramount to copy, cause source can change.
if viz:
artist = self.plot_loss()
Expand Down

0 comments on commit b32c953

Please sign in to comment.