Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
AUC Updated
  • Loading branch information
TimotheeeNiven authored Jan 24, 2025
1 parent b8c0305 commit 05b1d3b
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions benchmark/runner/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def normalize_probabilities(scores):

# Function to calculate accuracy
def calculate_accuracy(y_pred, labels):
y_pred = (y_pred - y_pred.min(axis=0)) / (y_pred.max(axis=0) - y_pred.min(axis=0) + 1e-10)
y_pred = np.array([[value, 1- value] for value in y_pred])
y_pred_label = np.argmax(y_pred, axis=1)
correct = np.sum(labels == y_pred_label)
accuracy = 100 * correct / len(y_pred)
Expand All @@ -121,11 +123,14 @@ def calculate_accuracy(y_pred, labels):

# Function to calculate AUC
def calculate_auc(y_pred, labels, n_classes):
thresholds = np.arange(0.0, 1.01, .01)
# Normalize y_pred for each class
y_pred = (y_pred - y_pred.min(axis=0)) / (y_pred.max(axis=0) - y_pred.min(axis=0) + 1e-10)
y_pred = np.array([[value, 1- value] for value in y_pred])
thresholds = np.arange(0.0, 1.01, 0.01)
fpr = np.zeros([n_classes, len(thresholds)])
tpr = np.zeros([n_classes, len(thresholds)])
roc_auc = np.zeros(n_classes)

for class_item in range(n_classes):
all_positives = sum(labels == class_item)
all_negatives = len(labels) - all_positives
Expand All @@ -146,13 +151,14 @@ def calculate_auc(y_pred, labels, n_classes):
fpr[class_item, 0] = 1
tpr[class_item, 0] = 1
for threshold_item in range(len(thresholds) - 1):
roc_auc[class_item] += .5 * (tpr[class_item, threshold_item] + tpr[class_item, threshold_item + 1]) * (
roc_auc[class_item] += 0.5 * (tpr[class_item, threshold_item] + tpr[class_item, threshold_item + 1]) * (
fpr[class_item, threshold_item] - fpr[class_item, threshold_item + 1])

roc_auc_avg = np.mean(roc_auc)
print(f"Simplified average ROC AUC = {roc_auc_avg:.3f}")
return roc_auc


# Summarize results
def summarize_result(result):
num_correct_files = 0
Expand Down Expand Up @@ -227,4 +233,4 @@ def summarize_result(result):
"mode": args.mode
}
result = run_test(**config)
summarize_result(result)
summarize_result(result)

0 comments on commit 05b1d3b

Please sign in to comment.