From 771ce418dbd946d860faefdb06787c352160a315 Mon Sep 17 00:00:00 2001 From: RemyLau Date: Mon, 10 Jul 2023 19:14:03 -0400 Subject: [PATCH] fix: optional metrics in sl and lp trainers --- src/obnb/model_trainer/label_propagation.py | 6 ++++-- src/obnb/model_trainer/supervised_learning.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/obnb/model_trainer/label_propagation.py b/src/obnb/model_trainer/label_propagation.py index ca3e9085..fd90282e 100644 --- a/src/obnb/model_trainer/label_propagation.py +++ b/src/obnb/model_trainer/label_propagation.py @@ -1,5 +1,7 @@ +import numpy as np + from obnb.model_trainer.base import StandardTrainer -from obnb.typing import LogLevel, Optional +from obnb.typing import Callable, Dict, LogLevel, Optional class LabelPropagationTrainer(StandardTrainer): @@ -7,7 +9,7 @@ class LabelPropagationTrainer(StandardTrainer): def __init__( self, - metrics, + metrics: Optional[Dict[str, Callable[[np.ndarray, np.ndarray], float]]] = None, train_on="train", log_level: LogLevel = "WARNING", log_path: Optional[str] = None, diff --git a/src/obnb/model_trainer/supervised_learning.py b/src/obnb/model_trainer/supervised_learning.py index cc8368b8..d8486119 100644 --- a/src/obnb/model_trainer/supervised_learning.py +++ b/src/obnb/model_trainer/supervised_learning.py @@ -1,5 +1,7 @@ +import numpy as np + from obnb.model_trainer.base import StandardTrainer -from obnb.typing import LogLevel, Optional +from obnb.typing import Callable, Dict, LogLevel, Optional class SupervisedLearningTrainer(StandardTrainer): @@ -21,7 +23,7 @@ class SupervisedLearningTrainer(StandardTrainer): def __init__( self, - metrics, + metrics: Optional[Dict[str, Callable[[np.ndarray, np.ndarray], float]]] = None, train_on="train", log_level: LogLevel = "WARNING", log_path: Optional[str] = None,