Skip to content

Commit

Permalink
Update plotting, integrate global accuracy and minimize layout size.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Feb 12, 2024
1 parent 8e0fa0d commit 08624a9
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 196 deletions.
50 changes: 31 additions & 19 deletions params.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
settings:
n_jobs: 4
n_jobs: 95
backend: joblib
mlflow_tracking_uri: http://localhost:5000

threshold_characteristics:
active: false
active: true
valuation_method: banzhaf_shapley # Method used to calculate the threshold characteristics.
model: logistic_regression # Default model to use for determining the values
max_plotting_percentage: 1e-4 # Threshold for stopping plotting in direction of x-axis.
Expand Down Expand Up @@ -34,8 +34,6 @@ active:
- tmc_shapley
- beta_shapley
- loo
- banzhaf_shapley
- least_core
repetitions:
from: 1
to: 20
Expand All @@ -48,31 +46,31 @@ experiments:
fn: metric
metric: accuracy
eval_model: logistic_regression
plot:
plots:
- accuracy
weighted_accuracy_drop_knn:
fn: metric
metric: accuracy
eval_model: knn
plot:
plots:
- accuracy
weighted_accuracy_drop_gradient_boosting_classifier:
fn: metric
metric: accuracy
eval_model: gradient_boosting_classifier
plot:
plots:
- accuracy
weighted_accuracy_drop_svm:
fn: metric
metric: accuracy
eval_model: svm
plot:
plots:
- accuracy
weighted_accuracy_drop_mlp:
fn: metric
metric: accuracy
eval_model: mlp
plot:
plots:
- accuracy

metrics:
Expand All @@ -85,9 +83,9 @@ experiments:
- weighted_accuracy_drop_mlp
fn: geometric_weighted_drop
input_perc: 1.0
plot:
- table
- boxplot
plots:
- table_wad
- box_wad

geometric_weighted_drop_half:
curve:
Expand All @@ -98,9 +96,9 @@ experiments:
- weighted_accuracy_drop_mlp
fn: geometric_weighted_drop
input_perc: 0.5
plot:
- table
- boxplot
plots:
- table_wad
- box_wad

noise_removal:
sampler: default
Expand All @@ -118,14 +116,14 @@ experiments:
curve:
- precision_recall
fn: roc_auc
plot:
- table
- boxplot
plots:
- table_auc
- box_auc

plots:
accuracy:
type: line
len_curve_perc: 0.5
plot_perc: 0.5
x_label: "n"
y_label: "Accuracy"

Expand All @@ -134,6 +132,20 @@ plots:
x_label: "Recall"
y_label: "Precision"

table_wad:
type: table

table_auc:
type: table

box_wad:
type: boxplot
x_label: "WAD"

box_auc:
type: boxplot
x_label: "AUC"


samplers:
default:
Expand Down
22 changes: 7 additions & 15 deletions scripts/calculate_threshold_characteristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _calculate_threshold_characteristics(
n_jobs = params["settings"]["n_jobs"]

logger.info("Calculating in class characteristics.")
in_cls_mar_acc, in_cls_stats = calculate_subset_score(
in_cls_mar_acc = calculate_subset_score(
val_set,
lambda c: np.argwhere(val_set.y_train == c)[:, 0],
model_name,
Expand All @@ -104,10 +104,11 @@ def _calculate_threshold_characteristics(
n_jobs,
backend,
)

logger.info("Calculating out of class characteristics.")
out_of_cls_mar_acc, out_of_cls_stats = calculate_subset_score(
global_mar_acc = calculate_subset_score(
val_set,
lambda c: np.argwhere(val_set.y_train != c)[:, 0],
lambda c: np.argwhere((val_set.y_train == c) | (val_set.y_train != c))[:, 0],
model_name,
model_seed,
sampler_seed,
Expand All @@ -116,20 +117,11 @@ def _calculate_threshold_characteristics(
backend,
)

logger.info("Calculating curves and statistics.")
threshold_characteristics_curves = calculate_threshold_characteristic_curves(
in_cls_mar_acc, out_of_cls_mar_acc
)
in_cls_out_of_cls_stats = pd.DataFrame(
[in_cls_stats, out_of_cls_stats], index=["in_cls", "out_of_cls"]
)

logger.info("Storing files.")
os.makedirs(output_dir, exist_ok=True)
in_cls_out_of_cls_stats.to_csv(output_dir / "threshold_characteristics_stats.csv")
threshold_characteristics_curves.to_csv(
output_dir / "threshold_characteristics_curves.csv", sep=";"
)

np.savetxt(output_dir / "in_cls_mar_acc.txt", in_cls_mar_acc)
np.savetxt(output_dir / "global_mar_acc.txt", global_mar_acc)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 08624a9

Please sign in to comment.