Skip to content

Commit

Permalink
Added parse_loss function to HyperLoss class
Browse files Browse the repository at this point in the history
  • Loading branch information
Cmurilochem committed Dec 1, 2023
1 parent 9c42c0b commit 5dba88c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
46 changes: 34 additions & 12 deletions n3fit/src/n3fit/hyper_optimization/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"""
import logging
from typing import Callable

import numpy as np

Expand All @@ -42,30 +43,29 @@ class HyperLoss:
statistics default to the average.
Args:
loss (str): the loss over the replicas to use
loss_type (str): the type of loss over the replicas to use
fold_statistic (str): the statistic over the folds to use
replica_statistic (str): the statistic over the replicas to use, for per replica losses
"""

def __init__(self, loss: str = None, replica_statistic: str = None, fold_statistic: str = None):
def __init__(
self, loss_type: str = None, replica_statistic: str = None, fold_statistic: str = None
):
self.implemented_stats = {
"average": self._average,
"best_worst": self._best_worst,
"std": self._std,
}
self._default_statistic = "average"
self.implemented_losses = ["chi2", "phi2"]

self.loss = loss
# self.implemented_losses = {
# "chi2": self._chi2,
# "phi2": self._phi2,
# }
self._default_losss = "chi2"
self._default_statistic = "average"
self._default_loss = "chi2"

self.loss = self._parse_loss(loss_type)
self.reduce_over_folds = self._parse_statistic(fold_statistic, "fold_statistic")
self.reduce_over_replicas = self._parse_statistic(replica_statistic, "replica_statistic")

def compute_loss(self, penalties, experimental_loss, pdf_models, experimental_data):
def compute_loss(self, penalties, experimental_loss, pdf_models, experimental_data) -> float:
"""
Compute the loss, including added penalties, for a single fold.
Expand All @@ -84,10 +84,32 @@ def compute_loss(self, penalties, experimental_loss, pdf_models, experimental_da
loss = self.reduce_over_replicas(experimental_loss)
elif self.loss == "phi2":
loss = compute_phi2(N3PDF(pdf_models), experimental_data)

return total_penalties + loss

def _parse_statistic(self, statistic: str, name) -> str:
def _parse_loss(self, loss_type: str) -> str:
"""
Parse the type of loss and return the default if None.
Args:
loss_type (str): the loss to parse
Returns:
str: the parsed loss
"""
if loss_type is None:
loss_type = self._default_loss
log.warning(f"No loss_type selected in HyperLoss, defaulting to {loss_type}")

if loss_type not in self.implemented_losses:
raise ValueError(
f"Input loss type {loss_type} not recognized by HyperLoss. "
"Options are 'chi2' or 'phi2."
)

log.info(f"Using '{loss_type}' as the loss type for hyperoptimization")
return loss_type

def _parse_statistic(self, statistic: str, name) -> Callable:
"""
Parse the statistic and return the default if None.
Expand Down
6 changes: 4 additions & 2 deletions n3fit/src/n3fit/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,11 @@ def __init__(
# Check what is the hyperoptimization target function
replica_statistic = kfold_parameters.get("replica_statistic", None)
fold_statistic = kfold_parameters.get("fold_statistic", None)
loss = kfold_parameters.get("loss", None)
loss_type = kfold_parameters.get("loss", None)
self._hyper_loss = HyperLoss(
loss=loss, replica_statistic=replica_statistic, fold_statistic=fold_statistic
loss_type=loss_type,
replica_statistic=replica_statistic,
fold_statistic=fold_statistic,
)

# Initialize the dictionaries which contain all fitting information
Expand Down

0 comments on commit 5dba88c

Please sign in to comment.