-
Notifications
You must be signed in to change notification settings - Fork 404
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CODE] Transfer from #475 + new structure
- Loading branch information
sanchit97
committed
Aug 20, 2021
1 parent
f64df48
commit c0e2993
Showing
8 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
""" | ||
""" | ||
|
||
from .attack_metrics import AttackMetric | ||
# from .quality_metrics import QualityMetric |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
""" | ||
attack_metrics: | ||
====================== | ||
TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display. | ||
""" | ||
|
||
from .attack_metric import AttackMetric | ||
from .attack_queries import AttackQueries | ||
from .attack_success_rate import AttackSuccessRate | ||
from .words_perturbed import WordsPerturbed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
""" | ||
Attack Metrics Class | ||
======================== | ||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from textattack.attack_results import AttackResult | ||
|
||
|
||
class AttackMetric: | ||
"""A metric for evaluating Adversarial Attack candidates.""" | ||
|
||
@abstractmethod | ||
def __init__(self, results, **kwargs): | ||
"""Creates pre-built :class:`~textattack.AttackMetric` that correspond to | ||
evaluation metrics for adversarial examples. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def calculate(): | ||
""" Abstract function for computing any values which are to be calculated as a whole during initialization""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy as np | ||
|
||
from textattack.attack_results import SkippedAttackResult | ||
|
||
from .attack_metric import AttackMetric | ||
|
||
|
||
class AttackQueries(AttackMetric): | ||
"""Calculates all metrics related to number of queries in an attack | ||
Args: | ||
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`): | ||
Attack results for each instance in dataset | ||
""" | ||
|
||
def __init__(self, results): | ||
self.results = results | ||
|
||
self.all_metrics = {} | ||
|
||
def calculate(self): | ||
self.num_queries = np.array( | ||
[ | ||
r.num_queries | ||
for r in self.results | ||
if not isinstance(r, SkippedAttackResult) | ||
] | ||
) | ||
self.all_metrics['avg_num_queries'] = self.avg_num_queries() | ||
|
||
return self.all_metrics | ||
|
||
def avg_num_queries(self): | ||
avg_num_queries = self.num_queries.mean() | ||
avg_num_queries = round(avg_num_queries, 2) | ||
return avg_num_queries |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult | ||
|
||
from .attack_metric import AttackMetric | ||
|
||
|
||
class AttackSuccessRate(AttackMetric): | ||
"""Calculates all metrics related to number of succesful, failed and skipped results in an attack | ||
Args: | ||
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`): | ||
Attack results for each instance in dataset | ||
""" | ||
|
||
def __init__(self, results): | ||
self.results = results | ||
self.failed_attacks = 0 | ||
self.skipped_attacks = 0 | ||
self.successful_attacks = 0 | ||
self.total_attacks = len(self.results) | ||
|
||
self.all_metrics = {} | ||
|
||
def calculate(self): | ||
for i, result in enumerate(self.results): | ||
if isinstance(result, FailedAttackResult): | ||
self.failed_attacks += 1 | ||
continue | ||
elif isinstance(result, SkippedAttackResult): | ||
self.skipped_attacks += 1 | ||
continue | ||
else: | ||
self.successful_attacks += 1 | ||
|
||
# Calculated numbers | ||
self.all_metrics['successful_attacks'] = self.successful_attacks | ||
self.all_metrics['failed_attacks'] = self.failed_attacks | ||
self.all_metrics['skipped_attacks'] = self.skipped_attacks | ||
|
||
# Percentages wrt the calculations | ||
self.all_metrics['original_accuracy'] = self.original_accuracy_perc() | ||
self.all_metrics['attack_accuracy_perc'] = self.attack_accuracy_perc() | ||
self.all_metrics['attack_success_rate'] = self.attack_success_rate_perc() | ||
|
||
return self.all_metrics | ||
|
||
|
||
def original_accuracy_perc(self): | ||
original_accuracy = ( | ||
(self.total_attacks - self.skipped_attacks) * 100.0 / (self.total_attacks) | ||
) | ||
original_accuracy = round(original_accuracy, 2) | ||
return original_accuracy | ||
|
||
def attack_accuracy_perc(self): | ||
accuracy_under_attack = (self.failed_attacks) * 100.0 / (self.total_attacks) | ||
accuracy_under_attack = round(accuracy_under_attack, 2) | ||
return accuracy_under_attack | ||
|
||
def attack_success_rate_perc(self): | ||
if self.successful_attacks + self.failed_attacks == 0: | ||
attack_success_rate = 0 | ||
else: | ||
attack_success_rate = ( | ||
self.successful_attacks | ||
* 100.0 | ||
/ (self.successful_attacks + self.failed_attacks) | ||
) | ||
attack_success_rate = round(attack_success_rate, 2) | ||
return attack_success_rate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
|
||
from textattack.attack_results import FailedAttackResult, SkippedAttackResult | ||
|
||
from .attack_metric import AttackMetric | ||
|
||
|
||
class WordsPerturbed(AttackMetric): | ||
def __init__(self, results): | ||
self.results = results | ||
self.total_attacks = len(self.results) | ||
self.all_num_words = np.zeros(len(self.results)) | ||
self.perturbed_word_percentages = np.zeros(len(self.results)) | ||
self.num_words_changed_until_success = np.zeros(2 ** 16) | ||
self.all_metrics = {} | ||
|
||
def calculate(self): | ||
self.max_words_changed = 0 | ||
for i, result in enumerate(self.results): | ||
self.all_num_words[i] = len(result.original_result.attacked_text.words) | ||
|
||
if isinstance(result, FailedAttackResult) or isinstance( | ||
result, SkippedAttackResult | ||
): | ||
continue | ||
|
||
num_words_changed = len( | ||
result.original_result.attacked_text.all_words_diff( | ||
result.perturbed_result.attacked_text | ||
) | ||
) | ||
self.num_words_changed_until_success[num_words_changed - 1] += 1 | ||
self.max_words_changed = max( | ||
self.max_words_changed or num_words_changed, num_words_changed | ||
) | ||
if len(result.original_result.attacked_text.words) > 0: | ||
perturbed_word_percentage = ( | ||
num_words_changed | ||
* 100.0 | ||
/ len(result.original_result.attacked_text.words) | ||
) | ||
else: | ||
perturbed_word_percentage = 0 | ||
|
||
self.perturbed_word_percentages[i] = perturbed_word_percentage | ||
|
||
self.all_metrics['avg_word_perturbed'] = self.avg_number_word_perturbed_num() | ||
self.all_metrics['avg_word_perturbed_perc'] = self.avg_perturbation_perc() | ||
self.all_metrics['max_words_changed'] = self.max_words_changed | ||
self.all_metrics['num_words_changed_until_success'] = self.num_words_changed_until_success | ||
|
||
return self.all_metrics | ||
|
||
def avg_number_word_perturbed_num(self): | ||
average_num_words = self.all_num_words.mean() | ||
average_num_words = round(average_num_words, 2) | ||
return average_num_words | ||
|
||
def avg_perturbation_perc(self): | ||
self.perturbed_word_percentages = self.perturbed_word_percentages[ | ||
self.perturbed_word_percentages > 0 | ||
] | ||
average_perc_words_perturbed = self.perturbed_word_percentages.mean() | ||
average_perc_words_perturbed = round(average_perc_words_perturbed, 2) | ||
return average_perc_words_perturbed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
attack_metrics: | ||
====================== | ||
TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display. | ||
""" | ||
|
||
from .quality_metric import QualityMetric | ||
|
Empty file.