Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

AUC metrics #3751

Merged
merged 40 commits into from
Jul 8, 2021
Merged

AUC metrics #3751

merged 40 commits into from
Jul 8, 2021

Conversation

liliarose
Copy link
Contributor

@liliarose liliarose commented Jun 25, 2021

Patch description

Adding AUC (Area under ROC curve) metrics as an option for eval_model

Testing steps

  1. Check the other metric tests r okay ( pytest -v -k TestMetrics/ pytest -v -k TestMetric)
  2. Check via test to make sure AUC added up correctly..... (pytest -v -k TestAggregators)
  3. Check it works via eval_model for both classes with micro aggregation --> checking to make sure the agents' auc resetted. (parlai em -mf zoo:dialogue_safety/multi_turn/model -t internal:civil_bias_toxic_comment:civil_bias_toxicity,dialogue_safety -auc 4 -rf parlai_internal/reports/auc_multi_safety_civil_multi.json -ne 3000 -micro True)

Logs

Screen Shot 2021-06-30 at 1 36 01 PM

Screen Shot 2021-06-30 at 1 36 17 PM

metric dialogue_safety internal:civil_bias_toxic_comment:civil_bias_toxicity all
class___notok___f1 0.82453 0.55437 0.68945
AUC___notok__ 0.98469 0.92527 0.96373

Testing Step 2

parlai em -mf zoo:dialogue_safety/multi_turn/model -t internal:civil_bias_toxic_comment:civil_bias_toxicity -auc True -rf parlai_internal/reports/auc_multi_safety_civil_single.json -ne 3000

Results:

Testing Step 1

(conda_parlai) wzhang4343@devfair0173:~/ParlAI$ pytest -v -k TestAggregators
===================================================== test session starts =====================================================
platform linux -- Python 3.7.10, pytest-6.2.4, py-1.10.0, pluggy-1.0.0.dev0 -- /private/home/wzhang4343/.conda/envs/conda_parlai/bin/python
cachedir: .pytest_cache
rootdir: /private/home/wzhang4343/ParlAI, configfile: pytest.ini, testpaths: tests, parlai/tasks
plugins: requests-mock-1.9.2, regressions-2.2.0, hydra-core-1.0.6, datadir-1.3.1
collected 997 items / 990 deselected / 7 selected                                                                             

tests/test_metrics.py::TestAggregators::test_auc_metrics PASSED                                                         [ 14%]
tests/test_metrics.py::TestAggregators::test_classifier_metrics PASSED                                                  [ 28%]
tests/test_metrics.py::TestAggregators::test_macro_aggregation PASSED                                                   [ 42%]
tests/test_metrics.py::TestAggregators::test_micro_aggregation PASSED                                                   [ 57%]
tests/test_metrics.py::TestAggregators::test_time_metric PASSED                                                         [ 71%]
tests/test_metrics.py::TestAggregators::test_uneven_macro_aggrevation PASSED                                            [ 85%]
tests/test_metrics.py::TestAggregators::test_unnamed_aggregation PASSED                                                 [100%]

==================================================== slowest 10 durations =====================================================
0.01s call     tests/test_metrics.py::TestAggregators::test_auc_metrics

(9 durations < 0.005s hidden.  Use -vv to show these durations.)
======================================== 7 passed, 990 deselected, 2 warnings in 5.77s ========================================

Testing Step 0

(conda_parlai) wzhang4343@devfair0173:~/ParlAI$ pytest -v -k TestMetric
===================================================== test session starts =====================================================
platform linux -- Python 3.7.10, pytest-6.2.4, py-1.10.0, pluggy-1.0.0.dev0 -- /private/home/wzhang4343/.conda/envs/conda_parlai/bin/python
cachedir: .pytest_cache
rootdir: /private/home/wzhang4343/ParlAI, configfile: pytest.ini, testpaths: tests, parlai/tasks
plugins: requests-mock-1.9.2, regressions-2.2.0, hydra-core-1.0.6, datadir-1.3.1
collected 996 items / 984 deselected / 12 selected                                                                            

tests/test_metrics.py::TestMetric::test_average_metric_additions PASSED                                                 [  8%]
tests/test_metrics.py::TestMetric::test_average_metric_inputs PASSED                                                    [ 16%]
tests/test_metrics.py::TestMetric::test_fixedmetric PASSED                                                              [ 25%]
tests/test_metrics.py::TestMetric::test_macroaverage_additions PASSED                                                   [ 33%]
tests/test_metrics.py::TestMetric::test_sum_metric_additions PASSED                                                     [ 41%]
tests/test_metrics.py::TestMetric::test_sum_metric_inputs PASSED                                                        [ 50%]
tests/test_metrics.py::TestMetrics::test_largebuffer PASSED                                                             [ 58%]
tests/test_metrics.py::TestMetrics::test_multithreaded PASSED                                                           [ 66%]
tests/test_metrics.py::TestMetrics::test_recent PASSED                                                                  [ 75%]
tests/test_metrics.py::TestMetrics::test_shared PASSED                                                                  [ 83%]
tests/test_metrics.py::TestMetrics::test_simpleadd PASSED                                                               [ 91%]
tests/test_metrics.py::TestMetrics::test_verymultithreaded PASSED                                                       [100%]

