Skip to content

Commit

Permalink
Show CUDA matmul precision info only ever once (#17960)
Browse files Browse the repository at this point in the history
(cherry picked from commit c5fae64)
  • Loading branch information
awaelchli authored and Borda committed Jul 7, 2023
1 parent 9bfae44 commit 4a0c1f5
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed automatic step tracking in Fabric's CSVLogger ([#17942](https://github.com/Lightning-AI/lightning/pull/17942))


- Fixed an issue causing the `torch.set_float32_matmul_precision` info message to show multiple times ([#17960](https://github.com/Lightning-AI/lightning/pull/17960))


## [2.0.3] - 2023-06-07

- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def _device_count_nvml() -> int:
return len(visible_devices)


@lru_cache(1) # show the warning only ever once
def _check_cuda_matmul_precision(device: torch.device) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
# before 1.12, tf32 was used by default
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed incorrect parsing of arguments when augmenting exception messages in DDP ([#17948](https://github.com/Lightning-AI/lightning/pull/17948))


- Fixed an issue causing the `torch.set_float32_matmul_precision` info message to show multiple times ([#17960](https://github.com/Lightning-AI/lightning/pull/17960))


- Added missing `map_location` argument for the `LightningDataModule.load_from_checkpoint` function ([#17950](https://github.com/Lightning-AI/lightning/pull/17950))


Expand Down
10 changes: 10 additions & 0 deletions tests/tests_fabric/accelerators/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ def test_tf32_message(_, __, caplog, monkeypatch):
with caplog.at_level(logging.INFO):
_check_cuda_matmul_precision(device)
assert expected in caplog.text
_check_cuda_matmul_precision.cache_clear()

caplog.clear()
torch.backends.cuda.matmul.allow_tf32 = True # changing this changes the string
assert torch.get_float32_matmul_precision() == "high"
with caplog.at_level(logging.INFO):
_check_cuda_matmul_precision(device)
assert not caplog.text
_check_cuda_matmul_precision.cache_clear()

caplog.clear()
torch.backends.cuda.matmul.allow_tf32 = False
Expand All @@ -115,12 +117,20 @@ def test_tf32_message(_, __, caplog, monkeypatch):
with caplog.at_level(logging.INFO):
_check_cuda_matmul_precision(device)
assert not caplog.text
_check_cuda_matmul_precision.cache_clear()

torch.set_float32_matmul_precision("highest") # can be reverted
with caplog.at_level(logging.INFO):
_check_cuda_matmul_precision(device)
assert expected in caplog.text

# subsequent calls don't produce more messages
caplog.clear()
with caplog.at_level(logging.INFO):
_check_cuda_matmul_precision(device)
assert expected not in caplog.text
_check_cuda_matmul_precision.cache_clear()


def test_find_usable_cuda_devices_error_handling():
"""Test error handling for edge cases when using `find_usable_cuda_devices`."""
Expand Down

0 comments on commit 4a0c1f5

Please sign in to comment.