-
Notifications
You must be signed in to change notification settings - Fork 404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New metric module to improve flexibility and intuitiveness #475
Changes from all commits
1900868
90cc0ce
0691c4c
8710133
bb5e17d
43dd3d2
a074816
16b8081
8beb20f
091bcd1
d4d2d1b
f93c6e4
3e1b05d
19d0c27
db572f9
769dbdd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
""" | ||
|
||
attack_metrics: | ||
====================== | ||
|
||
TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display. | ||
|
||
|
||
""" | ||
|
||
from .attack_metric import AttackMetric | ||
from .attack_queries import AttackQueries | ||
from .attack_success_rate import AttackSuccessRate | ||
from .words_perturbed import WordsPerturbed |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
""" | ||
Attack Metrics Class | ||
======================== | ||
|
||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from textattack.attack_results import AttackResult | ||
|
||
|
||
class AttackMetric(AttackResult, ABC): | ||
"""A metric for evaluating Adversarial Attack candidates.""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sanchit97 please update regarding the comment.. |
||
def __init__(self, results, **kwargs): | ||
"""Creates pre-built :class:`~textattack.AttackMetric` that correspond to | ||
evaluation metrics for adversarial examples. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def calculate(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing here. I don't think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sanchit97 please update regarding the comment.. |
||
""" Abstract function for computing any values which are to be calculated as a whole during initialization""" | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import numpy as np | ||
|
||
from textattack.attack_results import SkippedAttackResult | ||
|
||
from .attack_metric import AttackMetric | ||
|
||
|
||
class AttackQueries(AttackMetric): | ||
"""Calculates all metrics related to number of queries in an attack | ||
|
||
Args: | ||
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be wrong, but can you check if this works a valid docstring for sphinx? I recall that references don't get nested. Ideally, we want something like You can check by running There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sanchit97 please update regarding the comment.. |
||
Attack results for each instance in dataset | ||
""" | ||
|
||
def __init__(self, results): | ||
self.results = results | ||
|
||
self.calculate() | ||
|
||
def calculate(self): | ||
self.num_queries = np.array( | ||
[ | ||
r.num_queries | ||
for r in self.results | ||
if not isinstance(r, SkippedAttackResult) | ||
] | ||
) | ||
|
||
def avg_num_queries_num(self): | ||
avg_num_queries = self.num_queries.mean() | ||
avg_num_queries = round(avg_num_queries, 2) | ||
return avg_num_queries |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult | ||
|
||
from .attack_metric import AttackMetric | ||
|
||
|
||
class AttackSuccessRate(AttackMetric): | ||
"""Calculates all metrics related to number of succesful, failed and skipped results in an attack | ||
|
||
Args: | ||
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar thing here. Can you check if docstring works? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sanchit97 please update regarding the comment.. |
||
Attack results for each instance in dataset | ||
""" | ||
|
||
def __init__(self, results): | ||
self.results = results | ||
self.failed_attacks = 0 | ||
self.skipped_attacks = 0 | ||
self.successful_attacks = 0 | ||
self.total_attacks = len(self.results) | ||
|
||
self.calculate() | ||
|
||
def calculate(self): | ||
for i, result in enumerate(self.results): | ||
if isinstance(result, FailedAttackResult): | ||
self.failed_attacks += 1 | ||
continue | ||
elif isinstance(result, SkippedAttackResult): | ||
self.skipped_attacks += 1 | ||
continue | ||
else: | ||
self.successful_attacks += 1 | ||
|
||
def successful_attacks_num(self): | ||
return self.successful_attacks | ||
|
||
def failed_attacks_num(self): | ||
return self.failed_attacks | ||
|
||
def skipped_attacks_num(self): | ||
return self.skipped_attacks | ||
|
||
def original_accuracy_perc(self): | ||
original_accuracy = ( | ||
(self.total_attacks - self.skipped_attacks) * 100.0 / (self.total_attacks) | ||
) | ||
original_accuracy = round(original_accuracy, 2) | ||
return original_accuracy | ||
|
||
def attack_accuracy_perc(self): | ||
accuracy_under_attack = (self.failed_attacks) * 100.0 / (self.total_attacks) | ||
accuracy_under_attack = round(accuracy_under_attack, 2) | ||
return accuracy_under_attack | ||
|
||
def attack_success_rate_perc(self): | ||
if self.successful_attacks + self.failed_attacks == 0: | ||
attack_success_rate = 0 | ||
else: | ||
attack_success_rate = ( | ||
self.successful_attacks | ||
* 100.0 | ||
/ (self.successful_attacks + self.failed_attacks) | ||
) | ||
attack_success_rate = round(attack_success_rate, 2) | ||
return attack_success_rate |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
|
||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult | ||
|
||
from .attack_metric import AttackMetric | ||
|
||
|
||
class WordsPerturbed(AttackMetric): | ||
def __init__(self, results): | ||
self.results = results | ||
self.total_attacks = len(self.results) | ||
self.all_num_words = np.zeros(len(self.results)) | ||
self.perturbed_word_percentages = np.zeros(len(self.results)) | ||
self.num_words_changed_until_success = np.zeros(2 ** 16) | ||
|
||
self.calculate() | ||
|
||
def calculate(self): | ||
self.max_words_changed = 0 | ||
for i, result in enumerate(self.results): | ||
self.all_num_words[i] = len(result.original_result.attacked_text.words) | ||
|
||
if isinstance(result, FailedAttackResult) or isinstance( | ||
result, SkippedAttackResult | ||
): | ||
continue | ||
|
||
num_words_changed = len( | ||
result.original_result.attacked_text.all_words_diff( | ||
result.perturbed_result.attacked_text | ||
) | ||
) | ||
self.num_words_changed_until_success[num_words_changed - 1] += 1 | ||
self.max_words_changed = max( | ||
self.max_words_changed or num_words_changed, num_words_changed | ||
) | ||
if len(result.original_result.attacked_text.words) > 0: | ||
perturbed_word_percentage = ( | ||
num_words_changed | ||
* 100.0 | ||
/ len(result.original_result.attacked_text.words) | ||
) | ||
else: | ||
perturbed_word_percentage = 0 | ||
|
||
self.perturbed_word_percentages[i] = perturbed_word_percentage | ||
|
||
def avg_number_word_perturbed_num(self): | ||
average_num_words = self.all_num_words.mean() | ||
average_num_words = round(average_num_words, 2) | ||
return average_num_words | ||
|
||
def avg_perturbation_perc(self): | ||
self.perturbed_word_percentages = self.perturbed_word_percentages[ | ||
self.perturbed_word_percentages > 0 | ||
] | ||
average_perc_words_perturbed = self.perturbed_word_percentages.mean() | ||
average_perc_words_perturbed = round(average_perc_words_perturbed, 2) | ||
return average_perc_words_perturbed | ||
|
||
def max_words_changed_num(self): | ||
return self.max_words_changed | ||
|
||
def num_words_changed_until_success_num(self): | ||
return self.num_words_changed_until_success |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the difference between |
||
|
||
metrics: | ||
====================== | ||
|
||
TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display. | ||
|
||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason why
AttackMetric
is a sub-class ofAttackResult
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sanchit97 I have the same concerns as @jinyongyoo above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is an oversight. Fixed.