==================================================== slowest 10 durations =====================================================
0.13s call     tests/test_metrics.py::TestMetrics::test_verymultithreaded
0.09s call     tests/test_metrics.py::TestMetrics::test_largebuffer

(8 durations < 0.005s hidden.  Use -vv to show these durations.)
======================================= 12 passed, 984 deselected, 2 warnings in 7.61s ========================================
(conda_parlai) wzhang4343@devfair0173:~/ParlAI$ pytest -v -k TestMetrics
===================================================== test session starts =====================================================
platform linux -- Python 3.7.10, pytest-6.2.4, py-1.10.0, pluggy-1.0.0.dev0 -- /private/home/wzhang4343/.conda/envs/conda_parlai/bin/python
cachedir: .pytest_cache
rootdir: /private/home/wzhang4343/ParlAI, configfile: pytest.ini, testpaths: tests, parlai/tasks
plugins: requests-mock-1.9.2, regressions-2.2.0, hydra-core-1.0.6, datadir-1.3.1
collected 996 items / 990 deselected / 6 selected                                                                             

tests/test_metrics.py::TestMetrics::test_largebuffer PASSED                                                             [ 16%]
tests/test_metrics.py::TestMetrics::test_multithreaded PASSED                                                           [ 33%]
tests/test_metrics.py::TestMetrics::test_recent PASSED                                                                  [ 50%]
tests/test_metrics.py::TestMetrics::test_shared PASSED                                                                  [ 66%]
tests/test_metrics.py::TestMetrics::test_simpleadd PASSED                                                               [ 83%]
tests/test_metrics.py::TestMetrics::test_verymultithreaded PASSED                                                       [100%]

==================================================== slowest 10 durations =====================================================
0.13s call     tests/test_metrics.py::TestMetrics::test_verymultithreaded
0.09s call     tests/test_metrics.py::TestMetrics::test_largebuffer

(8 durations < 0.005s hidden.  Use -vv to show these durations.)
======================================== 6 passed, 990 deselected, 2 warnings in 5.16s ========================================

** Other Info **

  • I chose to store the AUCmetric in the classifier because if we ever wanna do it during training.... it should be relatively easy to modify the code and do it there :) (maybe?)
  • Also, added 6 tests and know that it might be a lot less code if I just did a for loop with some of the commands, but then I noticed that if I did do that, it would be less easy to read exactly which test it was.
  • Please don't get scared by the lines...... ~300 are in the tests.
  • Code used to generate the graphs (not in the final version)
                # for fun...
                import matplotlib.pyplot as plt
                from sklearn import metrics
                fpr, tpr, _, _ = curr_auc._calc_fpr_tpr()
                display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=curr_auc.value())
                display.plot()
                folder = '/'.join(opt['report_filename'].split('/')[:-1])
                plt.savefig(folder+f"/AUC_graph_{task}_{classifier_agent.class_list[class_indices]}.png")
                plt.clf()
                print(f"graphed {folder}/AUC_graph_{task}_{classifier_agent.class_list[class_indices]}.png")
  • Code used to generate table from report file:
interested_metrics = {'AUC___notok__', 'class___notok___f1'}

fields = ['dialogue_safety', 'internal:civil_bias_toxic_comment:civil_bias_toxicity']
rows = []

for metric in interested_metrics:  
 row = [metric]
 for field in fields:
   key = field + '/' + metric 
   row.append(str(round(report[key], 5)))
 row.append(str(round(report[metric], 5)))
 rows.append('|'.join(row))

print('|'.join(['metric'] + fields + ['all']))
print('|'.join(['---']*(len(fields) + 2)))
print('\n'.join(rows))

@liliarose liliarose marked this pull request as ready for review June 25, 2021 21:20
Copy link
Contributor

@stephenroller stephenroller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally seems good and I appreciate the feature. Just have a clarification question.

parlai/core/torch_classifier_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_classifier_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_classifier_agent.py Outdated Show resolved Hide resolved
@stephenroller
Copy link
Contributor

(That test is fantastic btw)

@liliarose liliarose marked this pull request as ready for review June 30, 2021 22:21
@jxmsML jxmsML self-requested a review July 1, 2021 14:44
Copy link
Contributor

@EricMichaelSmith EricMichaelSmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, great to have! Yeah, I like the thorough tests. Will defer to others with more context for approval

parlai/core/torch_classifier_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_classifier_agent.py Show resolved Hide resolved
parlai/core/torch_classifier_agent.py Show resolved Hide resolved
parlai/scripts/eval_model.py Outdated Show resolved Hide resolved
parlai/scripts/eval_model.py Show resolved Hide resolved
@liliarose liliarose marked this pull request as draft July 6, 2021 20:46
@liliarose liliarose marked this pull request as ready for review July 6, 2021 21:07
Copy link
Contributor

@jxmsML jxmsML left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!! some nit

@stephenroller stephenroller changed the title Auc metrics AUC metrics Jul 7, 2021
@liliarose liliarose merged commit 36004c9 into master Jul 8, 2021
@liliarose liliarose deleted the auc_metrics branch July 8, 2021 00:16
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants