diff --git a/docs/0_get_started/installation.md b/docs/0_get_started/installation.md index ae44894f..26116a34 100644 --- a/docs/0_get_started/installation.md +++ b/docs/0_get_started/installation.md @@ -47,3 +47,34 @@ You can also install other miscallenous optional dependencies by running To install both groups of packages, run pip install textattack[tensorflow,optional] + + + +## FAQ on installation + +For many of the dependent library issues, the following command is the first you could try: +```bash +pip install --force-reinstall textattack +``` + +OR +```bash +pip install textattack[tensorflow,optional] +``` + + +Besides, we highly recommend you to use virtual environment for textattack use, +see [information here](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#removing-an-environment). Here is one conda example: + +```bash +conda create -n textattackenv python=3.7 +conda activate textattackenv +conda env list +``` + +If you want to use the most-up-to-date version of textattack (normally with newer bug fixes), you can run the following: +```bash +git clone https://github.com/QData/TextAttack.git +cd TextAttack +pip install .[dev] +``` \ No newline at end of file diff --git a/docs/apidoc/textattack.constraints.grammaticality.language_models.rst b/docs/apidoc/textattack.constraints.grammaticality.language_models.rst index f342ed86..d998a19d 100644 --- a/docs/apidoc/textattack.constraints.grammaticality.language_models.rst +++ b/docs/apidoc/textattack.constraints.grammaticality.language_models.rst @@ -7,6 +7,7 @@ textattack.constraints.grammaticality.language\_models package :show-inheritance: + .. toctree:: :maxdepth: 6 diff --git a/docs/apidoc/textattack.constraints.grammaticality.rst b/docs/apidoc/textattack.constraints.grammaticality.rst index e39cdb40..f3d2c34c 100644 --- a/docs/apidoc/textattack.constraints.grammaticality.rst +++ b/docs/apidoc/textattack.constraints.grammaticality.rst @@ -7,6 +7,7 @@ textattack.constraints.grammaticality package :show-inheritance: + .. toctree:: :maxdepth: 6 diff --git a/docs/apidoc/textattack.constraints.rst b/docs/apidoc/textattack.constraints.rst index 72dbb9e4..1907e29b 100644 --- a/docs/apidoc/textattack.constraints.rst +++ b/docs/apidoc/textattack.constraints.rst @@ -7,6 +7,7 @@ textattack.constraints package :show-inheritance: + .. toctree:: :maxdepth: 6 diff --git a/docs/apidoc/textattack.constraints.semantics.rst b/docs/apidoc/textattack.constraints.semantics.rst index 3e8b0973..e20d9c0e 100644 --- a/docs/apidoc/textattack.constraints.semantics.rst +++ b/docs/apidoc/textattack.constraints.semantics.rst @@ -7,6 +7,7 @@ textattack.constraints.semantics package :show-inheritance: + .. toctree:: :maxdepth: 6 diff --git a/docs/apidoc/textattack.constraints.semantics.sentence_encoders.rst b/docs/apidoc/textattack.constraints.semantics.sentence_encoders.rst index 9e712dd8..22be1b97 100644 --- a/docs/apidoc/textattack.constraints.semantics.sentence_encoders.rst +++ b/docs/apidoc/textattack.constraints.semantics.sentence_encoders.rst @@ -7,6 +7,7 @@ textattack.constraints.semantics.sentence\_encoders package :show-inheritance: + .. toctree:: :maxdepth: 6 diff --git a/docs/apidoc/textattack.datasets.rst b/docs/apidoc/textattack.datasets.rst index f4881aa6..d5e1564d 100644 --- a/docs/apidoc/textattack.datasets.rst +++ b/docs/apidoc/textattack.datasets.rst @@ -7,6 +7,7 @@ textattack.datasets package :show-inheritance: + .. toctree:: :maxdepth: 6 diff --git a/docs/apidoc/textattack.goal_functions.rst b/docs/apidoc/textattack.goal_functions.rst index 4e42db9b..a1a429f8 100644 --- a/docs/apidoc/textattack.goal_functions.rst +++ b/docs/apidoc/textattack.goal_functions.rst @@ -7,6 +7,7 @@ textattack.goal\_functions package :show-inheritance: + .. toctree:: :maxdepth: 6 diff --git a/docs/apidoc/textattack.metrics.attack_metrics.rst b/docs/apidoc/textattack.metrics.attack_metrics.rst new file mode 100644 index 00000000..b6ff602c --- /dev/null +++ b/docs/apidoc/textattack.metrics.attack_metrics.rst @@ -0,0 +1,26 @@ +textattack.metrics.attack\_metrics package +========================================== + +.. automodule:: textattack.metrics.attack_metrics + :members: + :undoc-members: + :show-inheritance: + + + +.. automodule:: textattack.metrics.attack_metrics.attack_queries + :members: + :undoc-members: + :show-inheritance: + + +.. automodule:: textattack.metrics.attack_metrics.attack_success_rate + :members: + :undoc-members: + :show-inheritance: + + +.. automodule:: textattack.metrics.attack_metrics.words_perturbed + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/apidoc/textattack.metrics.quality_metrics.rst b/docs/apidoc/textattack.metrics.quality_metrics.rst new file mode 100644 index 00000000..6d46e32d --- /dev/null +++ b/docs/apidoc/textattack.metrics.quality_metrics.rst @@ -0,0 +1,20 @@ +textattack.metrics.quality\_metrics package +=========================================== + +.. automodule:: textattack.metrics.quality_metrics + :members: + :undoc-members: + :show-inheritance: + + + +.. automodule:: textattack.metrics.quality_metrics.perplexity + :members: + :undoc-members: + :show-inheritance: + + +.. automodule:: textattack.metrics.quality_metrics.use + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/apidoc/textattack.metrics.rst b/docs/apidoc/textattack.metrics.rst new file mode 100644 index 00000000..bcad2dbe --- /dev/null +++ b/docs/apidoc/textattack.metrics.rst @@ -0,0 +1,22 @@ +textattack.metrics package +========================== + +.. automodule:: textattack.metrics + :members: + :undoc-members: + :show-inheritance: + + + +.. toctree:: + :maxdepth: 6 + + textattack.metrics.attack_metrics + textattack.metrics.quality_metrics + + + +.. automodule:: textattack.metrics.metric + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/apidoc/textattack.models.rst b/docs/apidoc/textattack.models.rst index 84074b07..153747f0 100644 --- a/docs/apidoc/textattack.models.rst +++ b/docs/apidoc/textattack.models.rst @@ -7,6 +7,7 @@ textattack.models package :show-inheritance: + .. toctree:: :maxdepth: 6 diff --git a/docs/apidoc/textattack.rst b/docs/apidoc/textattack.rst index 83ff65b5..325e4dea 100644 --- a/docs/apidoc/textattack.rst +++ b/docs/apidoc/textattack.rst @@ -19,6 +19,7 @@ textattack package textattack.goal_function_results textattack.goal_functions textattack.loggers + textattack.metrics textattack.models textattack.search_methods textattack.shared diff --git a/docs/apidoc/textattack.shared.rst b/docs/apidoc/textattack.shared.rst index 9679c45a..34a5a1b4 100644 --- a/docs/apidoc/textattack.shared.rst +++ b/docs/apidoc/textattack.shared.rst @@ -7,6 +7,7 @@ textattack.shared package :show-inheritance: + .. toctree:: :maxdepth: 6 diff --git a/docs/apidoc/textattack.transformations.rst b/docs/apidoc/textattack.transformations.rst index 1ae3f653..6e87f53f 100644 --- a/docs/apidoc/textattack.transformations.rst +++ b/docs/apidoc/textattack.transformations.rst @@ -7,6 +7,7 @@ textattack.transformations package :show-inheritance: + .. toctree:: :maxdepth: 6 diff --git a/tests/sample_outputs/run_attack_hotflip_lstm_mr_4_adv_metrics.txt b/tests/sample_outputs/run_attack_hotflip_lstm_mr_4_adv_metrics.txt new file mode 100644 index 00000000..9023d194 --- /dev/null +++ b/tests/sample_outputs/run_attack_hotflip_lstm_mr_4_adv_metrics.txt @@ -0,0 +1,74 @@ +/.*/Attack( + (search_method): BeamSearch( + (beam_width): 10 + ) + (goal_function): UntargetedClassification + (transformation): WordSwapGradientBased( + (top_n): 1 + ) + (constraints): + (0): MaxWordsPerturbed( + (max_num_words): 2 + (compare_against_original): True + ) + (1): WordEmbeddingDistance( + (embedding): WordEmbedding + (min_cos_sim): 0.8 + (cased): False + (include_unknown_words): True + (compare_against_original): True + ) + (2): PartOfSpeech( + (tagger_type): nltk + (tagset): universal + (allow_verb_noun_swap): True + (compare_against_original): True + ) + (3): RepeatModification + (4): StopwordModification + (is_black_box): False +) + +--------------------------------------------- Result 1 --------------------------------------------- +[[Positive (96%)]] --> [[Negative (77%)]] + +the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur [[supplies]] with tremendous skill . + +the story gives ample opportunity for large-scale action and suspense , which director shekhar kapur [[stagnated]] with tremendous skill . + + +--------------------------------------------- Result 2 --------------------------------------------- +[[Negative (57%)]] --> [[[SKIPPED]]] + +red dragon " never cuts corners . + + +--------------------------------------------- Result 3 --------------------------------------------- +[[Positive (51%)]] --> [[[FAILED]]] + +fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense . + + +--------------------------------------------- Result 4 --------------------------------------------- +[[Positive (89%)]] --> [[[FAILED]]] + +throws in enough clever and unexpected twists to make the formula feel fresh . + + + ++-------------------------------+--------+ +| Attack Results | | ++-------------------------------+--------+ +| Number of successful attacks: | 1 | +| Number of failed attacks: | 2 | +| Number of skipped attacks: | 1 | +| Original accuracy: | 75.0% | +| Accuracy under attack: | 50.0% | +| Attack success rate: | 33.33% | +| Average perturbed word %: | 5.56% | +| Average num. words per input: | 15.5 | +| Avg num queries: | 1.33 | +| Average Original Perplexity: | 291.47 | +| Average Attack Perplexity: | 320.33 | +| Average Attack USE Score: | 0.91 | ++-------------------------------+--------+ diff --git a/tests/sample_outputs/run_attack_transformers_datasets_adv_metrics.txt b/tests/sample_outputs/run_attack_transformers_datasets_adv_metrics.txt new file mode 100644 index 00000000..1b01102f --- /dev/null +++ b/tests/sample_outputs/run_attack_transformers_datasets_adv_metrics.txt @@ -0,0 +1,68 @@ +/.*/Attack( + (search_method): GreedyWordSwapWIR( + (wir_method): unk + ) + (goal_function): UntargetedClassification + (transformation): CompositeTransformation( + (0): WordSwapNeighboringCharacterSwap( + (random_one): True + ) + (1): WordSwapRandomCharacterSubstitution( + (random_one): True + ) + (2): WordSwapRandomCharacterDeletion( + (random_one): True + ) + (3): WordSwapRandomCharacterInsertion( + (random_one): True + ) + ) + (constraints): + (0): LevenshteinEditDistance( + (max_edit_distance): 30 + (compare_against_original): True + ) + (1): RepeatModification + (2): StopwordModification + (is_black_box): True +) + +--------------------------------------------- Result 1 --------------------------------------------- +[[Negative (100%)]] --> [[Positive (71%)]] + +[[hide]] [[new]] secretions from the parental units + +[[Ehide]] [[enw]] secretions from the parental units + + +--------------------------------------------- Result 2 --------------------------------------------- +[[Negative (100%)]] --> [[[FAILED]]] + +contains no wit , only labored gags + + +--------------------------------------------- Result 3 --------------------------------------------- +[[Positive (100%)]] --> [[Negative (96%)]] + +that [[loves]] its characters and communicates [[something]] [[rather]] [[beautiful]] about human nature + +that [[lodes]] its characters and communicates [[somethNng]] [[rathrer]] [[beautifdul]] about human nature + + + ++-------------------------------+---------+ +| Attack Results | | ++-------------------------------+---------+ +| Number of successful attacks: | 2 | +| Number of failed attacks: | 1 | +| Number of skipped attacks: | 0 | +| Original accuracy: | 100.0% | +| Accuracy under attack: | 33.33% | +| Attack success rate: | 66.67% | +| Average perturbed word %: | 30.95% | +| Average num. words per input: | 8.33 | +| Avg num queries: | 22.67 | +| Average Original Perplexity: | 1126.57 | +| Average Attack Perplexity: | 2823/.*/| +| Average Attack USE Score: | 0.76 | ++-------------------------------+---------+ diff --git a/tests/test_command_line/test_attack.py b/tests/test_command_line/test_attack.py index 714f73ed..5a455d22 100644 --- a/tests/test_command_line/test_attack.py +++ b/tests/test_command_line/test_attack.py @@ -48,6 +48,20 @@ "tests/sample_outputs/run_attack_transformers_datasets.txt", ), # + # test loading an attack from the transformers model hub and calculate perplexity and use + # + ( + "attack_from_transformers_adv_metrics", + ( + "textattack attack --model-from-huggingface " + "distilbert-base-uncased-finetuned-sst-2-english " + "--dataset-from-huggingface glue^sst2^train --recipe deepwordbug --num-examples 3 " + "--enable-advance-metrics" + "" + ), + "tests/sample_outputs/run_attack_transformers_datasets_adv_metrics.txt", + ), + # # test running an attack by loading a model and dataset from file # ( @@ -72,6 +86,17 @@ "tests/sample_outputs/run_attack_hotflip_lstm_mr_4.txt", ), # + # test hotflip on 10 samples from LSTM MR and calculate perplexity and use + # + ( + "run_attack_hotflip_lstm_mr_4_adv_metrics", + ( + "textattack attack --model lstm-mr --recipe hotflip " + "--num-examples 4 --num-examples-offset 3 --enable-advance-metrics " + ), + "tests/sample_outputs/run_attack_hotflip_lstm_mr_4_adv_metrics.txt", + ), + # # test: run_attack deepwordbug attack on 10 samples from LSTM MR # ( diff --git a/textattack/__init__.py b/textattack/__init__.py index 12f52331..a169173e 100644 --- a/textattack/__init__.py +++ b/textattack/__init__.py @@ -8,7 +8,6 @@ TextAttack provides components for common NLP tasks like sentence encoding, grammar-checking, and word replacement that can be used on their own. """ - from .attack_args import AttackArgs, CommandLineAttackArgs from .augment_args import AugmenterArgs from .dataset_args import DatasetArgs @@ -17,6 +16,7 @@ from .attack import Attack from .attacker import Attacker from .trainer import Trainer +from .metrics import Metric from . import ( attack_recipes, @@ -28,10 +28,12 @@ goal_function_results, goal_functions, loggers, + metrics, models, search_methods, shared, transformations, ) + name = "textattack" diff --git a/textattack/attack_args.py b/textattack/attack_args.py index ef510fdd..eaf5a725 100644 --- a/textattack/attack_args.py +++ b/textattack/attack_args.py @@ -174,6 +174,8 @@ class AttackArgs: Disable displaying individual attack results to stdout. silent (:obj:`bool`, `optional`, defaults to :obj:`False`): Disable all logging (except for errors). This is stronger than :obj:`disable_stdout`. + enable_advance_metrics (:obj:`bool`, `optional`, defaults to :obj:`False`): + Enable calculation and display of optional advance post-hoc metrics like perplexity, grammar errors, etc. """ num_examples: int = 10 @@ -194,6 +196,7 @@ class AttackArgs: log_to_wandb: str = None disable_stdout: bool = False silent: bool = False + enable_advance_metrics: bool = False def __post_init__(self): if self.num_successful_examples: @@ -351,6 +354,12 @@ def _add_parser_args(cls, parser): default=default_obj.silent, help="Disable all logging", ) + parser.add_argument( + "--enable-advance-metrics", + action="store_true", + default=default_obj.enable_advance_metrics, + help="Enable calculation and display of optional advance post-hoc metrics like perplexity, USE distance, etc.", + ) return parser diff --git a/textattack/attacker.py b/textattack/attacker.py index 7aa14a0a..96a9e21c 100644 --- a/textattack/attacker.py +++ b/textattack/attacker.py @@ -219,6 +219,10 @@ def _attack(self): # Enable summary stdout if not self.attack_args.silent and self.attack_args.disable_stdout: self.attack_log_manager.enable_stdout() + + if self.attack_args.enable_advance_metrics: + self.attack_log_manager.enable_advance_metrics = True + self.attack_log_manager.log_summary() self.attack_log_manager.flush() print() @@ -390,6 +394,10 @@ def _attack_parallel(self): # Enable summary stdout. if not self.attack_args.silent and self.attack_args.disable_stdout: self.attack_log_manager.enable_stdout() + + if self.attack_args.enable_advance_metrics: + self.attack_log_manager.enable_advance_metrics = True + self.attack_log_manager.log_summary() self.attack_log_manager.flush() print() diff --git a/textattack/commands/eval_model_command.py b/textattack/commands/eval_model_command.py index d05cd9fe..010d8a14 100644 --- a/textattack/commands/eval_model_command.py +++ b/textattack/commands/eval_model_command.py @@ -39,6 +39,8 @@ def get_preds(self, model, inputs): def test_model_on_dataset(self, args): model = ModelArgs._create_model_from_args(args) dataset = DatasetArgs._create_dataset_from_args(args) + if args.num_examples == -1: + args.num_examples = len(dataset) preds = [] ground_truth_outputs = [] diff --git a/textattack/dataset_args.py b/textattack/dataset_args.py index 6c4468c8..b7c61947 100644 --- a/textattack/dataset_args.py +++ b/textattack/dataset_args.py @@ -275,7 +275,9 @@ def _create_dataset_from_args(cls, args): dataset_args = (dataset_args,) if args.dataset_split: if len(dataset_args) > 1: - dataset_args[2] = args.dataset_split + dataset_args = ( + dataset_args[:1] + (args.dataset_split,) + dataset_args[2:] + ) dataset = textattack.datasets.HuggingFaceDataset( *dataset_args, shuffle=False ) diff --git a/textattack/loggers/attack_log_manager.py b/textattack/loggers/attack_log_manager.py index ab608b48..e07f1ec9 100644 --- a/textattack/loggers/attack_log_manager.py +++ b/textattack/loggers/attack_log_manager.py @@ -3,9 +3,12 @@ ======================== """ -import numpy as np - -from textattack.attack_results import FailedAttackResult, SkippedAttackResult +from textattack.metrics.attack_metrics import ( + AttackQueries, + AttackSuccessRate, + WordsPerturbed, +) +from textattack.metrics.quality_metrics import Perplexity, USEMetric from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger @@ -16,6 +19,7 @@ class AttackLogManager: def __init__(self): self.loggers = [] self.results = [] + self.enable_advance_metrics = False def enable_stdout(self): self.loggers.append(FileLogger(stdout=True)) @@ -72,103 +76,77 @@ def log_summary(self): total_attacks = len(self.results) if total_attacks == 0: return - # Count things about attacks. - all_num_words = np.zeros(len(self.results)) - perturbed_word_percentages = np.zeros(len(self.results)) - num_words_changed_until_success = np.zeros( - 2 ** 16 - ) # @ TODO: be smarter about this - failed_attacks = 0 - skipped_attacks = 0 - successful_attacks = 0 - max_words_changed = 0 - for i, result in enumerate(self.results): - all_num_words[i] = len(result.original_result.attacked_text.words) - if isinstance(result, FailedAttackResult): - failed_attacks += 1 - continue - elif isinstance(result, SkippedAttackResult): - skipped_attacks += 1 - continue - else: - successful_attacks += 1 - num_words_changed = result.original_result.attacked_text.words_diff_num( - result.perturbed_result.attacked_text - ) - # num_words_changed = len( - # result.original_result.attacked_text.all_words_diff( - # result.perturbed_result.attacked_text - # ) - # ) - num_words_changed_until_success[num_words_changed - 1] += 1 - max_words_changed = max( - max_words_changed or num_words_changed, num_words_changed - ) - if len(result.original_result.attacked_text.words) > 0: - perturbed_word_percentage = ( - num_words_changed - * 100.0 - / len(result.original_result.attacked_text.words) - ) - else: - perturbed_word_percentage = 0 - perturbed_word_percentages[i] = perturbed_word_percentage - - # Original classifier success rate on these samples. - original_accuracy = (total_attacks - skipped_attacks) * 100.0 / (total_attacks) - original_accuracy = str(round(original_accuracy, 2)) + "%" - - # New classifier success rate on these samples. - accuracy_under_attack = (failed_attacks) * 100.0 / (total_attacks) - accuracy_under_attack = str(round(accuracy_under_attack, 2)) + "%" - - # Attack success rate. - if successful_attacks + failed_attacks == 0: - attack_success_rate = 0 - else: - attack_success_rate = ( - successful_attacks * 100.0 / (successful_attacks + failed_attacks) - ) - attack_success_rate = str(round(attack_success_rate, 2)) + "%" - - perturbed_word_percentages = perturbed_word_percentages[ - perturbed_word_percentages > 0 - ] - average_perc_words_perturbed = perturbed_word_percentages.mean() - average_perc_words_perturbed = str(round(average_perc_words_perturbed, 2)) + "%" - average_num_words = all_num_words.mean() - average_num_words = str(round(average_num_words, 2)) + # Default metrics - calculated on every attack + attack_success_stats = AttackSuccessRate().calculate(self.results) + words_perturbed_stats = WordsPerturbed().calculate(self.results) + attack_query_stats = AttackQueries().calculate(self.results) + # @TODO generate this table based on user input - each column in specific class + # Example to demonstrate: + # summary_table_rows = attack_success_stats.display_row() + words_perturbed_stats.display_row() + ... summary_table_rows = [ - ["Number of successful attacks:", str(successful_attacks)], - ["Number of failed attacks:", str(failed_attacks)], - ["Number of skipped attacks:", str(skipped_attacks)], - ["Original accuracy:", original_accuracy], - ["Accuracy under attack:", accuracy_under_attack], - ["Attack success rate:", attack_success_rate], - ["Average perturbed word %:", average_perc_words_perturbed], - ["Average num. words per input:", average_num_words], + [ + "Number of successful attacks:", + attack_success_stats["successful_attacks"], + ], + ["Number of failed attacks:", attack_success_stats["failed_attacks"]], + ["Number of skipped attacks:", attack_success_stats["skipped_attacks"]], + [ + "Original accuracy:", + str(attack_success_stats["original_accuracy"]) + "%", + ], + [ + "Accuracy under attack:", + str(attack_success_stats["attack_accuracy_perc"]) + "%", + ], + [ + "Attack success rate:", + str(attack_success_stats["attack_success_rate"]) + "%", + ], + [ + "Average perturbed word %:", + str(words_perturbed_stats["avg_word_perturbed_perc"]) + "%", + ], + [ + "Average num. words per input:", + words_perturbed_stats["avg_word_perturbed"], + ], ] - num_queries = np.array( - [ - r.num_queries - for r in self.results - if not isinstance(r, SkippedAttackResult) - ] + summary_table_rows.append( + ["Avg num queries:", attack_query_stats["avg_num_queries"]] ) - avg_num_queries = num_queries.mean() - avg_num_queries = str(round(avg_num_queries, 2)) - summary_table_rows.append(["Avg num queries:", avg_num_queries]) + + if self.enable_advance_metrics: + perplexity_stats = Perplexity().calculate(self.results) + use_stats = USEMetric().calculate(self.results) + + summary_table_rows.append( + [ + "Average Original Perplexity:", + perplexity_stats["avg_original_perplexity"], + ] + ) + + summary_table_rows.append( + [ + "Average Attack Perplexity:", + perplexity_stats["avg_attack_perplexity"], + ] + ) + summary_table_rows.append( + ["Average Attack USE Score:", use_stats["avg_attack_use_score"]] + ) + self.log_summary_rows( summary_table_rows, "Attack Results", "attack_results_summary" ) # Show histogram of words changed. - numbins = max(max_words_changed, 10) + numbins = max(words_perturbed_stats["max_words_changed"], 10) for logger in self.loggers: logger.log_hist( - num_words_changed_until_success[:numbins], + words_perturbed_stats["num_words_changed_until_success"][:numbins], numbins=numbins, title="Num Words Perturbed", window_id="num_words_perturbed", diff --git a/textattack/metrics/__init__.py b/textattack/metrics/__init__.py new file mode 100644 index 00000000..fde2faf6 --- /dev/null +++ b/textattack/metrics/__init__.py @@ -0,0 +1,11 @@ +""" +""" + +from .metric import Metric + +from .attack_metrics import AttackSuccessRate +from .attack_metrics import WordsPerturbed +from .attack_metrics import AttackQueries + +from .quality_metrics import Perplexity +from .quality_metrics import USEMetric diff --git a/textattack/metrics/attack_metrics/__init__.py b/textattack/metrics/attack_metrics/__init__.py new file mode 100644 index 00000000..3eb90e34 --- /dev/null +++ b/textattack/metrics/attack_metrics/__init__.py @@ -0,0 +1,12 @@ +""" + +attack_metrics: +====================== + +TextAttack provide users common metrics on attacks' quality. + +""" + +from .attack_queries import AttackQueries +from .attack_success_rate import AttackSuccessRate +from .words_perturbed import WordsPerturbed diff --git a/textattack/metrics/attack_metrics/attack_queries.py b/textattack/metrics/attack_metrics/attack_queries.py new file mode 100644 index 00000000..7affc698 --- /dev/null +++ b/textattack/metrics/attack_metrics/attack_queries.py @@ -0,0 +1,41 @@ +""" + +Metrics on AttackQueries +========================= + +""" + +import numpy as np + +from textattack.attack_results import SkippedAttackResult +from textattack.metrics import Metric + + +class AttackQueries(Metric): + def __init__(self): + self.all_metrics = {} + + def calculate(self, results): + """Calculates all metrics related to number of queries in an attack + + Args: + results (``AttackResult`` objects): + Attack results for each instance in dataset + """ + + self.results = results + self.num_queries = np.array( + [ + r.num_queries + for r in self.results + if not isinstance(r, SkippedAttackResult) + ] + ) + self.all_metrics["avg_num_queries"] = self.avg_num_queries() + + return self.all_metrics + + def avg_num_queries(self): + avg_num_queries = self.num_queries.mean() + avg_num_queries = round(avg_num_queries, 2) + return avg_num_queries diff --git a/textattack/metrics/attack_metrics/attack_success_rate.py b/textattack/metrics/attack_metrics/attack_success_rate.py new file mode 100644 index 00000000..368e8b87 --- /dev/null +++ b/textattack/metrics/attack_metrics/attack_success_rate.py @@ -0,0 +1,74 @@ +""" + +Metrics on AttackSuccessRate +============================= + +""" + +from textattack.attack_results import FailedAttackResult, SkippedAttackResult +from textattack.metrics import Metric + + +class AttackSuccessRate(Metric): + def __init__(self): + self.failed_attacks = 0 + self.skipped_attacks = 0 + self.successful_attacks = 0 + + self.all_metrics = {} + + def calculate(self, results): + """Calculates all metrics related to number of succesful, failed and skipped results in an attack + + Args: + results (``AttackResult`` objects): + Attack results for each instance in dataset + """ + self.results = results + self.total_attacks = len(self.results) + + for i, result in enumerate(self.results): + if isinstance(result, FailedAttackResult): + self.failed_attacks += 1 + continue + elif isinstance(result, SkippedAttackResult): + self.skipped_attacks += 1 + continue + else: + self.successful_attacks += 1 + + # Calculated numbers + self.all_metrics["successful_attacks"] = self.successful_attacks + self.all_metrics["failed_attacks"] = self.failed_attacks + self.all_metrics["skipped_attacks"] = self.skipped_attacks + + # Percentages wrt the calculations + self.all_metrics["original_accuracy"] = self.original_accuracy_perc() + self.all_metrics["attack_accuracy_perc"] = self.attack_accuracy_perc() + self.all_metrics["attack_success_rate"] = self.attack_success_rate_perc() + + return self.all_metrics + + def original_accuracy_perc(self): + original_accuracy = ( + (self.total_attacks - self.skipped_attacks) * 100.0 / (self.total_attacks) + ) + original_accuracy = round(original_accuracy, 2) + return original_accuracy + + def attack_accuracy_perc(self): + accuracy_under_attack = (self.failed_attacks) * 100.0 / (self.total_attacks) + accuracy_under_attack = round(accuracy_under_attack, 2) + return accuracy_under_attack + + def attack_success_rate_perc(self): + if self.successful_attacks + self.failed_attacks == 0: + attack_success_rate = 0 + else: + attack_success_rate = ( + self.successful_attacks + * 100.0 + / (self.successful_attacks + self.failed_attacks) + ) + attack_success_rate = round(attack_success_rate, 2) + return attack_success_rate diff --git a/textattack/metrics/attack_metrics/words_perturbed.py b/textattack/metrics/attack_metrics/words_perturbed.py new file mode 100644 index 00000000..a9f29b92 --- /dev/null +++ b/textattack/metrics/attack_metrics/words_perturbed.py @@ -0,0 +1,85 @@ +""" + +Metrics on perturbed words +============================= + +""" + +import numpy as np + +from textattack.attack_results import FailedAttackResult, SkippedAttackResult +from textattack.metrics import Metric + + +class WordsPerturbed(Metric): + def __init__(self): + self.total_attacks = 0 + self.all_num_words = None + self.perturbed_word_percentages = None + self.num_words_changed_until_success = 0 + self.all_metrics = {} + + def calculate(self, results): + """Calculates all metrics related to perturbed words in an attack + + Args: + results (``AttackResult`` objects): + Attack results for each instance in dataset + """ + + self.results = results + self.total_attacks = len(self.results) + self.all_num_words = np.zeros(len(self.results)) + self.perturbed_word_percentages = np.zeros(len(self.results)) + self.num_words_changed_until_success = np.zeros(2 ** 16) + self.max_words_changed = 0 + + for i, result in enumerate(self.results): + self.all_num_words[i] = len(result.original_result.attacked_text.words) + + if isinstance(result, FailedAttackResult) or isinstance( + result, SkippedAttackResult + ): + continue + + num_words_changed = len( + result.original_result.attacked_text.all_words_diff( + result.perturbed_result.attacked_text + ) + ) + self.num_words_changed_until_success[num_words_changed - 1] += 1 + self.max_words_changed = max( + self.max_words_changed or num_words_changed, num_words_changed + ) + if len(result.original_result.attacked_text.words) > 0: + perturbed_word_percentage = ( + num_words_changed + * 100.0 + / len(result.original_result.attacked_text.words) + ) + else: + perturbed_word_percentage = 0 + + self.perturbed_word_percentages[i] = perturbed_word_percentage + + self.all_metrics["avg_word_perturbed"] = self.avg_number_word_perturbed_num() + self.all_metrics["avg_word_perturbed_perc"] = self.avg_perturbation_perc() + self.all_metrics["max_words_changed"] = self.max_words_changed + self.all_metrics[ + "num_words_changed_until_success" + ] = self.num_words_changed_until_success + + return self.all_metrics + + def avg_number_word_perturbed_num(self): + average_num_words = self.all_num_words.mean() + average_num_words = round(average_num_words, 2) + return average_num_words + + def avg_perturbation_perc(self): + self.perturbed_word_percentages = self.perturbed_word_percentages[ + self.perturbed_word_percentages > 0 + ] + average_perc_words_perturbed = self.perturbed_word_percentages.mean() + average_perc_words_perturbed = round(average_perc_words_perturbed, 2) + return average_perc_words_perturbed diff --git a/textattack/metrics/metric.py b/textattack/metrics/metric.py new file mode 100644 index 00000000..1a7c79c0 --- /dev/null +++ b/textattack/metrics/metric.py @@ -0,0 +1,28 @@ +""" +Metric Class +======================== + +""" + +from abc import ABC, abstractmethod + + +class Metric(ABC): + """A metric for evaluating Adversarial Attack candidates.""" + + @abstractmethod + def __init__(self, **kwargs): + """Creates pre-built :class:`~textattack.Metric` that correspond to + evaluation metrics for adversarial examples. + """ + raise NotImplementedError() + + @abstractmethod + def calculate(self, results): + """Abstract function for computing any values which are to be calculated as a whole during initialization + Args: + results (``AttackResult`` objects): + Attack results for each instance in dataset + """ + + raise NotImplementedError diff --git a/textattack/metrics/quality_metrics/__init__.py b/textattack/metrics/quality_metrics/__init__.py new file mode 100644 index 00000000..5addbad4 --- /dev/null +++ b/textattack/metrics/quality_metrics/__init__.py @@ -0,0 +1,12 @@ +""" + +Metrics on Quality +====================== + +TextAttack provide users common metrics on text examples' quality. + + +""" + +from .perplexity import Perplexity +from .use import USEMetric diff --git a/textattack/metrics/quality_metrics/perplexity.py b/textattack/metrics/quality_metrics/perplexity.py new file mode 100644 index 00000000..d508e29f --- /dev/null +++ b/textattack/metrics/quality_metrics/perplexity.py @@ -0,0 +1,93 @@ +""" + +Perplexity Metric: +====================== + +""" + +import torch +from transformers import GPT2LMHeadModel, GPT2Tokenizer + +from textattack.attack_results import FailedAttackResult, SkippedAttackResult +from textattack.metrics import Metric +import textattack.shared.utils + + +class Perplexity(Metric): + def __init__(self): + self.all_metrics = {} + self.original_candidates = [] + self.successful_candidates = [] + self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2") + self.ppl_model.to(textattack.shared.utils.device) + self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + self.ppl_model.eval() + self.max_length = self.ppl_model.config.n_positions + self.stride = 512 + + def calculate(self, results): + """Calculates average Perplexity on all successfull attacks using a pre-trained small GPT-2 model + + Args: + results (``AttackResult`` objects): + Attack results for each instance in dataset + """ + self.results = results + self.original_candidates_ppl = [] + self.successful_candidates_ppl = [] + + for i, result in enumerate(self.results): + if isinstance(result, FailedAttackResult): + continue + elif isinstance(result, SkippedAttackResult): + continue + else: + self.original_candidates.append( + result.original_result.attacked_text.text.lower() + ) + self.successful_candidates.append( + result.perturbed_result.attacked_text.text.lower() + ) + + ppl_orig = self.calc_ppl(self.original_candidates) + ppl_attack = self.calc_ppl(self.successful_candidates) + + self.all_metrics["avg_original_perplexity"] = round(ppl_orig[0], 2) + self.all_metrics["original_perplexity_list"] = ppl_orig[1] + + self.all_metrics["avg_attack_perplexity"] = round(ppl_attack[0], 2) + self.all_metrics["attack_perplexity_list"] = ppl_attack[1] + + return self.all_metrics + + def calc_ppl(self, texts): + + ppl_vals = [] + + with torch.no_grad(): + for text in texts: + eval_loss = [] + input_ids = torch.tensor( + self.ppl_tokenizer.encode(text, add_special_tokens=True) + ).unsqueeze(0) + # Strided perplexity calculation from huggingface.co/transformers/perplexity.html + for i in range(0, input_ids.size(1), self.stride): + begin_loc = max(i + self.stride - self.max_length, 0) + end_loc = min(i + self.stride, input_ids.size(1)) + trg_len = end_loc - i + input_ids_t = input_ids[:, begin_loc:end_loc].to( + textattack.shared.utils.device + ) + target_ids = input_ids_t.clone() + target_ids[:, :-trg_len] = -100 + + outputs = self.ppl_model(input_ids_t, labels=target_ids) + log_likelihood = outputs[0] * trg_len + + eval_loss.append(log_likelihood) + + ppl_vals.append( + torch.exp(torch.stack(eval_loss).sum() / end_loc).item() + ) + + return sum(ppl_vals) / len(ppl_vals), ppl_vals diff --git a/textattack/metrics/quality_metrics/use.py b/textattack/metrics/quality_metrics/use.py new file mode 100644 index 00000000..424727cf --- /dev/null +++ b/textattack/metrics/quality_metrics/use.py @@ -0,0 +1,44 @@ +from textattack.attack_results import FailedAttackResult, SkippedAttackResult +from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder +from textattack.metrics import Metric + + +class USEMetric(Metric): + def __init__(self, **kwargs): + self.use_obj = UniversalSentenceEncoder() + self.use_obj.model = UniversalSentenceEncoder() + self.original_candidates = [] + self.successful_candidates = [] + self.all_metrics = {} + + def calculate(self, results): + """Calculates average USE similarity on all successfull attacks + + Args: + results (``AttackResult`` objects): + Attack results for each instance in dataset + """ + self.results = results + + for i, result in enumerate(self.results): + if isinstance(result, FailedAttackResult): + continue + elif isinstance(result, SkippedAttackResult): + continue + else: + self.original_candidates.append(result.original_result.attacked_text) + self.successful_candidates.append(result.perturbed_result.attacked_text) + + use_scores = [] + for c in range(len(self.original_candidates)): + use_scores.append( + self.use_obj._sim_score( + self.original_candidates[c], self.successful_candidates[c] + ).item() + ) + + self.all_metrics["avg_attack_use_score"] = round( + sum(use_scores) / len(use_scores), 2 + ) + + return self.all_metrics