Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric module to improve flexibility and intuitiveness #475

Closed
wants to merge 16 commits into from
Closed
14 changes: 14 additions & 0 deletions textattack/attack_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""

attack_metrics:
======================

TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display.


"""

from .attack_metric import AttackMetric
from .attack_queries import AttackQueries
from .attack_success_rate import AttackSuccessRate
from .words_perturbed import WordsPerturbed
27 changes: 27 additions & 0 deletions textattack/attack_metrics/attack_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Attack Metrics Class
========================

"""

from abc import ABC, abstractmethod

from textattack.attack_results import AttackResult


class AttackMetric(AttackResult, ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason why AttackMetric is a sub-class of AttackResult?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit97 I have the same concerns as @jinyongyoo above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is an oversight. Fixed.

"""A metric for evaluating Adversarial Attack candidates."""

@staticmethod
@abstractmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think @staticmethod here is necessary since it's not really a static method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit97 please update regarding the comment..

def __init__(self, results, **kwargs):
"""Creates pre-built :class:`~textattack.AttackMetric` that correspond to
evaluation metrics for adversarial examples.
"""
raise NotImplementedError()

@staticmethod
@abstractmethod
def calculate():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here. I don't think @staticmethod should be here (unless you intend to make it a static method). But in other sub-classes you define it as a regular instance method, so I'm guessing this should also be a regular method. Also in that case, you should have self as the first argument.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit97 please update regarding the comment..

""" Abstract function for computing any values which are to be calculated as a whole during initialization"""
raise NotImplementedError
33 changes: 33 additions & 0 deletions textattack/attack_metrics/attack_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np

from textattack.attack_results import SkippedAttackResult

from .attack_metric import AttackMetric


class AttackQueries(AttackMetric):
"""Calculates all metrics related to number of queries in an attack

Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong, but can you check if this works a valid docstring for sphinx? I recall that references don't get nested. Ideally, we want something like list[GoalFunctionResult].

You can check by running sphinx-autobuild docs docs/_build/html --port 8765 --host 0.0.0.0 and then connecting to the host+port in your browser.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit97 please update regarding the comment..

Attack results for each instance in dataset
"""

def __init__(self, results):
self.results = results

self.calculate()

def calculate(self):
self.num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
)

def avg_num_queries_num(self):
avg_num_queries = self.num_queries.mean()
avg_num_queries = round(avg_num_queries, 2)
return avg_num_queries
65 changes: 65 additions & 0 deletions textattack/attack_metrics/attack_success_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from textattack.attack_results import FailedAttackResult, SkippedAttackResult

from .attack_metric import AttackMetric


class AttackSuccessRate(AttackMetric):
"""Calculates all metrics related to number of succesful, failed and skipped results in an attack

Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar thing here. Can you check if docstring works?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sanchit97 please update regarding the comment..

Attack results for each instance in dataset
"""

def __init__(self, results):
self.results = results
self.failed_attacks = 0
self.skipped_attacks = 0
self.successful_attacks = 0
self.total_attacks = len(self.results)

self.calculate()

def calculate(self):
for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
self.failed_attacks += 1
continue
elif isinstance(result, SkippedAttackResult):
self.skipped_attacks += 1
continue
else:
self.successful_attacks += 1

def successful_attacks_num(self):
return self.successful_attacks

def failed_attacks_num(self):
return self.failed_attacks

def skipped_attacks_num(self):
return self.skipped_attacks

def original_accuracy_perc(self):
original_accuracy = (
(self.total_attacks - self.skipped_attacks) * 100.0 / (self.total_attacks)
)
original_accuracy = round(original_accuracy, 2)
return original_accuracy

def attack_accuracy_perc(self):
accuracy_under_attack = (self.failed_attacks) * 100.0 / (self.total_attacks)
accuracy_under_attack = round(accuracy_under_attack, 2)
return accuracy_under_attack

def attack_success_rate_perc(self):
if self.successful_attacks + self.failed_attacks == 0:
attack_success_rate = 0
else:
attack_success_rate = (
self.successful_attacks
* 100.0
/ (self.successful_attacks + self.failed_attacks)
)
attack_success_rate = round(attack_success_rate, 2)
return attack_success_rate
Empty file.
65 changes: 65 additions & 0 deletions textattack/attack_metrics/words_perturbed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np

from textattack.attack_results import FailedAttackResult, SkippedAttackResult

from .attack_metric import AttackMetric


class WordsPerturbed(AttackMetric):
def __init__(self, results):
self.results = results
self.total_attacks = len(self.results)
self.all_num_words = np.zeros(len(self.results))
self.perturbed_word_percentages = np.zeros(len(self.results))
self.num_words_changed_until_success = np.zeros(2 ** 16)

self.calculate()

def calculate(self):
self.max_words_changed = 0
for i, result in enumerate(self.results):
self.all_num_words[i] = len(result.original_result.attacked_text.words)

if isinstance(result, FailedAttackResult) or isinstance(
result, SkippedAttackResult
):
continue

num_words_changed = len(
result.original_result.attacked_text.all_words_diff(
result.perturbed_result.attacked_text
)
)
self.num_words_changed_until_success[num_words_changed - 1] += 1
self.max_words_changed = max(
self.max_words_changed or num_words_changed, num_words_changed
)
if len(result.original_result.attacked_text.words) > 0:
perturbed_word_percentage = (
num_words_changed
* 100.0
/ len(result.original_result.attacked_text.words)
)
else:
perturbed_word_percentage = 0

self.perturbed_word_percentages[i] = perturbed_word_percentage

def avg_number_word_perturbed_num(self):
average_num_words = self.all_num_words.mean()
average_num_words = round(average_num_words, 2)
return average_num_words

