Skip to content

Commit

Permalink
Merge pull request #7 from shiv3679/patch-2
Browse files Browse the repository at this point in the history
multiple instances of gss in nwpeval.py
  • Loading branch information
Debasish-Mahapatra authored Apr 6, 2024
2 parents deec65b + 42a6a4e commit dd45376
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions nwpeval/nwpeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def compute_metrics(self, metrics, dim=None, thresholds=None):
elif metric == 'PSS':
threshold = thresholds.get('PSS', 0.5)
metric_values[metric] = self.compute_pss(threshold, dim)
elif metric == 'GS':
threshold = thresholds.get('GS', 0.5)
# elif metric == 'GS':
# threshold = thresholds.get('GS', 0.5) # a mirror instance of gilbert skill score already present
metric_values[metric] = self.compute_gs(threshold, dim)
elif metric == 'SEDS':
threshold = thresholds.get('SEDS', 0.5)
Expand Down Expand Up @@ -454,29 +454,29 @@ def compute_pss(self, threshold, dim=None):

return pss

def compute_gs(self, threshold, dim=None):
"""
Compute the Gilbert Skill Score (GS) for a given threshold.
# def compute_gs(self, threshold, dim=None):
# """
# Compute the Gilbert Skill Score (GS) for a given threshold.

Args:
threshold (float): The threshold value for binary classification.
dim (str, list, or None): The dimension(s) along which to compute the GS.
If None, compute the GS over the entire data.
# Args:
# threshold (float): The threshold value for binary classification.
# dim (str, list, or None): The dimension(s) along which to compute the GS.
# If None, compute the GS over the entire data.

Returns:
xarray.DataArray: The computed GS values.
"""
# Convert data to binary based on the threshold
obs_binary = (self.obs_data >= threshold).astype(int)
model_binary = (self.model_data >= threshold).astype(int)
# Returns:
# xarray.DataArray: The computed GS values.
# """
# # Convert data to binary based on the threshold
# obs_binary = (self.obs_data >= threshold).astype(int)
# model_binary = (self.model_data >= threshold).astype(int)

# Calculate the confusion matrix
tn, fp, fn, tp = self.confusion_matrix(obs_binary, model_binary, dim)
# # Calculate the confusion matrix
# tn, fp, fn, tp = self.confusion_matrix(obs_binary, model_binary, dim)

# Calculate the GS
gs = (tp - ((tp + fp) * (tp + fn) / (tp + fp + fn + tn))) / (tp + fp + fn - ((tp + fp) * (tp + fn) / (tp + fp + fn + tn)))
# # Calculate the GS
# gs = (tp - ((tp + fp) * (tp + fn) / (tp + fp + fn + tn))) / (tp + fp + fn - ((tp + fp) * (tp + fn) / (tp + fp + fn + tn)))

return gs
# return gs

def compute_seds(self, threshold, dim=None):
"""
Expand Down

0 comments on commit dd45376

Please sign in to comment.