diff --git a/examples/src/charmory_examples/xview_example.py b/examples/src/charmory_examples/xview_example.py index 572da4142..e942c131d 100644 --- a/examples/src/charmory_examples/xview_example.py +++ b/examples/src/charmory_examples/xview_example.py @@ -1,3 +1,4 @@ +from pathlib import Path from pprint import pprint import sys @@ -5,11 +6,15 @@ import albumentations as A import art.attacks.evasion from art.estimators.object_detection import PyTorchFasterRCNN +import boto3 +import botocore from datasets import load_dataset -import jatic_toolbox from jatic_toolbox import __version__ as jatic_version from jatic_toolbox.interop.huggingface import HuggingFaceObjectDetectionDataset +from jatic_toolbox.interop.torchvision import TorchVisionObjectDetector import numpy as np +import torch +from torchvision.transforms._presets import ObjectDetection from armory.art_experimental.attacks.patch import AttackWrapper from armory.metrics.compute import BasicProfiler @@ -19,20 +24,26 @@ from charmory.evaluation import Attack, Dataset, Evaluation, Metric, Model, SysConfig from charmory.model.object_detection import JaticObjectDetectionModel from charmory.tasks.object_detection import ObjectDetectionTask -from charmory.track import track_init_params, track_params + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +from charmory.track import track_init_params from charmory.utils import create_jatic_dataset_transform BATCH_SIZE = 1 TRAINING_EPOCHS = 20 -import torch +BUCKET_NAME = "armory-library-data" +KEY = "fasterrcnn_mobilenet_v3_2" + torch.set_float32_matmul_precision("high") +import armory.data.datasets def load_huggingface_dataset(): - train_data = load_dataset("Honaker/xview_dataset", split="train") + train_data = load_dataset("Honaker/xview_dataset_subset", split="train") - new_dataset = train_data.train_test_split(test_size=0.2, seed=1) + new_dataset = train_data.train_test_split(test_size=0.4, seed=3) train_dataset, test_dataset = new_dataset["train"], new_dataset["test"] train_dataset, test_dataset = HuggingFaceObjectDetectionDataset( @@ -53,13 +64,23 @@ def main(argv: list = sys.argv[1:]): ### # Model ### - model = track_params(jatic_toolbox.load_model)( - provider="torchvision", - model_name="fasterrcnn_resnet50_fpn", - task="object-detection", + s3 = boto3.resource("s3") + try: + s3.Bucket(BUCKET_NAME).download_file( + KEY, Path.cwd() / "fasterrcnn_mobilenet_v3_2" + ) + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "404": + print("The object does not exist.") + else: + raise + + model = torch.load(Path.cwd() / "fasterrcnn_mobilenet_v3_2") + model.to(DEVICE) + + model = TorchVisionObjectDetector( + model=model, processor=ObjectDetection(), labels=None ) - - # Bypass JATIC model wrapper to allow targeted adversarial attacks model.forward = model._model.forward detector = track_init_params(PyTorchFasterRCNN)( @@ -67,6 +88,7 @@ def main(argv: list = sys.argv[1:]): channels_first=True, clip_values=(0.0, 1.0), ) + model_transform = create_jatic_dataset_transform(model.preprocessor) train_dataset, test_dataset = load_huggingface_dataset() @@ -105,7 +127,6 @@ def transform(sample): transformed = model_transform(transformed) return transformed - train_dataset.set_transform(transform) test_dataset.set_transform(transform) train_dataloader = ArmoryDataLoader( @@ -122,7 +143,7 @@ def transform(sample): test_dataset=test_dataloader, ) eval_model = Model( - name="fasterrcnn-resnet-50", + name="xview-trained-fasterrcnn-resnet-50", model=detector, ) @@ -167,7 +188,7 @@ def transform(sample): task = ObjectDetectionTask( evaluation, - export_every_n_batches=5, + export_every_n_batches=2, class_metrics=False, ) engine = LightningEngine(task, limit_test_batches=10)