def avg_perturbation_perc(self):
self.perturbed_word_percentages = self.perturbed_word_percentages[
self.perturbed_word_percentages > 0
]
average_perc_words_perturbed = self.perturbed_word_percentages.mean()
average_perc_words_perturbed = round(average_perc_words_perturbed, 2)
return average_perc_words_perturbed

def max_words_changed_num(self):
return self.max_words_changed

def num_words_changed_until_success_num(self):
return self.num_words_changed_until_success
123 changes: 38 additions & 85 deletions textattack/loggers/attack_log_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
========================
"""

import numpy as np

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.attack_metrics import AttackQueries, AttackSuccessRate, WordsPerturbed

from . import CSVLogger, FileLogger, VisdomLogger, WeightsAndBiasesLogger

Expand Down Expand Up @@ -72,100 +70,55 @@ def log_summary(self):
total_attacks = len(self.results)
if total_attacks == 0:
return
# Count things about attacks.
all_num_words = np.zeros(len(self.results))
perturbed_word_percentages = np.zeros(len(self.results))
num_words_changed_until_success = np.zeros(
2 ** 16
) # @ TODO: be smarter about this
failed_attacks = 0
skipped_attacks = 0
successful_attacks = 0
max_words_changed = 0
for i, result in enumerate(self.results):
all_num_words[i] = len(result.original_result.attacked_text.words)
if isinstance(result, FailedAttackResult):
failed_attacks += 1
continue
elif isinstance(result, SkippedAttackResult):
skipped_attacks += 1
continue
else:
successful_attacks += 1
num_words_changed = len(
result.original_result.attacked_text.all_words_diff(
result.perturbed_result.attacked_text
)
)
num_words_changed_until_success[num_words_changed - 1] += 1
max_words_changed = max(
max_words_changed or num_words_changed, num_words_changed
)
if len(result.original_result.attacked_text.words) > 0:
perturbed_word_percentage = (
num_words_changed
* 100.0
/ len(result.original_result.attacked_text.words)
)
else:
perturbed_word_percentage = 0
perturbed_word_percentages[i] = perturbed_word_percentage

# Original classifier success rate on these samples.
original_accuracy = (total_attacks - skipped_attacks) * 100.0 / (total_attacks)
original_accuracy = str(round(original_accuracy, 2)) + "%"

# New classifier success rate on these samples.
accuracy_under_attack = (failed_attacks) * 100.0 / (total_attacks)
accuracy_under_attack = str(round(accuracy_under_attack, 2)) + "%"

# Attack success rate.
if successful_attacks + failed_attacks == 0:
attack_success_rate = 0
else:
attack_success_rate = (
successful_attacks * 100.0 / (successful_attacks + failed_attacks)
)
attack_success_rate = str(round(attack_success_rate, 2)) + "%"

perturbed_word_percentages = perturbed_word_percentages[
perturbed_word_percentages > 0
]
average_perc_words_perturbed = perturbed_word_percentages.mean()
average_perc_words_perturbed = str(round(average_perc_words_perturbed, 2)) + "%"

average_num_words = all_num_words.mean()
average_num_words = str(round(average_num_words, 2))
# Default metrics - calculated on every attack
attack_success_stats = AttackSuccessRate(self.results)
words_perturbed_stats = WordsPerturbed(self.results)
attack_query_stats = AttackQueries(self.results)

# @TODO generate this table based on user input - each column in specific class
# Example to demonstrate:
# summary_table_rows = attack_success_stats.display_row() + words_perturbed_stats.display_row() + ...
summary_table_rows = [
["Number of successful attacks:", str(successful_attacks)],
["Number of failed attacks:", str(failed_attacks)],
["Number of skipped attacks:", str(skipped_attacks)],
["Original accuracy:", original_accuracy],
["Accuracy under attack:", accuracy_under_attack],
["Attack success rate:", attack_success_rate],
["Average perturbed word %:", average_perc_words_perturbed],
["Average num. words per input:", average_num_words],
[
"Number of successful attacks:",
attack_success_stats.successful_attacks_num(),
],
["Number of failed attacks:", attack_success_stats.failed_attacks_num()],
["Number of skipped attacks:", attack_success_stats.skipped_attacks_num()],
[
"Original accuracy:",
str(attack_success_stats.original_accuracy_perc()) + "%",
],
[
"Accuracy under attack:",
str(attack_success_stats.attack_accuracy_perc()) + "%",
],
[
"Attack success rate:",
str(attack_success_stats.attack_success_rate_perc()) + "%",
],
[
"Average perturbed word %:",
str(words_perturbed_stats.avg_perturbation_perc()) + "%",
],
[
"Average num. words per input:",
words_perturbed_stats.avg_number_word_perturbed_num(),
],
]

num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
summary_table_rows.append(
["Avg num queries:", attack_query_stats.avg_num_queries_num()]
)
avg_num_queries = num_queries.mean()
avg_num_queries = str(round(avg_num_queries, 2))
summary_table_rows.append(["Avg num queries:", avg_num_queries])
self.log_summary_rows(
summary_table_rows, "Attack Results", "attack_results_summary"
)
# Show histogram of words changed.
numbins = max(max_words_changed, 10)
numbins = max(words_perturbed_stats.max_words_changed_num(), 10)
for logger in self.loggers:
logger.log_hist(
num_words_changed_until_success[:numbins],
words_perturbed_stats.num_words_changed_until_success_num()[:numbins],
numbins=numbins,
title="Num Words Perturbed",
window_id="num_words_perturbed",
Expand Down
9 changes: 9 additions & 0 deletions textattack/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between attack_metrics module and metrics module?


metrics:
======================

TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display.


"""