Skip to content

Commit

Permalink
Add implementation for DALI AA and TA with readme
Browse files Browse the repository at this point in the history
Adjust some configuration options to accomodate it.
Remove the obsolete pipeline.
Provide the readme.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki committed Mar 6, 2023
1 parent c792869 commit a923ea0
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions docs/examples/use_cases/pytorch/efficientnet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def add_parser_arguments(parser, skip_arch=False):
parser.add_argument(
"--data-backend",
metavar="BACKEND",
default="dali-cpu",
default="dali",
choices=DATA_BACKEND_CHOICES,
help="data backend: "
+ " | ".join(DATA_BACKEND_CHOICES)
+ " (default: dali-cpu)",
+ " (default: dali)",
)
parser.add_argument(
"--interpolation",
Expand Down Expand Up @@ -111,7 +111,14 @@ def add_parser_arguments(parser, skip_arch=False):
default=2,
type=int,
metavar="N",
help="number of samples prefetched by each loader",
help="number of samples prefetched by each loader (PyTorch only)",
)
parser.add_argument(
"--dali-device",
default="gpu",
type=str,
choices=["cpu", "gpu"],
help=("The placement of DALI decode and random resized crop operations (default: gpu)"),
)
parser.add_argument(
"--epochs",
Expand Down Expand Up @@ -315,11 +322,11 @@ def add_parser_arguments(parser, skip_arch=False):
)
parser.add_argument("--use-ema", default=None, type=float, help="use EMA")
parser.add_argument(
"--augmentation",
"--automatic-augmentation",
type=str,
default=None,
choices=[None, "autoaugment"],
help="augmentation method",
default="autoaugment",
choices=["disabled", "autoaugment", "trivialaugment"],
help="Automatic augmentation method, trivialaugment is supported only for DALI data backend",
)

parser.add_argument(
Expand Down Expand Up @@ -480,11 +487,8 @@ def _worker_init_fn(id):
args.workers = args.workers * 2
get_train_loader = get_pytorch_train_loader
get_val_loader = get_pytorch_val_loader
elif args.data_backend == "dali-gpu":
get_train_loader = get_dali_train_loader(dali_cpu=False)
get_val_loader = get_dali_val_loader()
elif args.data_backend == "dali-cpu":
get_train_loader = get_dali_train_loader(dali_cpu=True)
elif args.data_backend == "dali":
get_train_loader = get_dali_train_loader(dali_device=args.dali_device)
get_val_loader = get_dali_val_loader()
elif args.data_backend == "synthetic":
get_val_loader = get_synthetic_loader
Expand All @@ -500,7 +504,7 @@ def _worker_init_fn(id):
model_args.num_classes,
args.mixup > 0.0,
interpolation=args.interpolation,
augmentation=args.augmentation,
augmentation=args.automatic_augmentation,
start_epoch=start_epoch,
workers=args.workers,
_worker_init_fn=_worker_init_fn,
Expand Down

0 comments on commit a923ea0

Please sign in to comment.