Skip to content

Commit

Permalink
Update lm_target.py
Browse files Browse the repository at this point in the history
Fixed bug when not using model parallel training
  • Loading branch information
hhou435 authored Mar 6, 2024
1 parent 86531a8 commit b0a9591
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tencentpretrain/targets/lm_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@ def __init__(self, args, vocab_size):
super(LmTarget, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = args.hidden_size
self.tensor_model_parallel_size = args.tensor_model_parallel_size
self.pipeline_model_parallel_size = args.pipeline_model_parallel_size
if hasattr(args, "tensor_model_parallel_size"):
self.tensor_model_parallel_size = args.tensor_model_parallel_size
else:
self.tensor_model_parallel_size = 1
if hasattr(args, "pipeline_model_parallel_size"):
self.pipeline_model_parallel_size = args.pipeline_model_parallel_size
else:
self.pipeline_model_parallel_size = 1
if "label_smoothing" in args:
self.label_smoothing = args.label_smoothing
else:
Expand Down

0 comments on commit b0a9591

Please sign in to comment.