Skip to content

Commit

Permalink
train args defaulting None marked as Optional (huggingface#17156)
Browse files Browse the repository at this point in the history
Co-authored-by: Dom Miketa <dmiketa@exscientia.co.uk>
  • Loading branch information
2 people authored and Narsil committed May 12, 2022
1 parent 249229b commit 37d80b2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
20 changes: 11 additions & 9 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ class TrainingArguments:
)
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
data_seed: int = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
bf16: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -616,14 +616,14 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
)
tf32: bool = field(
tf32: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
xpu_backend: str = field(
xpu_backend: Optional[str] = field(
default=None,
metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]},
)
Expand All @@ -648,7 +648,7 @@ class TrainingArguments:
dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
)
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
eval_steps: Optional[int] = field(default=None, metadata={"help": "Run an evaluation every X steps."})
dataloader_num_workers: int = field(
default=0,
metadata={
Expand Down Expand Up @@ -770,14 +770,14 @@ class TrainingArguments:
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
)
hub_model_id: str = field(
hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_strategy: HubStrategy = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."})
gradient_checkpointing: bool = field(
default=False,
Expand All @@ -793,13 +793,15 @@ class TrainingArguments:
default="auto",
metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
)
push_to_hub_model_id: str = field(
push_to_hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
)
push_to_hub_organization: str = field(
push_to_hub_organization: Optional[str] = field(
default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
)
push_to_hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
push_to_hub_token: Optional[str] = field(
default=None, metadata={"help": "The token to use to push to the Model Hub."}
)
_n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/training_args_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import warnings
from dataclasses import dataclass, field
from typing import Tuple
from typing import Optional, Tuple

from .training_args import TrainingArguments
from .utils import cached_property, is_tf_available, logging, tf_required
Expand Down Expand Up @@ -161,17 +161,17 @@ class TFTrainingArguments(TrainingArguments):
Whether to activate the XLA compilation or not.
"""

tpu_name: str = field(
tpu_name: Optional[str] = field(
default=None,
metadata={"help": "Name of TPU"},
)

tpu_zone: str = field(
tpu_zone: Optional[str] = field(
default=None,
metadata={"help": "Zone of TPU"},
)

gcp_project: str = field(
gcp_project: Optional[str] = field(
default=None,
metadata={"help": "Name of Cloud TPU-enabled project"},
)
Expand Down

0 comments on commit 37d80b2

Please sign in to comment.