-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support data_parallel training and ucf101 dataset #4819
Conversation
'--model_path_pre', | ||
type=str, | ||
default='tsm', | ||
help='default model path pre is tsm.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what the meaning of model path pre?
dygraph/tsm/train.py
Outdated
#load resnet50 pretrain | ||
pre_state_dict = fluid.load_program_state(args.resnet50_dir) | ||
for key in pre_state_dict.keys(): | ||
print('pre_state_dict.key: {}'.format(key)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print是调试代码?所有参数名打印出来太长了 建议注释或删除
dygraph/tsm/train.py
Outdated
current_step_lr)) | ||
|
||
# 6.2 save checkpoint | ||
save_parameters = (not use_data_parallel) or ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用走or逻辑?单卡时也可以local_rank==0保存,参考
https://github.com/huangjun12/models/blob/9e2809a85c64115df92564d31055066300661141/dygraph/slowfast/train.py#L441
for batch_id, data in enumerate(train_reader()): | ||
t1 = time.time() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t1-t5重新命名一下? 比如batch_start_time
video_model = fluid.dygraph.parallel.DataParallel(video_model, | ||
strategy) | ||
|
||
# 4. load checkpoint | ||
if args.checkpoint: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resume阶段, epoch计数是否对应调整下?
outputs = video_model(imgs) | ||
t3 = time.time() | ||
|
||
loss = fluid.layers.cross_entropy( | ||
input=outputs, label=labels, ignore_index=-1) | ||
avg_loss = fluid.layers.mean(loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy avg_loss to a new variable , and output(print) it instead of avg_loss, in avoid to print avg_loss after scale_loss function, which is already divided by the number of cards
Add a result of multi-cards training? |
support data_parallel training and ucf101 dataset