Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ def clip_gradients(
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""clips all the optimizer parameters to the given value"""
self.precision_plugin.clip_gradients(self.model, optimizer, clip_val, gradient_clip_algorithm)
self.precision_plugin.clip_gradients(
self.model, optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)

def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
"""Hook to do something on the end of an training epoch
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE
from pytorch_lightning.utilities import _XLA_AVAILABLE, GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _XLA_AVAILABLE:
Expand Down Expand Up @@ -56,7 +56,12 @@ def run_optimizer_step(
) -> None:
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):
def clip_gradients(
self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM
) -> None:
assert gradient_clip_algorithm is GradClipAlgorithmType.NORM, \
"Only NORM gradient clipping is supported on TPU for now"

model = self.lightning_module
parameters = model.parameters()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def on_trainer_init(
f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}"
)
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = gradient_clip_algorithm
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)

# gradient norm tracking
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,10 @@ def track_and_norm_grad(self, optimizer):
grad_norm_dic = self._track_gradient_norm()

# clip gradients
self.trainer.accelerator.clip_gradients(optimizer, self.trainer.gradient_clip_val)
self.trainer.accelerator.clip_gradients(
optimizer, self.trainer.gradient_clip_val,
gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
)
self._cur_grad_norm_dict = grad_norm_dic

def _track_gradient_norm(self):
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_tpu_grad_norm(tmpdir):
@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_clip_grad_by_value(tmpdir):
"""Test if clip_gradients by value works on TPU."""
"""Test if clip_gradients by value works on TPU. (It should not.)"""
tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
Expand All @@ -236,7 +236,8 @@ def test_tpu_clip_grad_by_value(tmpdir):
)

model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
with pytest.raises(AssertionError):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@RunIf(tpu=True)
Expand Down
21 changes: 11 additions & 10 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,13 +918,13 @@ def test_gradient_clipping_by_value(tmpdir):

model = BoringModel()

grad_clip_val = 0.0001
grad_clip_val = 1e-10
trainer = Trainer(
max_steps=10,
max_steps=1,
max_epochs=1,
gradient_clip_val=grad_clip_val,
gradient_clip_algorithm='value',
default_root_dir=tmpdir,
default_root_dir=tmpdir
)

trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward
Expand All @@ -938,8 +938,8 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde
parameters = model.parameters()
grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters]
grad_max = torch.max(torch.stack(grad_max_list))
assert round(grad_max.item(), 6) <= grad_clip_val, \
f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ."
assert abs(grad_max.item() - grad_clip_val) < 1e-11, \
f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ."

return ret_val

Expand Down Expand Up @@ -996,9 +996,9 @@ def test_gradient_clipping_by_value_fp16(tmpdir):
tutils.reset_seed()

model = BoringModel()
grad_clip_val = 0.0001
grad_clip_val = 1e-10
trainer = Trainer(
max_steps=10,
max_steps=1,
max_epochs=1,
precision=16,
gpus=1,
Expand All @@ -1016,9 +1016,10 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde
# test that gradient is clipped correctly
ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens)
parameters = model.parameters()
grad_max = torch.max(torch.stack([p.grad.detach() for p in parameters]))
assert round(grad_max.item(), 6) <= grad_clip_val, \
f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ."
grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters]
grad_max = torch.max(torch.stack(grad_max_list))
assert abs(grad_max.item() - grad_clip_val) < 1e-11, \
f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ."

return ret_val

Expand Down