Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[trainer] --model_parallel hasn't been implemented for most models #9347

Merged
merged 6 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):

- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
"""
config_class = None
base_model_prefix = ""
Expand All @@ -417,6 +418,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# trained, but which are deterministic)
_keys_to_ignore_on_save = None

is_parallelizable = False

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
config_class = GPT2Config
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
is_parallelizable = True

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ class T5PreTrainedModel(PreTrainedModel):
config_class = T5Config
load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer"
is_parallelizable = True

@property
def dummy_inputs(self):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,11 @@ def __init__(
if model is None and model_init is not None:
model = self.call_model_init()

if self.args.model_parallel and not model.is_parallelizable:
raise ValueError(
f"{model.__class__.__name__} implementation currently doesn't support model parallelism, therefore --model_parallel cl arg cannot be used"
)

# Model parallel
if model is not None and not self.args.model_parallel:
model = model.to(args.device)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ class TrainingArguments:
:obj:`"eval_loss"`.
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If there is more than one device, whether to use model parallelism to distribute the model's modules across
devices or not.
If the model supports model parallelism and there is more than one device, whether to use model parallelism
to distribute the model's modules across devices or not.
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
Expand Down