Skip to content

Commit b727c2b

Browse files
authored
Allow TrackioCallback to work when pynvml is not installed (#39851)
Allow TrackioCallback to work when pynvml is not installed
1 parent 1ec0fec commit b727c2b

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

src/transformers/integrations/integration_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import numpy as np
3636
import packaging.version
3737

38+
from transformers.utils.import_utils import _is_package_available
39+
3840

3941
if os.getenv("WANDB_MODE") == "offline":
4042
print("⚙️ Running in WANDB offline mode")
@@ -1043,6 +1045,14 @@ def on_predict(self, args, state, control, metrics, **kwargs):
10431045
class TrackioCallback(TrainerCallback):
10441046
"""
10451047
A [`TrainerCallback`] that logs metrics to Trackio.
1048+
1049+
It records training metrics, model (and PEFT) configuration, and GPU memory usage.
1050+
If `nvidia-ml-py` is installed, GPU power consumption is also tracked.
1051+
1052+
**Requires**:
1053+
```bash
1054+
pip install trackio
1055+
```
10461056
"""
10471057

10481058
def __init__(self):
@@ -1119,12 +1129,14 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
11191129
device_idx = torch.cuda.current_device()
11201130
total_memory = torch.cuda.get_device_properties(device_idx).total_memory
11211131
memory_allocated = torch.cuda.memory_allocated(device_idx)
1122-
power = torch.cuda.power_draw(device_idx)
1132+
11231133
gpu_memory_logs = {
11241134
f"gpu/{device_idx}/allocated_memory": memory_allocated / (1024**3), # GB
11251135
f"gpu/{device_idx}/memory_usage": memory_allocated / total_memory, # ratio
1126-
f"gpu/{device_idx}/power": power / 1000, # Watts
11271136
}
1137+
if _is_package_available("pynvml"):
1138+
power = torch.cuda.power_draw(device_idx)
1139+
gpu_memory_logs[f"gpu/{device_idx}/power"] = power / 1000 # Watts
11281140
if dist.is_available() and dist.is_initialized():
11291141
gathered_logs = [None] * dist.get_world_size()
11301142
dist.all_gather_object(gathered_logs, gpu_memory_logs)

0 commit comments

Comments
 (0)