-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add benchmark from torchvision training references #714
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, LGTM! I think we can land this PR as it is and addresses the comments later. I am just leaving them as TODO now for our reference.
We should add a README with instructions to this folder. I am indifferent if whether it is added in this PR or in a later PR since it is not the most urgent.
return metric_logger.acc1.global_avg | ||
|
||
|
||
def create_data_loaders(args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We likely have to accept path to the datasets as an argument.
train_data_loader = DataLoader2( | ||
train_dataset, | ||
datapipe_adapter_fn=adapter.Shuffle(), | ||
reading_service=MultiProcessingReadingService(num_workers=args.workers), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that MultiProcessingReadingService
will be updated to use a slightly different backend (i.e. what is currently PrototypeMultiProcessingReadingService
.
I'm going to look into that part to see if there is any issue.
def add_meter(self, name, meter): | ||
self.meters[name] = meter | ||
|
||
def log_every(self, iterable, print_freq, header=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Is it possible to break this into smaller pieces? Or just a few inline comments along the way would be helpful. Thanks!
|
||
train_data_loader, val_data_loader, train_sampler = create_data_loaders(args) | ||
|
||
num_classes = 1000 # I'm lazy. TODO change this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do? Define the number of classes to classify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need to pass this to the model's constructor. It needs to know how many classes there are because this directly impact the number of outputs of the model (one output (probability) per class).
@NicolasHug has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Towards #416
This is a modified and simplified version of the torchvision classification training reference that provides:
I removed a lot of non-essential features from the original reference, but I can simplify further. Typically I would expect the
MetricLogger
to disappear, or be trimmed down to its most essential bits.