Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing errors #10

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,24 @@ repos:
rev: v1.6.0
hooks:
- id: mypy
# When charmory code passes mypy, add this pattern to the list below
# src/charmory/.*
additional_dependencies: [types-requests]
files: |
(?x)^(
src/charmory/.*|
examples/src/.*
)$
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.0
hooks:
- id: nbqa-mypy
additional_dependencies: [mypy, types-requests]
args: ["--ignore-missing-imports"]
files: |
(?x)^(
src/charmory/.*|
tutorials/notebooks/.*|
examples/notebooks/.*
)$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
Expand Down
4 changes: 2 additions & 2 deletions examples/notebooks/api-walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@
"outputs": [],
"source": [
"import charmory.engine\n",
"engine = charmory.engine.Engine(baseline)"
"engine = charmory.engine.EvaluationEngine(baseline)"
]
},
{
Expand Down Expand Up @@ -391,7 +391,7 @@
"import charmory.engine\n",
"\n",
"baseline = charmory.blocks.cifar10.baseline\n",
"engine = charmory.engine.Engine(baseline)\n",
"engine = charmory.engine.EvaluationEngine(baseline)\n",
"result = engine.run()"
]
}
Expand Down
16 changes: 5 additions & 11 deletions examples/notebooks/jatic-food-track-demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@
"metadata": {},
"outputs": [],
"source": [
"from charmory.data import ArmoryDataLoader, JaticImageClassificationDataset\n",
"from charmory.data import ArmoryDataLoader\n",
"\n",
"generator = ArmoryDataLoader(\n",
" dataset=JaticImageClassificationDataset(dataset),\n",
" dataset=dataset,\n",
" batch_size=16,\n",
")"
]
Expand All @@ -214,7 +214,6 @@
"outputs": [],
"source": [
"import art.attacks.evasion\n",
"from armory.instrument.config import MetricsLogger\n",
"from armory.metrics.compute import BasicProfiler\n",
"from charmory.evaluation import (\n",
" Attack,\n",
Expand All @@ -229,7 +228,9 @@
"\n",
" eval_dataset = Dataset(\n",
" name=\"food-category-classification\",\n",
" test_dataset=generator,\n",
" x_key=\"image\",\n",
" y_key=\"label\",\n",
" test_dataloader=generator,\n",
" )\n",
"\n",
" eval_model = Model(\n",
Expand All @@ -255,13 +256,6 @@
"\n",
" eval_metric = Metric(\n",
" profiler=BasicProfiler(),\n",
" logger=MetricsLogger(\n",
" supported_metrics=[\"accuracy\"],\n",
" perturbation=[\"linf\"],\n",
" task=[\"categorical_accuracy\"],\n",
" means=True,\n",
" record_metric_per_sample=False,\n",
" ),\n",
" )\n",
"\n",
" eval_sysconfig = SysConfig(\n",
Expand Down
16 changes: 5 additions & 11 deletions examples/notebooks/jatic_toolbox-image-classification-demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@
"metadata": {},
"outputs": [],
"source": [
"from charmory.data import ArmoryDataLoader, JaticImageClassificationDataset\n",
"from charmory.data import ArmoryDataLoader\n",
"\n",
"generator = ArmoryDataLoader(\n",
" dataset=JaticImageClassificationDataset(dataset),\n",
" dataset=dataset,\n",
" batch_size=16,\n",
")"
]
Expand All @@ -214,7 +214,6 @@
"outputs": [],
"source": [
"import art.attacks.evasion\n",
"from armory.instrument.config import MetricsLogger\n",
"from armory.metrics.compute import BasicProfiler\n",
"from charmory.evaluation import (\n",
" Attack,\n",
Expand All @@ -227,7 +226,9 @@
"\n",
"eval_dataset = Dataset(\n",
" name=\"food-category-classification\",\n",
" test_dataset=generator,\n",
" x_key=\"image\",\n",
" y_key=\"label\",\n",
" test_dataloader=generator,\n",
")\n",
"\n",
"eval_model = Model(\n",
Expand All @@ -253,13 +254,6 @@
"\n",
"eval_metric = Metric(\n",
" profiler=BasicProfiler(),\n",
" logger=MetricsLogger(\n",
" supported_metrics=[\"accuracy\"],\n",
" perturbation=[\"linf\"],\n",
" task=[\"categorical_accuracy\"],\n",
" means=True,\n",
" record_metric_per_sample=False,\n",
" ),\n",
")\n",
"\n",
"eval_sysconfig = SysConfig(\n",
Expand Down
50 changes: 25 additions & 25 deletions src/armory/data/adversarial_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Adversarial datasets
"""

from typing import Callable
from typing import Callable, Optional

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -68,7 +68,7 @@ def imagenet_adversarial(
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = imagenet_adversarial_canonical_preprocessing,
cache_dataset: bool = True,
framework: str = "numpy",
Expand Down Expand Up @@ -114,7 +114,7 @@ def librispeech_adversarial(
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = librispeech_adversarial_canonical_preprocessing,
cache_dataset: bool = True,
framework: str = "numpy",
Expand Down Expand Up @@ -164,7 +164,7 @@ def resisc45_adversarial_224x224(
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = resisc45_adversarial_canonical_preprocessing,
cache_dataset: bool = True,
framework: str = "numpy",
Expand Down Expand Up @@ -222,7 +222,7 @@ def ucf101_adversarial_112x112(
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = ucf101_adversarial_canonical_preprocessing,
cache_dataset: bool = True,
framework: str = "numpy",
Expand Down Expand Up @@ -282,12 +282,12 @@ def gtsrb_poison(
split: str = "poison",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
preprocessing_fn: Callable = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Optional[Callable] = None,
cache_dataset: bool = True,
framework: str = "numpy",
clean_key: str = None,
adversarial_key: str = None,
clean_key: Optional[str] = None,
adversarial_key: Optional[str] = None,
shuffle_files: bool = False,
**kwargs,
) -> datasets.ArmoryDataGenerator:
Expand Down Expand Up @@ -361,7 +361,7 @@ def apricot_dev_adversarial(
split: str = "frcnn+ssd+retinanet",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = apricot_canonical_preprocessing,
label_preprocessing_fn: Callable = apricot_label_preprocessing,
cache_dataset: bool = True,
Expand Down Expand Up @@ -424,7 +424,7 @@ def apricot_test_adversarial(
split: str = "adversarial",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = apricot_canonical_preprocessing,
label_preprocessing_fn: Callable = apricot_label_preprocessing,
cache_dataset: bool = True,
Expand Down Expand Up @@ -487,7 +487,7 @@ def dapricot_dev_adversarial(
split: str = "large+medium+small",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = dapricot_canonical_preprocessing,
label_preprocessing_fn: Callable = dapricot_label_preprocessing,
cache_dataset: bool = True,
Expand Down Expand Up @@ -521,7 +521,7 @@ def dapricot_test_adversarial(
split: str = "large+medium+small",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = dapricot_canonical_preprocessing,
label_preprocessing_fn: Callable = dapricot_label_preprocessing,
cache_dataset: bool = True,
Expand Down Expand Up @@ -584,7 +584,7 @@ def carla_obj_det_dev(
split: str = "dev",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = carla_obj_det_dev_canonical_preprocessing,
label_preprocessing_fn=carla_obj_det_label_preprocessing,
cache_dataset: bool = True,
Expand Down Expand Up @@ -652,7 +652,7 @@ def carla_over_obj_det_dev(
split: str = "dev",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = carla_obj_det_dev_canonical_preprocessing,
label_preprocessing_fn=carla_obj_det_label_preprocessing,
cache_dataset: bool = True,
Expand Down Expand Up @@ -720,7 +720,7 @@ def carla_over_obj_det_test(
split: str = "test",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = carla_obj_det_test_canonical_preprocessing,
label_preprocessing_fn=carla_obj_det_label_preprocessing,
cache_dataset: bool = True,
Expand Down Expand Up @@ -788,7 +788,7 @@ def carla_obj_det_test(
split: str = "test",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = carla_obj_det_test_canonical_preprocessing,
label_preprocessing_fn=carla_obj_det_label_preprocessing,
cache_dataset: bool = True,
Expand Down Expand Up @@ -906,13 +906,13 @@ def carla_video_tracking_dev(
split: str = "dev",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = carla_video_tracking_canonical_preprocessing,
label_preprocessing_fn=carla_video_tracking_label_preprocessing,
cache_dataset: bool = True,
framework: str = "numpy",
shuffle_files: bool = False,
max_frames: int = None,
max_frames: Optional[int] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -960,13 +960,13 @@ def carla_video_tracking_test(
split: str = "test",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = carla_video_tracking_canonical_preprocessing,
label_preprocessing_fn=carla_video_tracking_label_preprocessing,
cache_dataset: bool = True,
framework: str = "numpy",
shuffle_files: bool = False,
max_frames: int = None,
max_frames: Optional[int] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -1130,13 +1130,13 @@ def carla_multi_object_tracking_dev(
split: str = "dev",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = carla_video_tracking_canonical_preprocessing,
label_preprocessing_fn=carla_mot_label_preprocessing,
cache_dataset: bool = True,
framework: str = "numpy",
shuffle_files: bool = False,
max_frames: int = None,
max_frames: Optional[int] = None,
coco_format: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -1191,13 +1191,13 @@ def carla_multi_object_tracking_test(
split: str = "test",
epochs: int = 1,
batch_size: int = 1,
dataset_dir: str = None,
dataset_dir: Optional[str] = None,
preprocessing_fn: Callable = carla_video_tracking_canonical_preprocessing,
label_preprocessing_fn=carla_mot_label_preprocessing,
cache_dataset: bool = True,
framework: str = "numpy",
shuffle_files: bool = False,
max_frames: int = None,
max_frames: Optional[int] = None,
coco_format: bool = False,
**kwargs,
):
Expand Down
Loading