-
Notifications
You must be signed in to change notification settings - Fork 272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Let the DispReconstructor also compute a score for the sign prediction #2479
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -672,6 +672,7 @@ def _predict(self, key, table): | |
) | ||
X, valid = table_to_X(table, self.features, self.log) | ||
prediction = np.full(len(table), np.nan) | ||
score = np.full(len(table), np.nan) | ||
|
||
if np.any(valid): | ||
valid_norms = self._models[key][0].predict(X) | ||
|
@@ -681,12 +682,17 @@ def _predict(self, key, table): | |
else: | ||
prediction[valid] = valid_norms | ||
|
||
prediction[valid] *= self._models[key][1].predict(X) | ||
sign_proba = self._models[key][1].predict_proba(X)[:, 0] | ||
# proba is [0 and 1] where 0 => very certain -1, 1 => very certain 1 | ||
# and 0.5 means random guessing either. So we transform to a score | ||
# where 0 means "guessing" and 1 means "very certain" | ||
score[valid] = np.abs(2 * sign_proba - 1.0) | ||
prediction[valid] *= np.where(sign_proba >= 0.5, 1.0, -1.0) | ||
|
||
if self.unit is not None: | ||
prediction = u.Quantity(prediction, self.unit, copy=False) | ||
|
||
return prediction, valid | ||
return prediction, score, valid | ||
|
||
def __call__(self, event: ArrayEventContainer) -> None: | ||
"""Event-wise prediction for the EventSource-Loop. | ||
|
@@ -705,10 +711,15 @@ def __call__(self, event: ArrayEventContainer) -> None: | |
passes_quality_checks = self.quality_query.get_table_mask(table)[0] | ||
|
||
if passes_quality_checks: | ||
disp, valid = self._predict(self.subarray.tel[tel_id], table) | ||
disp, sign_score, valid = self._predict( | ||
self.subarray.tel[tel_id], table | ||
) | ||
|
||
if valid: | ||
disp_container = DispContainer(parameter=disp[0]) | ||
disp_container = DispContainer( | ||
parameter=disp[0], | ||
sign_score=sign_score[0], | ||
) | ||
|
||
hillas = event.dl1.tel[tel_id].parameters.hillas | ||
psi = hillas.psi.to_value(u.rad) | ||
|
@@ -775,11 +786,19 @@ def predict_table(self, key, table: Table) -> Dict[ReconstructionProperty, Table | |
n_rows = len(table) | ||
disp = u.Quantity(np.full(n_rows, np.nan), self.unit, copy=False) | ||
is_valid = np.full(n_rows, False) | ||
sign_score = np.full(n_rows, np.nan) | ||
|
||
valid = self.quality_query.get_table_mask(table) | ||
disp[valid], is_valid[valid] = self._predict(key, table[valid]) | ||
disp[valid], sign_score[valid], is_valid[valid] = self._predict( | ||
key, table[valid] | ||
) | ||
|
||
disp_result = Table({f"{self.prefix}_tel_parameter": disp}) | ||
disp_result = Table( | ||
{ | ||
f"{self.prefix}_tel_parameter": disp, | ||
f"{self.prefix}_tel_sign_score": sign_score, | ||
} | ||
) | ||
add_defaults_and_meta( | ||
disp_result, | ||
DispContainer, | ||
|
@@ -917,10 +936,10 @@ def __call__(self, telescope_type, table): | |
{ | ||
"cv_fold": np.full(len(truth), fold, dtype=np.uint8), | ||
"tel_type": [str(telescope_type)] * len(truth), | ||
"prediction": cv_prediction, | ||
"truth": truth, | ||
"true_energy": test["true_energy"], | ||
"true_impact_distance": test["true_impact_distance"], | ||
**cv_prediction, | ||
} | ||
) | ||
) | ||
|
@@ -945,7 +964,7 @@ def _cross_validate_regressor(self, telescope_type, train, test): | |
prediction, _ = regressor._predict(telescope_type, test) | ||
truth = test[regressor.target] | ||
r2 = r2_score(truth, prediction) | ||
return prediction, truth, {"R^2": r2} | ||
return {f"{regressor.prefix}_energy": prediction}, truth, {"R^2": r2} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like there isn't any test for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no direct test, yes. However the code is covered by the tests for the train tools which also inspect the cross validation output. |
||
|
||
def _cross_validate_classification(self, telescope_type, train, test): | ||
classifier = self.model_component | ||
|
@@ -957,15 +976,23 @@ def _cross_validate_classification(self, telescope_type, train, test): | |
0, | ||
) | ||
roc_auc = roc_auc_score(truth, prediction) | ||
return prediction, truth, {"ROC AUC": roc_auc} | ||
return ( | ||
{f"{classifier.prefix}_prediction": prediction}, | ||
truth, | ||
{"ROC AUC": roc_auc}, | ||
) | ||
|
||
def _cross_validate_disp(self, telescope_type, train, test): | ||
models = self.model_component | ||
models.fit(telescope_type, train) | ||
prediction, _ = models._predict(telescope_type, test) | ||
disp, sign_score, _ = models._predict(telescope_type, test) | ||
truth = test[models.target] | ||
r2 = r2_score(np.abs(truth), np.abs(prediction)) | ||
accuracy = accuracy_score(np.sign(truth), np.sign(prediction)) | ||
r2 = r2_score(np.abs(truth), np.abs(disp)) | ||
accuracy = accuracy_score(np.sign(truth), np.sign(disp)) | ||
prediction = { | ||
f"{models.prefix}_parameter": disp, | ||
f"{models.prefix}_sign_score": sign_score, | ||
} | ||
return prediction, truth, {"R^2": r2, "accuracy": accuracy} | ||
|
||
def write(self, overwrite=False): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
The ``DispReconstructor`` now computes a score for how certain the prediction of the disp sign is. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this explanation, is it supposed to say that
sign_proba
goes from -1 to 1?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classes predicted by sign are "1" and "-1" and the
[:,0]
component of the output ofpredict_proba
are the likelihood scores for the "1" class. Therefore a score of 0 equals a high likelihood of class "-1".