Skip to content

Commit

Permalink
Merge pull request #28 from twosixlabs/27-move-xview-dataset-to-s3
Browse files Browse the repository at this point in the history
moved xview dataset
  • Loading branch information
treubig26 authored Oct 31, 2023
2 parents 40f7f5c + 78f2110 commit 26dcbaf
Showing 1 changed file with 21 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from art.estimators.object_detection import PyTorchFasterRCNN
import boto3
import botocore
from datasets import load_dataset
from jatic_toolbox import __version__ as jatic_version
from jatic_toolbox.interop.huggingface import HuggingFaceObjectDetectionDataset
from jatic_toolbox.interop.torchvision import TorchVisionObjectDetector
Expand All @@ -25,25 +24,26 @@
from charmory.model.object_detection import JaticObjectDetectionModel
from charmory.tasks.object_detection import ObjectDetectionTask

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision("high")
from datasets import load_from_disk
from datasets.filesystems import S3FileSystem

import armory.data.datasets
from charmory.track import track_init_params
from charmory.utils import create_jatic_dataset_transform

BATCH_SIZE = 1
TRAINING_EPOCHS = 20
BUCKET_NAME = "armory-library-data"
KEY = "fasterrcnn_mobilenet_v3_2"


torch.set_float32_matmul_precision("high")
import armory.data.datasets
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_huggingface_dataset():
train_data = load_dataset("Honaker/xview_dataset_subset", split="train")
s3 = S3FileSystem(anon=False)
train_dataset = load_from_disk("s3://armory-library-data/datasets/train/", fs=s3)

new_dataset = train_data.train_test_split(test_size=0.4, seed=3)
new_dataset = train_dataset.train_test_split(test_size=0.4, seed=3)
train_dataset, test_dataset = new_dataset["train"], new_dataset["test"]

train_dataset, test_dataset = HuggingFaceObjectDetectionDataset(
Expand All @@ -65,17 +65,17 @@ def main(argv: list = sys.argv[1:]):
# Model
###
s3 = boto3.resource("s3")
try:
s3.Bucket(BUCKET_NAME).download_file(
KEY, Path.cwd() / "fasterrcnn_mobilenet_v3_2"
)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404":
print("The object does not exist.")
else:
raise

model = torch.load(Path.cwd() / "fasterrcnn_mobilenet_v3_2")
my_file = Path.cwd() / "fasterrcnn_mobilenet_v3_2"
if not my_file.is_file():
try:
s3.Bucket(BUCKET_NAME).download_file(KEY, my_file)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404":
print("The object does not exist.")
else:
raise

model = torch.load(my_file)
model.to(DEVICE)

model = TorchVisionObjectDetector(
Expand Down Expand Up @@ -127,6 +127,7 @@ def transform(sample):
transformed = model_transform(transformed)
return transformed

train_dataset.set_transform(transform)
test_dataset.set_transform(transform)

train_dataloader = ArmoryDataLoader(
Expand All @@ -140,7 +141,7 @@ def transform(sample):
eval_dataset = Dataset(
name="XVIEW",
x_key="image",
y_key="label",
y_key="objects",
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
)
Expand Down

0 comments on commit 26dcbaf

Please sign in to comment.