Skip to content

Commit

Permalink
Merge pull request #24 from twosixlabs/mnist-example
Browse files Browse the repository at this point in the history
Mnist example
  • Loading branch information
mwartell authored Oct 26, 2023
2 parents 0ddf47b + 9ecc553 commit 40f7f5c
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 10 deletions.
37 changes: 35 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,45 @@ jobs:
run: make test


integration-tests:
name: 🧪 Integration Tests - ${{ matrix.test.name }}
runs-on: ubuntu-latest
needs: code-linting
strategy:
fail-fast: false
matrix:
test:
- name: Image Classification
dir: src/charmory_examples/image_classification
script: mnist_vit_pgd.py --batch-size 2 --num-batches 1 --export-every-n-batches 1
- name: Object Detection
dir: src/charmory_examples/object_detection
script: yolov5_license_plates.py --batch-size 2 --num-batches 1 --export-every-n-batches 1
python-version: ["3.8"]
steps:
- name: Checkout
uses: actions/checkout@v3

- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install
# Consider switching to a non-yolo OD example and removing the yolo install,
# it adds 2-ish minutes to this install step.
run: cd examples && pip install --no-compile --editable .[armory,huggingface,yolo]

- name: Test
run: cd examples/${{ matrix.test.dir }} && python ${{ matrix.test.script }}


generate-docs:
name: 📖 Generate Docs
runs-on: ubuntu-latest
needs:
- code-linting
- unit-tests
- integration-tests
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down Expand Up @@ -114,8 +147,8 @@ jobs:
name: 🔨 Build
runs-on: ubuntu-latest
needs:
- code-linting
- unit-tests
- integration-tests
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
19 changes: 14 additions & 5 deletions examples/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ dependencies = []
charmory = "charmory_examples.cifar_example:main"

[project.optional-dependencies]
armory = ["albumentations", "charmory @ {root:uri}/../", "lightning"]
armory = [
"albumentations",
"armory-library @ {root:uri}/../",
]

