Skip to content

Commit

Permalink
Merge pull request #1894 from swsuggs/assorted-improvements
Browse files Browse the repository at this point in the history
Assorted improvements
  • Loading branch information
lcadalzo authored Mar 27, 2023
2 parents 2560098 + 5abe034 commit f98e9ab
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 11 deletions.
1 change: 1 addition & 0 deletions armory/baseline_models/tf_graph/audio_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_art_model(

loss_object = losses.SparseCategoricalCrossentropy()

@tf.function
def train_step(model, samples, labels):
with tf.GradientTape() as tape:
predictions = model(samples, training=True)
Expand Down
18 changes: 9 additions & 9 deletions armory/data/adversarial_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,15 +661,15 @@ def carla_over_obj_det_dev(
**kwargs,
):
"""
Dev set for CARLA object detection dataset, containing RGB and depth channels. The dev
Dev set for CARLA overhead object detection dataset, containing RGB and depth channels. The dev
set also contains green screens for adversarial patch insertion.
"""
if "class_ids" in kwargs:
raise ValueError(
"Filtering by class is not supported for the carla_obj_det_dev dataset"
"Filtering by class is not supported for the carla_over_obj_det_dev dataset"
)
if batch_size != 1:
raise ValueError("carla_obj_det_dev batch size must be set to 1")
raise ValueError("carla_over_obj_det_dev batch size must be set to 1")

modality = kwargs.pop("modality", "rgb")
if modality not in ["rgb", "depth", "both"]:
Expand Down Expand Up @@ -729,15 +729,15 @@ def carla_over_obj_det_test(
**kwargs,
):
"""
Dev set for CARLA object detection dataset, containing RGB and depth channels. The test
Test set for CARLA overhead object detection dataset, containing RGB and depth channels. The test
set also contains green screens for adversarial patch insertion.
"""
if "class_ids" in kwargs:
raise ValueError(
"Filtering by class is not supported for the carla_obj_det_test dataset"
"Filtering by class is not supported for the carla_over_obj_det_test dataset"
)
if batch_size != 1:
raise ValueError("carla_obj_det_test batch size must be set to 1")
raise ValueError("carla_over_obj_det_test batch size must be set to 1")

modality = kwargs.pop("modality", "rgb")
if modality not in ["rgb", "depth", "both"]:
Expand Down Expand Up @@ -924,7 +924,7 @@ def carla_video_tracking_dev(
"Filtering by class is not supported for the carla_video_tracking_dev dataset"
)
if batch_size != 1:
raise ValueError("carla_obj_det_dev batch size must be set to 1")
raise ValueError("carla_video_tracking_dev batch size must be set to 1")

if max_frames:
clip = datasets.ClipFrames(max_frames)
Expand Down Expand Up @@ -975,10 +975,10 @@ def carla_video_tracking_test(
"""
if "class_ids" in kwargs:
raise ValueError(
"Filtering by class is not supported for the carla_video_tracking_dev dataset"
"Filtering by class is not supported for the carla_video_tracking_test dataset"
)
if batch_size != 1:
raise ValueError("carla_obj_det_dev batch size must be set to 1")
raise ValueError("carla_video_tracking_test batch size must be set to 1")

if max_frames:
clip = datasets.ClipFrames(max_frames)
Expand Down
4 changes: 2 additions & 2 deletions armory/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,11 +1034,11 @@ def carla_over_obj_det_train(
**kwargs,
) -> ArmoryDataGenerator:
"""
Training set for CARLA object detection dataset, containing RGB and depth channels.
Training set for CARLA overhead object detection dataset, containing RGB and depth channels.
"""
if "class_ids" in kwargs:
raise ValueError(
"Filtering by class is not supported for the carla_obj_det_train dataset"
"Filtering by class is not supported for the carla_over_obj_det_train dataset"
)
modality = kwargs.pop("modality", "rgb")
if modality not in ["rgb", "depth", "both"]:
Expand Down
5 changes: 5 additions & 0 deletions armory/utils/config_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def load_dataset(dataset_config, *args, num_batches=None, check_run=False, **kwa
if check_run:
return EvalGenerator(dataset, num_eval_batches=1)
if num_batches:
if num_batches > dataset.batches_per_epoch:
# since num-eval-batches only applies at test time, we can assume there is only 1 epoch
raise ValueError(
f"{num_batches} eval batches were requested, but dataset has only {dataset.batches_per_epoch} batches of size {dataset.batch_size}"
)
return EvalGenerator(dataset, num_eval_batches=num_batches)
return dataset

Expand Down

0 comments on commit f98e9ab

Please sign in to comment.