-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PrecisionRecall Op #5111
Add PrecisionRecall Op #5111
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway, I can approve this for now until we have a better evaluator network implementation.
- micro average recall | ||
- micro f1 score | ||
|
||
To compute the above metrics, we need to statistic counts for true positives, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to statistic counts for => we need statistics to count...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
To compute the above metrics, we need to statistic counts for true positives, | ||
false positives and false negatives. Here count of true negatives is not | ||
necessary, but statisticing it may provide potential usage and the cost is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"statistic" is a noun, not a verb. Change it to "calculating".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
out2->mutable_data<T>(ctx.GetPlace()); | ||
auto accum_states = EigenMatrix<T>::From(*out2); | ||
accum_states.setZero(); | ||
T* accum_states_data = out2->data<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thought accumulating should be a more general method, like in #4828 (comment) , we can use an accumulating operator to generate more configurable evaluating subnetwork.
const T* weights_data = in2 ? in2->data<T>() : nullptr; | ||
const T* states_data = in3 ? in3->data<T>() : nullptr; | ||
T* batch_metrics_data = out0->mutable_data<T>(ctx.GetPlace()); | ||
T* accum_metrics_data = out1->mutable_data<T>(ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outputs of these metrics can just be type float
or double
, type T
should not infect the output type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
for (size_t i = 0; i < sample_num; ++i) { | ||
size_t max_idx = 0; | ||
T max_val = predictions_data[i * class_dim]; | ||
for (size_t j = 1; j < class_dim; ++j) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can assume the input predictions are outputs of topk
op so you don't need to find the max probability here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
topk will output both probability and indices.
class PrecisionRecallKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto* in0 = ctx.Input<Tensor>("Predictions"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These temp var names are not quite human-readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Fixes #5070
This operator can be used to compute various metrics including:
To compute the above metrics, we need to statistic counts for true positives, false positives and false negatives. Here count of true negatives is not necessary, but statisticing it may provide potential usage and the cost is trivial, so the operator also provides count of true negatives.
We define state as a 2-D tensor with shape [class number, 4]. Each row of a state contains statistic variables for corresponding class. Layout of each row is: TP(true positives), FP(false positives), TN(true negatives), FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be calculated by given weight instead of instance count.
This operator also supports metrics computing for cross-batch situation. To achieve this, 'Input(StatesInfo)' should be provided. State of current batch data will be accumulated to 'Input(StatesInfo)' and 'Output(AccumStatesInfo)' is the accumulation state.
'Output(BatchMetrics)' is metrics of current batch data while 'Output(AccumStatesInfo)' is metrics of accumulation data.