Skip to content

Commit

Permalink
Fix #965: add log loss metric
Browse files Browse the repository at this point in the history
  • Loading branch information
trentmc committed Apr 30, 2024
1 parent 0abf1ec commit aee1a09
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
8 changes: 6 additions & 2 deletions pdr_backend/sim/sim_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import polars as pl

from enforce_typing import enforce_types
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import log_loss, precision_recall_fscore_support
from statsmodels.stats.proportion import proportion_confint

from pdr_backend.aimodel.aimodel_data_factory import AimodelDataFactory
Expand Down Expand Up @@ -183,7 +183,11 @@ def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame):
average="binary",
zero_division=0.0,
)
st.clm.update(acc_est, acc_l, acc_u, f1, precision, recall)
if min(st.ytrues) == max(st.ytrues):
loss = 1.0
else:
loss = log_loss(st.ytrues, st.probs_up)
st.clm.update(acc_est, acc_l, acc_u, f1, precision, recall, loss)

# trader: exit the trading position
if pred_up:
Expand Down
9 changes: 7 additions & 2 deletions pdr_backend/sim/sim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def __init__(self):
self.precisions: List[float] = [] # [i] : precision
self.recalls: List[float] = [] # [i] : recall

def update(self, acc_est, acc_l, acc_u, f1, precision, recall):
self.losses: List[float] = [] # [i] : log-loss

def update(self, acc_est, acc_l, acc_u, f1, precision, recall, loss):
self.acc_ests.append(acc_est)
self.acc_ls.append(acc_l)
self.acc_us.append(acc_u)
Expand All @@ -26,9 +28,11 @@ def update(self, acc_est, acc_l, acc_u, f1, precision, recall):
self.precisions.append(precision)
self.recalls.append(recall)

self.losses.append(loss)

@staticmethod
def recent_metrics_names() -> List[str]:
return ["acc_est", "acc_l", "acc_u", "f1", "precision", "recall"]
return ["acc_est", "acc_l", "acc_u", "f1", "precision", "recall", "loss"]

def recent_metrics(self) -> List[Union[int, float]]:
"""Return most recent classifier metrics"""
Expand All @@ -40,6 +44,7 @@ def recent_metrics(self) -> List[Union[int, float]]:
self.f1s[-1],
self.precisions[-1],
self.recalls[-1],
self.losses[-1],
]


Expand Down

0 comments on commit aee1a09

Please sign in to comment.