Skip to content

Commit

Permalink
[CODE] Transfer from #475 + new structure
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit97 committed Aug 20, 2021
1 parent f64df48 commit c0e2993
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 0 deletions.
5 changes: 5 additions & 0 deletions textattack/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
"""

from .attack_metrics import AttackMetric
# from .quality_metrics import QualityMetric
14 changes: 14 additions & 0 deletions textattack/metrics/attack_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
attack_metrics:
======================
TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display.
"""

from .attack_metric import AttackMetric
from .attack_queries import AttackQueries
from .attack_success_rate import AttackSuccessRate
from .words_perturbed import WordsPerturbed
25 changes: 25 additions & 0 deletions textattack/metrics/attack_metrics/attack_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Attack Metrics Class
========================
"""

from abc import ABC, abstractmethod

from textattack.attack_results import AttackResult


class AttackMetric:
"""A metric for evaluating Adversarial Attack candidates."""

@abstractmethod
def __init__(self, results, **kwargs):
"""Creates pre-built :class:`~textattack.AttackMetric` that correspond to
evaluation metrics for adversarial examples.
"""
raise NotImplementedError()

@abstractmethod
def calculate():
""" Abstract function for computing any values which are to be calculated as a whole during initialization"""
raise NotImplementedError
36 changes: 36 additions & 0 deletions textattack/metrics/attack_metrics/attack_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np

from textattack.attack_results import SkippedAttackResult

from .attack_metric import AttackMetric


class AttackQueries(AttackMetric):
"""Calculates all metrics related to number of queries in an attack
Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""

def __init__(self, results):
self.results = results

self.all_metrics = {}

def calculate(self):
self.num_queries = np.array(
[
r.num_queries
for r in self.results
if not isinstance(r, SkippedAttackResult)
]
)
self.all_metrics['avg_num_queries'] = self.avg_num_queries()

return self.all_metrics

def avg_num_queries(self):
avg_num_queries = self.num_queries.mean()
avg_num_queries = round(avg_num_queries, 2)
return avg_num_queries
69 changes: 69 additions & 0 deletions textattack/metrics/attack_metrics/attack_success_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from textattack.attack_results import FailedAttackResult, SkippedAttackResult

from .attack_metric import AttackMetric


class AttackSuccessRate(AttackMetric):
"""Calculates all metrics related to number of succesful, failed and skipped results in an attack
Args:
results (:obj::`list`:class:`~textattack.goal_function_results.GoalFunctionResult`):
Attack results for each instance in dataset
"""

def __init__(self, results):
self.results = results
self.failed_attacks = 0
self.skipped_attacks = 0
self.successful_attacks = 0
self.total_attacks = len(self.results)

self.all_metrics = {}

def calculate(self):
for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
self.failed_attacks += 1
continue
elif isinstance(result, SkippedAttackResult):
self.skipped_attacks += 1
continue
else:
self.successful_attacks += 1

# Calculated numbers
self.all_metrics['successful_attacks'] = self.successful_attacks
self.all_metrics['failed_attacks'] = self.failed_attacks
self.all_metrics['skipped_attacks'] = self.skipped_attacks

# Percentages wrt the calculations
self.all_metrics['original_accuracy'] = self.original_accuracy_perc()
self.all_metrics['attack_accuracy_perc'] = self.attack_accuracy_perc()
self.all_metrics['attack_success_rate'] = self.attack_success_rate_perc()

return self.all_metrics


def original_accuracy_perc(self):
original_accuracy = (
(self.total_attacks - self.skipped_attacks) * 100.0 / (self.total_attacks)
)
original_accuracy = round(original_accuracy, 2)
return original_accuracy

def attack_accuracy_perc(self):
accuracy_under_attack = (self.failed_attacks) * 100.0 / (self.total_attacks)
accuracy_under_attack = round(accuracy_under_attack, 2)
return accuracy_under_attack

def attack_success_rate_perc(self):
if self.successful_attacks + self.failed_attacks == 0:
attack_success_rate = 0
else:
attack_success_rate = (
self.successful_attacks
* 100.0
/ (self.successful_attacks + self.failed_attacks)
)
attack_success_rate = round(attack_success_rate, 2)
return attack_success_rate
65 changes: 65 additions & 0 deletions textattack/metrics/attack_metrics/words_perturbed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np

from textattack.attack_results import FailedAttackResult, SkippedAttackResult

from .attack_metric import AttackMetric


class WordsPerturbed(AttackMetric):
def __init__(self, results):
self.results = results
self.total_attacks = len(self.results)
self.all_num_words = np.zeros(len(self.results))
self.perturbed_word_percentages = np.zeros(len(self.results))
self.num_words_changed_until_success = np.zeros(2 ** 16)
self.all_metrics = {}

def calculate(self):
self.max_words_changed = 0
for i, result in enumerate(self.results):
self.all_num_words[i] = len(result.original_result.attacked_text.words)

if isinstance(result, FailedAttackResult) or isinstance(
result, SkippedAttackResult
):
continue

num_words_changed = len(
result.original_result.attacked_text.all_words_diff(
result.perturbed_result.attacked_text
)
)
self.num_words_changed_until_success[num_words_changed - 1] += 1
self.max_words_changed = max(
self.max_words_changed or num_words_changed, num_words_changed
)
if len(result.original_result.attacked_text.words) > 0:
perturbed_word_percentage = (
num_words_changed
* 100.0
/ len(result.original_result.attacked_text.words)
)
else:
perturbed_word_percentage = 0

self.perturbed_word_percentages[i] = perturbed_word_percentage

self.all_metrics['avg_word_perturbed'] = self.avg_number_word_perturbed_num()
self.all_metrics['avg_word_perturbed_perc'] = self.avg_perturbation_perc()
self.all_metrics['max_words_changed'] = self.max_words_changed
self.all_metrics['num_words_changed_until_success'] = self.num_words_changed_until_success

return self.all_metrics

def avg_number_word_perturbed_num(self):
average_num_words = self.all_num_words.mean()
average_num_words = round(average_num_words, 2)
return average_num_words

def avg_perturbation_perc(self):
self.perturbed_word_percentages = self.perturbed_word_percentages[
self.perturbed_word_percentages > 0
]
average_perc_words_perturbed = self.perturbed_word_percentages.mean()
average_perc_words_perturbed = round(average_perc_words_perturbed, 2)
return average_perc_words_perturbed
12 changes: 12 additions & 0 deletions textattack/metrics/quality_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
attack_metrics:
======================
TextAttack allows users to use their own metrics on adversarial examples or select common metrics to display.
"""

from .quality_metric import QualityMetric

Empty file.

0 comments on commit c0e2993

Please sign in to comment.