Skip to content

Commit

Permalink
fix: optional metrics in sl and lp trainers
Browse files Browse the repository at this point in the history
  • Loading branch information
RemyLau committed Jul 10, 2023
1 parent 7f4aea5 commit 771ce41
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/obnb/model_trainer/label_propagation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import numpy as np

from obnb.model_trainer.base import StandardTrainer
from obnb.typing import LogLevel, Optional
from obnb.typing import Callable, Dict, LogLevel, Optional


class LabelPropagationTrainer(StandardTrainer):
"""Label propagation trainer."""

def __init__(
self,
metrics,
metrics: Optional[Dict[str, Callable[[np.ndarray, np.ndarray], float]]] = None,
train_on="train",
log_level: LogLevel = "WARNING",
log_path: Optional[str] = None,
Expand Down
6 changes: 4 additions & 2 deletions src/obnb/model_trainer/supervised_learning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from obnb.model_trainer.base import StandardTrainer
from obnb.typing import LogLevel, Optional
from obnb.typing import Callable, Dict, LogLevel, Optional


class SupervisedLearningTrainer(StandardTrainer):
Expand All @@ -21,7 +23,7 @@ class SupervisedLearningTrainer(StandardTrainer):

def __init__(
self,
metrics,
metrics: Optional[Dict[str, Callable[[np.ndarray, np.ndarray], float]]] = None,
train_on="train",
log_level: LogLevel = "WARNING",
log_path: Optional[str] = None,
Expand Down

0 comments on commit 771ce41

Please sign in to comment.