-
Notifications
You must be signed in to change notification settings - Fork 1
/
lightning_datamodules.py
63 lines (57 loc) · 1.93 KB
/
lightning_datamodules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pytorch_lightning as pl
from datasets import load_from_disk
from torch.utils.data import DataLoader
from utils import DataCollator
class SummaryDataModule(pl.LightningDataModule):
def __init__(
self,
data_path,
train_batch_size,
tokenizer,
eval_batch_size=None,
max_length=4096,
eval_size=500,
num_workers=8,
train_size=None,
):
super().__init__()
self.data_path = data_path
self.tokenizer = tokenizer
self.train_batch_size = train_batch_size
self.eval_batch_size = (
train_batch_size if eval_batch_size is None else eval_batch_size
)
self.max_length = max_length
self.eval_size = eval_size
self.num_workers = num_workers
self.train_size = train_size
def setup(self, stage):
if stage == "fit":
dataset = load_from_disk(self.data_path)
dataset = dataset.with_format("torch")
self.train_dataset = dataset["train"]
if self.train_size is not None:
self.train_dataset = self.train_dataset.select(range(self.train_size))
self.eval_dataset = dataset["validation"].select(range(self.eval_size))
self.data_collator = DataCollator(
self.tokenizer,
padding="longest",
max_length=self.max_length,
truncate=True,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.train_batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=self.data_collator,
)
def val_dataloader(self):
return DataLoader(
self.eval_dataset,
batch_size=self.train_batch_size,
shuffle=False,
num_workers=self.num_workers,
collate_fn=self.data_collator,
)