Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check layer types for Optimizer construction #10598

Merged
merged 2 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
get_parameter_names,
nested_concat,
nested_detach,
nested_numpify,
Expand Down Expand Up @@ -613,14 +614,15 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizer is None:
no_decay = ["bias", "LayerNorm.weight"]
decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0,
},
]
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,3 +672,19 @@ def save_state(self):

path = os.path.join(self.args.output_dir, "trainer_state.json")
self.state.save_to_json(path)


def get_parameter_names(model, forbidden_layer_types):
"""
Returns the names of the model parameters that are not inside a forbidden layer.
"""
result = []
for name, child in model.named_children():
result += [
f"{name}.{n}"
for n in get_parameter_names(child, forbidden_layer_types)
if not isinstance(child, tuple(forbidden_layer_types))
]
# Add model specific parameters (defined with nn.Parameter) since they are not in any child.
result += list(model._parameters.keys())
return result
26 changes: 26 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,20 @@ def forward(self, input_x, labels=None, **kwargs):
loss = torch.nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)

class TstLayer(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
self.ln1 = torch.nn.LayerNorm(hidden_size)
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
self.ln2 = torch.nn.LayerNorm(hidden_size)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))

def forward(self, x):
h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
h = torch.nn.functional.relu(self.linear2(x))
return self.ln2(x + h + self.bias)

def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, **kwargs):
label_names = kwargs.get("label_names", None)
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
Expand Down Expand Up @@ -991,6 +1005,18 @@ def test_fp16_full_eval(self):
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)

def test_no_wd_param_group(self):
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
trainer = Trainer(model=model)
trainer.create_optimizer_and_scheduler(10)
# fmt: off
wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight']
# fmt: on
wd_params = [p for n, p in model.named_parameters() if n in wd_names]
no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)


@require_torch
@require_optuna
Expand Down
24 changes: 24 additions & 0 deletions tests/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,23 @@
DistributedTensorGatherer,
LabelSmoother,
LengthGroupedSampler,
get_parameter_names,
)

class TstLayer(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, hidden_size)
self.ln1 = torch.nn.LayerNorm(hidden_size)
self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
self.ln2 = torch.nn.LayerNorm(hidden_size)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size))

def forward(self, x):
h = self.ln1(torch.nn.functional.relu(self.linear1(x)))
h = torch.nn.functional.relu(self.linear2(x))
return self.ln2(x + h + self.bias)


@require_torch
class TrainerUtilsTest(unittest.TestCase):
Expand Down Expand Up @@ -117,3 +132,12 @@ def test_distributed_length_grouped(self):
self.assertEqual(lengths[indices_process_0[0]], 50)
# The indices should be a permutation of range(100)
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))

def test_get_parameter_names(self):
model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
# fmt: off
self.assertEqual(
get_parameter_names(model, [torch.nn.LayerNorm]),
['0.linear1.weight', '0.linear1.bias', '0.linear2.weight', '0.linear2.bias', '0.bias', '1.0.linear1.weight', '1.0.linear1.bias', '1.0.linear2.weight', '1.0.linear2.bias', '1.0.bias', '1.1.linear1.weight', '1.1.linear1.bias', '1.1.linear2.weight', '1.1.linear2.bias', '1.1.bias']
)
# fmt: on