Skip to content

Commit

Permalink
Merge pull request #139 from twosixlabs/120-explainable-ai-outputs
Browse files Browse the repository at this point in the history
120 explainable ai outputs
  • Loading branch information
mwartell authored Apr 26, 2024
2 parents 41e140c + c48723f commit e6d71d2
Show file tree
Hide file tree
Showing 17 changed files with 1,128 additions and 26 deletions.
8 changes: 4 additions & 4 deletions examples/notebooks/image_classification_food101.ipynb

Large diffs are not rendered by default.

78 changes: 68 additions & 10 deletions examples/notebooks/object_detection_license_plates.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions examples/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ dependencies = [
cv = [
# 1.4.1 adds version constraints to scikit-learn and numpy that are incompatible with ART
"albumentations<1.4.1",
"captum",
"scikit-learn>=1.2", # required by xaitk-saliency
"tidecv",
"xaitk-saliency",
"yolov5",
]

Expand Down
22 changes: 19 additions & 3 deletions examples/src/armory/examples/image_classification/food101.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import armory.dataset
import armory.engine
import armory.evaluation
import armory.export.captum
import armory.export.criteria
import armory.export.image_classification
import armory.export.xaitksaliency
import armory.metric
import armory.metrics.compute
import armory.metrics.perturbation
Expand Down Expand Up @@ -81,7 +83,9 @@ def load_model():
armory_model = armory.model.image_classification.ImageClassifier(
name="ViT-finetuned-food101",
model=hf_model,
accessor=armory.data.Images.as_torch(scale=normalized_scale),
accessor=armory.data.Images.as_torch(
dim=armory.data.ImageDimensions.CHW, scale=normalized_scale
),
)

art_classifier = armory.track.track_init_params(
Expand Down Expand Up @@ -270,12 +274,24 @@ def create_metrics():
}


def create_exporters(export_every_n_batches):
def create_exporters(model, export_every_n_batches):
"""Create sample exporters"""
return [
armory.export.image_classification.ImageClassificationExporter(
criterion=armory.export.criteria.every_n_batches(export_every_n_batches)
),
armory.export.captum.CaptumImageClassificationExporter(
model,
criterion=armory.export.criteria.every_n_batches(export_every_n_batches),
),
armory.export.xaitksaliency.XaitkSaliencyBlackboxImageClassificationExporter(
name="slidingwindow",
model=model,
classes=[6, 23], # beignets(6), churros(23)
criterion=armory.export.criteria.when_metric_in(
armory.export.criteria.batch_targets(), [6, 23]
),
),
]


Expand Down Expand Up @@ -303,7 +319,7 @@ def main(
)
perturbations = dict()
metrics = create_metrics()
exporters = create_exporters(export_every_n_batches)
exporters = create_exporters(model, export_every_n_batches)
profiler = armory.metrics.compute.BasicProfiler()

if "benign" in chains:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import armory.engine
import armory.evaluation
import armory.export.criteria
import armory.export.drise
import armory.export.object_detection
import armory.metric
import armory.metrics.compute
Expand Down Expand Up @@ -168,12 +169,18 @@ def create_metrics():
}


def create_exporters(export_every_n_batches):
def create_exporters(model, export_every_n_batches):
"""Create sample exporters"""
return [
armory.export.object_detection.ObjectDetectionExporter(
criterion=armory.export.criteria.every_n_batches(export_every_n_batches)
),
armory.export.drise.DRiseSaliencyObjectDetectionExporter(
model,
criterion=armory.export.criteria.every_n_batches(export_every_n_batches),
num_classes=91,
num_masks=10,
),
]


Expand All @@ -188,7 +195,7 @@ def main(batch_size, export_every_n_batches, num_batches, seed, shuffle):
dataset = load_dataset(batch_size, shuffle)
attack = create_attack(art_detector, batch_size)
metrics = create_metrics()
exporters = create_exporters(export_every_n_batches)
exporters = create_exporters(model, export_every_n_batches)

evaluation = armory.evaluation.Evaluation(
name="coco-detection-fasterrcnn-resnet50",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import armory.engine
import armory.evaluation
import armory.export.criteria
import armory.export.drise
import armory.export.object_detection
import armory.metric
import armory.metrics.compute
Expand Down Expand Up @@ -198,12 +199,18 @@ def create_metrics():
}


def create_exporters(export_every_n_batches):
def create_exporters(model, export_every_n_batches):
"""Create sample exporters"""
return [
armory.export.object_detection.ObjectDetectionExporter(
criterion=armory.export.criteria.every_n_batches(export_every_n_batches)
),
armory.export.drise.DRiseSaliencyObjectDetectionExporter(
model,
criterion=armory.export.criteria.every_n_batches(export_every_n_batches),
num_classes=2,
num_masks=10,
),
]


Expand All @@ -218,7 +225,7 @@ def main(batch_size, export_every_n_batches, num_batches, seed, shuffle):
dataset = load_dataset(batch_size, shuffle)
attack = create_attack(art_detector, batch_size)
metrics = create_metrics()
exporters = create_exporters(export_every_n_batches)
exporters = create_exporters(model, export_every_n_batches)

evaluation = armory.evaluation.Evaluation(
name="license-plate-detection-yolos",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import armory.engine
import armory.evaluation
import armory.export.criteria
import armory.export.drise
import armory.export.object_detection
import armory.metric
import armory.metrics.compute
Expand Down Expand Up @@ -183,12 +184,18 @@ def create_metrics():
}


def create_exporters(export_every_n_batches):
def create_exporters(model, export_every_n_batches):
"""Create sample exporters"""
return [
armory.export.object_detection.ObjectDetectionExporter(
criterion=armory.export.criteria.every_n_batches(export_every_n_batches)
),
armory.export.drise.DRiseSaliencyObjectDetectionExporter(
model,
criterion=armory.export.criteria.every_n_batches(export_every_n_batches),
num_classes=1,
num_masks=10,
),
]


Expand All @@ -203,7 +210,7 @@ def main(batch_size, export_every_n_batches, num_batches, seed, shuffle):
dataset = load_dataset(batch_size, shuffle)
attack = create_attack(art_detector, batch_size)
metrics = create_metrics()
exporters = create_exporters(export_every_n_batches)
exporters = create_exporters(model, export_every_n_batches)

evaluation = armory.evaluation.Evaluation(
name="license-plate-detection-yolov5",
Expand Down
3 changes: 2 additions & 1 deletion examples/src/armory/examples/utils/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def display_object_detection_results(
for sample_idx in range(batch_size):
for chain_idx, chain in enumerate(chains):
image_filename = Exporter.artifact_path(
chain, batch_idx, sample_idx, "input.png"
chain, batch_idx, sample_idx, "objects.png"
)
client.download_artifacts(run_id, image_filename, tmpdir)
image = plt.imread(tmppath / image_filename)
Expand All @@ -112,6 +112,7 @@ def display_object_detection_results(
ax.tick_params(
bottom=False, left=False, labelbottom=False, labelleft=False
)
ax.axis("off")

fig.suptitle(f"Batch {batch_idx}")
fig.tight_layout()
2 changes: 2 additions & 0 deletions library/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ dependencies = [
extras = [
"adversarial-robustness-toolbox",
"albumentations",
"captum",
"matplotlib",
"pandas",
"tidecv",
"transformers",
"xaitk-saliency",
]


Expand Down
Loading

0 comments on commit e6d71d2

Please sign in to comment.