diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2cb426ac2ce8..276c08788a13 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -52,7 +52,7 @@ from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn -from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler from . import __version__ from .configuration_utils import PretrainedConfig @@ -353,7 +353,7 @@ def __init__( model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None,