diff --git a/textattack/attack_metrics/__init__.py b/textattack/attack_metrics/__init__.py new file mode 100644 index 00000000..d7f9b54e --- /dev/null +++ b/textattack/attack_metrics/__init__.py @@ -0,0 +1,14 @@ +""" + +attack_metrics: +====================== + +TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display. + + +""" + +from .attack_metric import AttackMetric +from .attack_queries import AttackQueries +from .attack_success_rate import AttackSuccessRate +from .words_perturbed import WordsPerturbed diff --git a/textattack/attack_metrics/attack_metric.py b/textattack/attack_metrics/attack_metric.py new file mode 100644 index 00000000..1487b63c --- /dev/null +++ b/textattack/attack_metrics/attack_metric.py @@ -0,0 +1,27 @@ +""" +Attack Metrics Class +======================== + +""" + +from abc import ABC, abstractmethod + +from textattack.attack_results import AttackResult + + +class AttackMetric(AttackResult, ABC): + """A metric for evaluating Adversarial Attack candidates.""" + + @staticmethod + @abstractmethod + def __init__(self, results, **kwargs): + """Creates pre-built :class:`~textattack.AttackMetric` that correspond to + evaluation metrics for adversarial examples. + """ + raise NotImplementedError() + + @staticmethod + @abstractmethod + def calculate(): + """ Abstract function for computing any values which are to be calculated as a whole during initialization""" + raise NotImplementedError diff --git a/textattack/attack_metrics/attack_queries.py b/textattack/attack_metrics/attack_queries.py new file mode 100644 index 00000000..0a2d60ba --- /dev/null +++ b/textattack/attack_metrics/attack_queries.py @@ -0,0 +1,33 @@ +import numpy as np + +from textattack.attack_results import SkippedAttackResult + +from .attack_metric import AttackMetric + + +class AttackQueries(AttackMetric): + """Calculates all metrics related to number of queries in an attack + + Args: + results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`): + Attack results for each instance in dataset + """ + + def __init__(self, results): + self.results = results + + self.calculate() + + def calculate(self): + self.num_queries = np.array( + [ + r.num_queries + for r in self.results + if not isinstance(r, SkippedAttackResult) + ] + ) + + def avg_num_queries_num(self): + avg_num_queries = self.num_queries.mean() + avg_num_queries = round(avg_num_queries, 2) + return avg_num_queries diff --git a/textattack/attack_metrics/attack_success_rate.py b/textattack/attack_metrics/attack_success_rate.py new file mode 100644 index 00000000..7cb38eb8 --- /dev/null +++ b/textattack/attack_metrics/attack_success_rate.py @@ -0,0 +1,65 @@ +from textattack.attack_results import FailedAttackResult, SkippedAttackResult + +from .attack_metric import AttackMetric + + +class AttackSuccessRate(AttackMetric): + """Calculates all metrics related to number of succesful, failed and skipped results in an attack + + Args: + results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`): + Attack results for each instance in dataset + """ + + def __init__(self, results): + self.results = results + self.failed_attacks = 0 + self.skipped_attacks = 0 + self.successful_attacks = 0 + self.total_attacks = len(self.results) + + self.calculate() + + def calculate(self): + for i, result in enumerate(self.results): + if isinstance(result, FailedAttackResult): + self.failed_attacks += 1 + continue + elif isinstance(result, SkippedAttackResult): + self.skipped_attacks += 1 + continue + else: + self.successful_attacks += 1 + + def successful_attacks_num(self): + return self.successful_attacks + + def failed_attacks_num(self): + return self.failed_attacks + + def skipped_attacks_num(self): + return self.skipped_attacks + + def original_accuracy_perc(self): + original_accuracy = ( + (self.total_attacks - self.skipped_attacks) * 100.0 / (self.total_attacks) + ) + original_accuracy = round(original_accuracy, 2) + return original_accuracy + + def attack_accuracy_perc(self): + accuracy_under_attack = (self.failed_attacks) * 100.0 / (self.total_attacks) + accuracy_under_attack = round(accuracy_under_attack, 2) + return accuracy_under_attack + + def attack_success_rate_perc(self): + if self.successful_attacks + self.failed_attacks == 0: + attack_success_rate = 0 + else: + attack_success_rate = ( + self.successful_attacks + * 100.0 + / (self.successful_attacks + self.failed_attacks) + ) + attack_success_rate = round(attack_success_rate, 2) + return attack_success_rate diff --git a/textattack/attack_metrics/testing.py b/textattack/attack_metrics/testing.py new file mode 100644 index 00000000..e69de29b diff --git a/textattack/attack_metrics/words_perturbed.py b/textattack/attack_metrics/words_perturbed.py new file mode 100644 index 00000000..36d0dd59 --- /dev/null +++ b/textattack/attack_metrics/words_perturbed.py @@ -0,0 +1,65 @@ +import numpy as np + +from textattack.attack_results import FailedAttackResult, SkippedAttackResult + +from .attack_metric import AttackMetric + + +class WordsPerturbed(AttackMetric): + def __init__(self, results): + self.results = results + self.total_attacks = len(self.results) + self.all_num_words = np.zeros(len(self.results)) + self.perturbed_word_percentages = np.zeros(len(self.results)) + self.num_words_changed_until_success = np.zeros(2 ** 16) + + self.calculate() + + def calculate(self): + self.max_words_changed = 0 + for i, result in enumerate(self.results): + self.all_num_words[i] = len(result.original_result.attacked_text.words) + + if isinstance(result, FailedAttackResult) or isinstance( + result, SkippedAttackResult + ): + continue + + num_words_changed = len( + result.original_result.attacked_text.all_words_diff( + result.perturbed_result.attacked_text + ) + ) + self.num_words_changed_until_success[num_words_changed - 1] += 1 + self.max_words_changed = max( + self.max_words_changed or num_words_changed, num_words_changed + ) + if len(result.original_result.attacked_text.words) > 0: + perturbed_word_percentage = ( + num_words_changed + * 100.0 + / len(result.original_result.attacked_text.words) + ) + else: + perturbed_word_percentage = 0 + + self.perturbed_word_percentages[i] = perturbed_word_percentage + + def avg_number_word_perturbed_num(self): + average_num_words = self.all_num_words.mean() + average_num_words = round(average_num_words, 2) + return average_num_words + + def avg_perturbation_perc(self): + self.perturbed_word_percentages = self.perturbed_word_percentages[ + self.perturbed_word_percentages > 0 + ] + average_perc_words_perturbed = self.perturbed_word_percentages.mean() + average_perc_words_perturbed = round(average_perc_words_perturbed, 2) + return average_perc_words_perturbed + + def max_words_changed_num(self): + return self.max_words_changed + + def num_words_changed_until_success_num(self): + return self.num_words_changed_until_success diff --git a/textattack/loggers/attack_log_manager.py b/textattack/loggers/attack_log_manager.py index 7141d479..232d2eb2 100644 --- a/textattack/loggers/attack_log_manager.py +++ b/textattack/loggers/attack_log_manager.py @@ -3,9 +3,7 @@ ======================== """ -import numpy as np - -from textattack.attack_results import FailedAttackResult, SkippedAttackResult +from textattack.attack_metrics import AttackQueries, AttackSuccessRate, WordsPerturbed from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger @@ -72,100 +70,55 @@ def log_summary(self): total_attacks = len(self.results) if total_attacks == 0: return - # Count things about attacks. - all_num_words = np.zeros(len(self.results)) - perturbed_word_percentages = np.zeros(len(self.results)) - num_words_changed_until_success = np.zeros( - 2 ** 16 - ) # @ TODO: be smarter about this - failed_attacks = 0 - skipped_attacks = 0 - successful_attacks = 0 - max_words_changed = 0 - for i, result in enumerate(self.results): - all_num_words[i] = len(result.original_result.attacked_text.words) - if isinstance(result, FailedAttackResult): - failed_attacks += 1 - continue - elif isinstance(result, SkippedAttackResult): - skipped_attacks += 1 - continue - else: - successful_attacks += 1 - num_words_changed = len( - result.original_result.attacked_text.all_words_diff( - result.perturbed_result.attacked_text - ) - ) - num_words_changed_until_success[num_words_changed - 1] += 1 - max_words_changed = max( - max_words_changed or num_words_changed, num_words_changed - ) - if len(result.original_result.attacked_text.words) > 0: - perturbed_word_percentage = ( - num_words_changed - * 100.0 - / len(result.original_result.attacked_text.words) - ) - else: - perturbed_word_percentage = 0 - perturbed_word_percentages[i] = perturbed_word_percentage - - # Original classifier success rate on these samples. - original_accuracy = (total_attacks - skipped_attacks) * 100.0 / (total_attacks) - original_accuracy = str(round(original_accuracy, 2)) + "%" - - # New classifier success rate on these samples. - accuracy_under_attack = (failed_attacks) * 100.0 / (total_attacks) - accuracy_under_attack = str(round(accuracy_under_attack, 2)) + "%" - - # Attack success rate. - if successful_attacks + failed_attacks == 0: - attack_success_rate = 0 - else: - attack_success_rate = ( - successful_attacks * 100.0 / (successful_attacks + failed_attacks) - ) - attack_success_rate = str(round(attack_success_rate, 2)) + "%" - perturbed_word_percentages = perturbed_word_percentages[ - perturbed_word_percentages > 0 - ] - average_perc_words_perturbed = perturbed_word_percentages.mean() - average_perc_words_perturbed = str(round(average_perc_words_perturbed, 2)) + "%" - - average_num_words = all_num_words.mean() - average_num_words = str(round(average_num_words, 2)) + # Default metrics - calculated on every attack + attack_success_stats = AttackSuccessRate(self.results) + words_perturbed_stats = WordsPerturbed(self.results) + attack_query_stats = AttackQueries(self.results) + # @TODO generate this table based on user input - each column in specific class + # Example to demonstrate: + # summary_table_rows = attack_success_stats.display_row() + words_perturbed_stats.display_row() + ... summary_table_rows = [ - ["Number of successful attacks:", str(successful_attacks)], - ["Number of failed attacks:", str(failed_attacks)], - ["Number of skipped attacks:", str(skipped_attacks)], - ["Original accuracy:", original_accuracy], - ["Accuracy under attack:", accuracy_under_attack], - ["Attack success rate:", attack_success_rate], - ["Average perturbed word %:", average_perc_words_perturbed], - ["Average num. words per input:", average_num_words], + [ + "Number of successful attacks:", + attack_success_stats.successful_attacks_num(), + ], + ["Number of failed attacks:", attack_success_stats.failed_attacks_num()], + ["Number of skipped attacks:", attack_success_stats.skipped_attacks_num()], + [ + "Original accuracy:", + str(attack_success_stats.original_accuracy_perc()) + "%", + ], + [ + "Accuracy under attack:", + str(attack_success_stats.attack_accuracy_perc()) + "%", + ], + [ + "Attack success rate:", + str(attack_success_stats.attack_success_rate_perc()) + "%", + ], + [ + "Average perturbed word %:", + str(words_perturbed_stats.avg_perturbation_perc()) + "%", + ], + [ + "Average num. words per input:", + words_perturbed_stats.avg_number_word_perturbed_num(), + ], ] - num_queries = np.array( - [ - r.num_queries - for r in self.results - if not isinstance(r, SkippedAttackResult) - ] + summary_table_rows.append( + ["Avg num queries:", attack_query_stats.avg_num_queries_num()] ) - avg_num_queries = num_queries.mean() - avg_num_queries = str(round(avg_num_queries, 2)) - summary_table_rows.append(["Avg num queries:", avg_num_queries]) self.log_summary_rows( summary_table_rows, "Attack Results", "attack_results_summary" ) # Show histogram of words changed. - numbins = max(max_words_changed, 10) + numbins = max(words_perturbed_stats.max_words_changed_num(), 10) for logger in self.loggers: logger.log_hist( - num_words_changed_until_success[:numbins], + words_perturbed_stats.num_words_changed_until_success_num()[:numbins], numbins=numbins, title="Num Words Perturbed", window_id="num_words_perturbed", diff --git a/textattack/metrics/__init__.py b/textattack/metrics/__init__.py new file mode 100644 index 00000000..637ecf0f --- /dev/null +++ b/textattack/metrics/__init__.py @@ -0,0 +1,9 @@ +""" + +metrics: +====================== + +TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display. + + +"""