Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducible checkpoint #11582

Merged
merged 18 commits into from
May 4, 2021
43 changes: 43 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import dataclasses
import gc
import math
import os
import random
import re
import tempfile
import unittest
Expand Down Expand Up @@ -195,6 +197,26 @@ def forward(self, input_x, labels=None, **kwargs):
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)

class RegressionRandomPreTrainedModel(PreTrainedModel):
config_class = RegressionModelConfig
base_model_prefix = "regression"

def __init__(self, config):
super().__init__(config)
self.a = torch.nn.Parameter(torch.tensor(config.a).float())
self.b = torch.nn.Parameter(torch.tensor(config.b).float())

def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
if self.training:
stas00 marked this conversation as resolved.
Show resolved Hide resolved
# Add random noise from torch, numpy and random
y += 0.05 * torch.randn(1).squeeze() + 0.05 * torch.tensor(np.random.rand() + random.random())

if labels is None:
return (y,)
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y)

class TstLayer(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
Expand Down Expand Up @@ -699,6 +721,27 @@ def test_can_resume_training(self):
trainer.train(resume_from_checkpoint=True)
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))

def test_resume_training_with_randomness(self):
train_dataset = RegressionDataset(length=128)
eval_dataset = RegressionDataset()

config = RegressionModelConfig(a=0, b=2)
model = RegressionRandomPreTrainedModel(config)

tmp_dir = self.get_auto_remove_tmp_dir()
args = RegressionTrainingArguments(tmp_dir, save_steps=5, learning_rate=0.1)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)

trainer.train()
(a, b) = trainer.model.a.item(), trainer.model.b.item()

model = RegressionRandomPreTrainedModel(config)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train(resume_from_checkpoint=os.path.join(tmp_dir, "checkpoint-15"))
(a1, b1) = trainer.model.a.item(), trainer.model.b.item()
self.assertTrue(math.isclose(a, a1, rel_tol=1e-4))
self.assertTrue(math.isclose(b, b1, rel_tol=1e-4))

def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2:
# This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
Expand Down