developer = [
"hatch", # build tool
Expand All @@ -30,22 +33,28 @@ developer = [
"flake8",
]

jatic = [
"jatic_toolbox @ git+ssh://git@gitlab.jatic.net/jatic/cdao/jatic-toolbox.git@v0.2.0rc1",
# HuggingFace
huggingface = [
"datasets",
"huggingface_hub",
"transformers",
]

jatic = [
"jatic_toolbox @ git+ssh://git@gitlab.jatic.net/jatic/cdao/jatic-toolbox.git@v0.2.0rc1",
]

yolo = [
"yolov5",
]

demo = ["pandas", "sklearn", "plotly", "numpy", "matplotlib", "seaborn"]


all = [
"armory-examples[armory]",
"armory-examples[developer]",
"armory-examples[huggingface]",
"armory-examples[jatic]",
"armory-examples[yolo]",
]

[build-system]
Expand Down
140 changes: 140 additions & 0 deletions examples/src/charmory_examples/image_classification/mnist_vit_pgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import argparse
from pprint import pprint

from art.attacks.evasion import ProjectedGradientDescent
from art.estimators.classification import PyTorchClassifier
import datasets
import torch
import torch.nn
from transformers import AutoImageProcessor, AutoModelForImageClassification

from armory.metrics.compute import BasicProfiler
from charmory.data import ArmoryDataLoader
from charmory.engine import EvaluationEngine
import charmory.evaluation as ev
from charmory.model.image_classification import JaticImageClassificationModel
from charmory.tasks.image_classification import ImageClassificationTask
from charmory.track import track_init_params, track_params
from charmory.utils import Unnormalize


def get_cli_args():
parser = argparse.ArgumentParser(
description="MNIST image classification using a ViT model and PGD attack",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--batch-size",
default=16,
type=int,
)
parser.add_argument(
"--export-every-n-batches",
default=5,
type=int,
)
parser.add_argument(
"--num-batches",
default=10,
type=int,
)
return parser.parse_args()


@track_params
def main(batch_size, export_every_n_batches, num_batches):
###
# Model
###
model = JaticImageClassificationModel(
track_params(AutoModelForImageClassification.from_pretrained)(
"farleyknight-org-username/vit-base-mnist"
),
)
classifier = track_init_params(PyTorchClassifier)(
model,
loss=torch.nn.CrossEntropyLoss(),
optimizer=torch.optim.Adam(model.parameters(), lr=0.003),
input_shape=(3, 224, 224),
channels_first=True,
nb_classes=10,
clip_values=(-1, 1),
)

###
# Dataset
###
dataset = datasets.load_dataset("mnist", split="test")
processor = AutoImageProcessor.from_pretrained(
"farleyknight-org-username/vit-base-mnist"
)

def transform(sample):
# Use the HF image processor and convert from BW To RGB
sample["image"] = processor([img.convert("RGB") for img in sample["image"]])[
"pixel_values"
]
return sample

dataset.set_transform(transform)
dataloader = ArmoryDataLoader(dataset, batch_size=batch_size)

###
# Attack
###
attack = track_init_params(ProjectedGradientDescent)(
classifier,
batch_size=batch_size,
eps=0.031,
eps_step=0.007,
max_iter=20,
num_random_init=1,
random_eps=False,
targeted=False,
verbose=False,
)

###
# Evaluation
###
evaluation = ev.Evaluation(
name="mnist-vit-pgd",
description="MNIST image classification using a ViT model and PGD attack",
author="TwoSix",
dataset=ev.Dataset(
name="MNIST",
x_key="image",
y_key="label",
test_dataloader=dataloader,
),
model=ev.Model(
name="ViT",
model=classifier,
),
attack=ev.Attack(
name="PGD",
attack=attack,
use_label_for_untargeted=False,
),
metric=ev.Metric(profiler=BasicProfiler()),
)

###
# Engine
###
task = ImageClassificationTask(
evaluation,
num_classes=10,
export_adapter=Unnormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
export_every_n_batches=export_every_n_batches,
)
engine = EvaluationEngine(task, limit_test_batches=num_batches)

###
# Execute
###
pprint(engine.run())


if __name__ == "__main__":
main(**vars(get_cli_args()))
7 changes: 6 additions & 1 deletion src/charmory/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Mapping, Optional
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional

import lightning.pytorch as pl
from lightning.pytorch.loggers import MLFlowLogger
Expand All @@ -14,6 +14,9 @@
from charmory.evaluation import Evaluation
from charmory.export import Exporter, MlflowExporter

ExportAdapter = Callable[[Any], Any]
"""An adapter for exported data (e.g., images). """


class BaseEvaluationTask(pl.LightningModule, ABC):
"""Base Armory evaluation task"""
Expand All @@ -23,12 +26,14 @@ def __init__(
evaluation: Evaluation,
skip_benign: bool = False,
skip_attack: bool = False,
export_adapter: Optional[ExportAdapter] = None,
export_every_n_batches: int = 0,
):
super().__init__()
self.evaluation = evaluation
self.skip_benign = skip_benign
self.skip_attack = skip_attack
self.export_adapter = export_adapter
self.export_every_n_batches = export_every_n_batches
self._exporter: Optional[Exporter] = None

Expand Down
5 changes: 4 additions & 1 deletion src/charmory/tasks/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def _export_image(self, name, batch_data, batch_idx):
batch_size = batch_data.shape[0]
for sample_idx in range(batch_size):
filename = f"batch_{batch_idx}_ex_{sample_idx}_{name}.png"
self.exporter.log_image(batch_data[sample_idx], filename)
image = batch_data[sample_idx]
if self.export_adapter is not None:
image = self.export_adapter(image)
self.exporter.log_image(image, filename)

@staticmethod
def _from_list(maybe_list, idx):
Expand Down
5 changes: 4 additions & 1 deletion src/charmory/tasks/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,14 @@ def export_batch(self, batch: BaseEvaluationTask.Batch):
def _export_image(self, name, images, truth, preds, batch_idx):
batch_size = images.shape[0]
for sample_idx in range(batch_size):
image = images[sample_idx]
if self.export_adapter is not None:
image = self.export_adapter(image)
boxes_above_threshold = preds[sample_idx]["boxes"][
preds[sample_idx]["scores"] > self.export_score_threshold
]
with_boxes = draw_boxes_on_image(
image=images[sample_idx],
image=image,
ground_truth_boxes=truth[sample_idx]["boxes"],
pred_boxes=boxes_above_threshold,
)
Expand Down
24 changes: 24 additions & 0 deletions src/charmory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,27 @@ def __call__(self, sample):
sample["image"] = [img.transpose(2, 0, 1) for img in sample["image"]]

return sample


class Unnormalize:
"""
Inverse of `torchvision.transforms.Normalize` transform.
"""

def __init__(self, mean, std):
"""
Initialize the transform.
Args:
mean: Sequence of means for each channel
std: Sequence of standard deviations for each channel
"""
self.mean = mean
self.std = std

def __call__(self, data):
unnormalized = deepcopy(data)
for t, m, s in zip(unnormalized, self.mean, self.std):
t *= s
t += m
return unnormalized

0 comments on commit 40f7f5c

Please sign in to comment.